├── .codecov.yml ├── .github └── workflows │ ├── dev.yml │ └── release.yml ├── .gitignore ├── .golangci.yml ├── .pre-commit-config.yaml ├── .release.yml ├── LICENSE ├── README.md ├── buf.gen.yaml ├── doc.go ├── errors.go ├── expr ├── binding_test.go ├── builder.go ├── builder_test.go ├── decimal_util.go ├── decimal_util_test.go ├── expression.go ├── expressions_test.go ├── field_reference.go ├── functions.go ├── interval_compound_literal.go ├── interval_compound_literal_test.go ├── interval_year_to_month.go ├── interval_year_to_month_test.go ├── literals.go ├── literals_test.go ├── proto_literals_test.go ├── string_test.go ├── testdata │ ├── expressions.yaml │ └── extended_exprs.yaml └── utils.go ├── extensions ├── extension_mgr.go ├── extension_mgr_test.go ├── simple_extension.go ├── simple_extension_test.go ├── variants.go └── variants_test.go ├── functions ├── dialect.go ├── dialect_test.go ├── functions.go ├── local_functions.go ├── local_functions_test.go ├── registries.go ├── types.go └── types_test.go ├── go.mod ├── go.sum ├── grammar └── generate.go ├── literal ├── utils.go └── utils_test.go ├── plan ├── builders.go ├── common.go ├── ctas_plan_test.go ├── internal │ ├── helper.go │ └── helper_test.go ├── named_write_plan_test.go ├── plan.go ├── plan_builder_test.go ├── plan_test.go ├── relations.go ├── relations_test.go ├── testdata │ ├── ctas_basic.json │ ├── ctas_with_filter.json │ ├── delete_with_filter.json │ ├── insert_from_select.json │ ├── value_with_literal.json │ └── value_with_scalar.json └── virtual_table_from_expr_test.go ├── testcases └── parser │ ├── baseparser │ ├── functestcase_lexer.go │ ├── functestcase_parser.go │ ├── functestcaseparser_base_visitor.go │ └── functestcaseparser_visitor.go │ ├── nodes.go │ ├── parse.go │ ├── parse_test.go │ └── visitor.go └── types ├── any_type.go ├── any_type_test.go ├── integer_parameters ├── concrete_int_param.go ├── integer_parameter_type.go ├── integer_parameter_type_test.go └── variable_int_param.go ├── interval_compound_type.go ├── interval_compound_type_test.go ├── interval_day_type.go ├── interval_day_type_test.go ├── interval_year_month_type.go ├── interval_year_month_type_test.go ├── parameterized_decimal_type.go ├── parameterized_decimal_type_test.go ├── parameterized_list_type.go ├── parameterized_list_type_test.go ├── parameterized_map_type.go ├── parameterized_map_type_test.go ├── parameterized_single_integer_param_type.go ├── parameterized_single_integer_param_type_test.go ├── parameterized_struct_type.go ├── parameterized_struct_type_test.go ├── parameterized_user_defined_type.go ├── parameterized_user_defined_type_test.go ├── parser ├── baseparser │ ├── README.md │ ├── substrait_lexer.go │ ├── substraittype_base_listener.go │ ├── substraittype_base_visitor.go │ ├── substraittype_lexer.go │ ├── substraittype_listener.go │ ├── substraittype_parser.go │ └── substraittype_visitor.go ├── parse.go ├── parse_test.go ├── util │ └── error_listener.go └── visitor.go ├── precison_timestamp_types.go ├── precison_timestamp_types_test.go ├── type_derivation.go ├── type_derivation_test.go ├── types.go └── types_test.go /.codecov.yml: -------------------------------------------------------------------------------- 1 | # .codecov.yml 2 | coverage: 3 | ignore: 4 | - "types/parser/baseparser/*.go" 5 | - "testcases/parser/baseparser/*.go" 6 | -------------------------------------------------------------------------------- /.github/workflows/dev.yml: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | name: PR Build Check 4 | 5 | on: [push, pull_request] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | golangci: 12 | name: Code Linting (ubuntu-latest) 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-go@v5 17 | with: 18 | go-version-file: 'go.mod' 19 | - name: golangci-lint 20 | uses: golangci/golangci-lint-action@v6 21 | with: 22 | version: v1.60.1 23 | build: 24 | name: Build and Test (${{ matrix.os }}) 25 | runs-on: ${{ matrix.os }} 26 | strategy: 27 | matrix: 28 | os: [ ubuntu-latest, windows-latest, macos-latest ] 29 | steps: 30 | - name: Checkout 31 | uses: actions/checkout@v3 32 | with: 33 | fetch-depth: 0 34 | - name: Install go 35 | uses: actions/setup-go@v3 36 | with: 37 | go-version: '>=1.18' 38 | cache: true 39 | - name: Build 40 | run: go build ./... 41 | - name: Run Tests 42 | if: runner.os != 'Linux' 43 | run: go test -v ./... 44 | - name: Run Tests with Coverage 45 | if: runner.os == 'Linux' 46 | run: go test -v -coverprofile=coverage.out $(go list ./... | grep -v /proto) 47 | - name: Upload coverage to Codecov 48 | if: runner.os == 'Linux' && github.repository == 'substrait-io/substrait-go' 49 | uses: codecov/codecov-action@v4 50 | with: 51 | token: ${{ secrets.CODECOV_TOKEN }} 52 | disable_search: true 53 | file: ./coverage.out 54 | fail_ci_if_error: true 55 | codecov_yml_path: .codecov.yml 56 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | pull_request: 5 | schedule: 6 | # 2 AM on Sunday 7 | - cron: "0 2 * * 0" 8 | workflow_dispatch: 9 | 10 | # we do not want more than one release workflow executing at the same time, ever 11 | concurrency: 12 | group: release 13 | # cancelling in the middle of a release would create incomplete releases 14 | # so cancel-in-progress is false 15 | cancel-in-progress: false 16 | 17 | permissions: 18 | contents: write 19 | checks: write 20 | id-token: write 21 | statuses: write 22 | 23 | jobs: 24 | release: 25 | if: github.repository == 'substrait-io/substrait-go' 26 | runs-on: ubuntu-latest 27 | steps: 28 | - name: Checkout Code 29 | uses: actions/checkout@v3 30 | with: 31 | fetch-depth: 0 32 | - name: Run go-semantic-release 33 | env: 34 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 35 | run: | 36 | wget https://github.com/Nightapes/go-semantic-release/releases/download/v2.1.1/go-semantic-release.linux_x86_64.zip 37 | unzip go-semantic-release.linux_x86_64.zip 38 | ./go-semantic-release.linux_x86_64 release -l trace 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | 27 | .idea 28 | /types/parser/baseparser/*.interp 29 | /types/parser/baseparser/*.tokens 30 | /testcases/parser/baseparser/*.interp 31 | /testcases/parser/baseparser/*.tokens 32 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | issues: 2 | exclude-dirs: 3 | - proto 4 | linters-settings: 5 | staticcheck: 6 | checks: 7 | - "-SA1019" #disabled until deprecation warnings are fixed. 8 | gci: 9 | custom-order: false 10 | linters: 11 | enable: 12 | - gci 13 | 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/golangci/golangci-lint/ 3 | rev: v1.52.2 4 | hooks: 5 | - id: golangci-lint 6 | args: 7 | - --fix 8 | - --issues-exit-code=1 9 | - --config=.golangci.yml 10 | -------------------------------------------------------------------------------- /.release.yml: -------------------------------------------------------------------------------- 1 | commitFormat: conventional 2 | branch: 3 | main: release 4 | release: 'github' 5 | github: 6 | repo: "substrait-go" 7 | user: "substrait-io" 8 | changelog: 9 | printAll: true 10 | title: "v{{.Version}}" 11 | showAuthors: false 12 | 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # substrait-go 2 | 3 | Experimental Go bindings for [substrait](https://substrait.io) 4 | 5 | [![release status](https://github.com/substrait-io/substrait-go/actions/workflows/release.yml/badge.svg)](https://github.com/substrait-io/substrait-go/actions/workflows/release.yml) 6 | [![codecov](https://codecov.io/gh/substrait-io/substrait-go/branch/main/graph/badge.svg?token=7YXPNM3AMJ)](https://codecov.io/gh/substrait-io/substrait-go) 7 | ## Note: 8 | 9 | This is work in progress still, things still to do: 10 | 11 | - [ ] Expression parsing 12 | - [x] Reading in extension yamls 13 | - [x] CI building and testing the implementation 14 | - [ ] Serialization/Deserialization of some expression types: 15 | - [x] IfThen 16 | - [x] SwitchExpression 17 | - [x] SingularOrList 18 | - [x] MultiOrList 19 | - [x] Cast 20 | - [x] Nested 21 | - [ ] Subquery 22 | - [ ] Serialization/Deserialization of Plan and Relations 23 | - [x] Plan 24 | - [x] PlanRel 25 | - [x] Rel 26 | - [x] ReadRel 27 | - [x] FilterRel 28 | - [x] FetchRel 29 | - [x] AggregateRel 30 | - [x] SortRel 31 | - [x] JoinRel 32 | - [x] ProjectRel 33 | - [x] SetRel 34 | - [x] ExtensionSingleRel 35 | - [x] ExtensionMultiRel 36 | - [x] ExtensionLeafRel 37 | - [x] CrossRel 38 | - [x] HashJoinRel 39 | - [x] MergeJoinRel 40 | - [ ] DdlRel 41 | - [ ] WriteRel 42 | - [ ] ExchangeRel 43 | - [x] Plan Building helpers 44 | - [ ] ReadRel 45 | - [x] NamedScanReadRel 46 | - [x] VirtualTableReadRel 47 | - [ ] ExtensionTableReadRel 48 | - [ ] LocalFileReadRel 49 | - [x] FilterRel 50 | - [x] FetchRel 51 | - [x] AggregateRel 52 | - [x] SortRel 53 | - [x] JoinRel 54 | - [x] ProjectRel 55 | - [x] SetRel 56 | - [x] CrossRel 57 | - [ ] HashJoinRel 58 | - [ ] MergeJoinRel 59 | - [ ] DdlRel 60 | - [ ] WriteRel 61 | - [ ] ExchangeRel 62 | 63 | As this is built out, you can expect refactors and other changes to the 64 | structure of the package for the time being. **The API should not yet be 65 | considered stable.** 66 | 67 | ## Generate from proto files 68 | 69 | ### Install buf 70 | 71 | First ensure you have `buf` installed by following https://docs.buf.build/installation. 72 | 73 | ### Install go plugin 74 | 75 | Run the following to install the Go plugin for protobuf: 76 | 77 | ```bash 78 | $ go install google.golang.org/protobuf/cmd/protoc-gen-go@latest 79 | ``` 80 | 81 | Ensure that your GOPATH is on your path: 82 | 83 | ```bash 84 | $ export PATH="$PATH:$(go env GOPATH)/bin" 85 | ``` 86 | 87 | ### Run go generate 88 | 89 | As long as buf and the Go protobuf plugin are installed, you can 90 | simply run `go generate` to generate the updated `.pb.go` files. It 91 | will generate them by referencing the primary substrait-io repository. 92 | 93 | You can then commit the updated files. -------------------------------------------------------------------------------- /buf.gen.yaml: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | version: v1 4 | plugins: 5 | - name: go 6 | out: proto 7 | opt: module=github.com/substrait-io/substrait-go/v3/proto -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | // Package substraitgo contains the experimental go bindings for substrait 4 | // (https://substrait.io). 5 | // 6 | // Current generated proto substrait version: v0.64.0 7 | package substraitgo 8 | 9 | //go:generate buf generate https://github.com/substrait-io/substrait.git#tag=v0.64.0 10 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package substraitgo 4 | 5 | import "errors" 6 | 7 | var ( 8 | ErrNotImplemented = errors.New("not implemented") 9 | ErrInvalidType = errors.New("invalid type") 10 | ErrInvalidExpr = errors.New("invalid expression") 11 | ErrNotFound = errors.New("not found") 12 | ErrKeyExists = errors.New("key already exists") 13 | ErrInvalidRel = errors.New("invalid relation") 14 | ErrInvalidArg = errors.New("invalid argument") 15 | ErrInvalidInputCount = errors.New("invalid input count") 16 | ErrInvalidDialect = errors.New("invalid dialect") 17 | ) 18 | -------------------------------------------------------------------------------- /expr/binding_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package expr_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | . "github.com/substrait-io/substrait-go/v4/expr" 10 | "github.com/substrait-io/substrait-go/v4/extensions" 11 | "github.com/substrait-io/substrait-go/v4/types" 12 | ) 13 | 14 | var ( 15 | extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) 16 | uPointRef = extReg.GetTypeAnchor(extensions.ID{ 17 | URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", 18 | Name: "point", 19 | }) 20 | 21 | subID = extensions.ID{ 22 | URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", 23 | Name: "subtract"} 24 | addID = extensions.ID{ 25 | URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", 26 | Name: "add"} 27 | indexInID = extensions.ID{ 28 | URI: extensions.SubstraitDefaultURIPrefix + "functions_set.yaml", 29 | Name: "index_in"} 30 | rankID = extensions.ID{ 31 | URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", 32 | Name: "rank"} 33 | extractID = extensions.ID{ 34 | URI: extensions.SubstraitDefaultURIPrefix + "functions_datetime.yaml", 35 | Name: "extract"} 36 | ntileID = extensions.ID{ 37 | URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", 38 | Name: "ntile"} 39 | 40 | boringSchema = types.NamedStruct{ 41 | Names: []string{ 42 | "bool", "i8", "i32", "i32_req", 43 | "point", "i64", "f32", "f32_req", 44 | "f64", "date_req", "str", "bin"}, 45 | Struct: types.StructType{ 46 | Nullability: types.NullabilityRequired, 47 | Types: []types.Type{ 48 | &types.BooleanType{}, 49 | &types.Int8Type{}, 50 | &types.Int32Type{}, 51 | &types.Int32Type{Nullability: types.NullabilityRequired}, 52 | &types.UserDefinedType{ 53 | TypeReference: uPointRef, 54 | }, 55 | &types.Int64Type{}, 56 | &types.Float32Type{}, 57 | &types.Float32Type{Nullability: types.NullabilityRequired}, 58 | &types.Float64Type{}, 59 | &types.DateType{Nullability: types.NullabilityRequired}, 60 | &types.StringType{}, 61 | &types.BinaryType{}, 62 | }, 63 | }, 64 | } 65 | ) 66 | 67 | func TestBoundExpressions(t *testing.T) { 68 | tests := []struct { 69 | ex Expression 70 | initialBound bool 71 | outputType types.Type 72 | }{ 73 | {NewPrimitiveLiteral(int32(1), true), true, 74 | &types.Int32Type{Nullability: types.NullabilityNullable}}, 75 | {MustExpr(NewRootFieldRef(NewStructFieldRef(10), types.NewRecordTypeFromStruct(boringSchema.Struct))), false, 76 | &types.StringType{}}, 77 | {MustExpr(NewRootFieldRefFromType( 78 | NewStructFieldRef(10), &types.StringType{})), false, 79 | &types.StringType{}}, 80 | {MustExpr(NewScalarFunc(extReg, subID, nil, 81 | NewPrimitiveLiteral(int8(1), false), 82 | NewPrimitiveLiteral(int8(5), false))), false, 83 | &types.Int8Type{Nullability: types.NullabilityRequired}}, 84 | {MustExpr(NewScalarFunc(extReg, addID, nil, 85 | NewPrimitiveLiteral(int8(1), false), 86 | MustExpr(NewRootFieldRef(NewStructFieldRef(1), types.NewRecordTypeFromStruct(boringSchema.Struct))))), false, 87 | &types.Int8Type{Nullability: types.NullabilityNullable}}, 88 | {MustExpr(NewScalarFunc(extReg, indexInID, nil, MustExpr(NewRootFieldRef(NewStructFieldRef(2), types.NewRecordTypeFromStruct(boringSchema.Struct))), 89 | NewListExpr(false, MustExpr(NewRootFieldRef(NewStructFieldRef(3), types.NewRecordTypeFromStruct(boringSchema.Struct))), 90 | NewPrimitiveLiteral(int32(10), true)))), false, 91 | &types.Int64Type{Nullability: types.NullabilityNullable}}, 92 | {MustExpr(NewWindowFunc(extReg, rankID, nil, types.AggInvocationAll, types.AggPhaseInitialToResult)), 93 | false, &types.Int64Type{Nullability: types.NullabilityNullable}}, 94 | {MustExpr(NewScalarFunc(extReg, extractID, nil, types.Enum("YEAR"), 95 | MustExpr(NewRootFieldRef(NewStructFieldRef(9), types.NewRecordTypeFromStruct(boringSchema.Struct))))), false, 96 | &types.Int64Type{Nullability: types.NullabilityRequired}}, 97 | } 98 | 99 | for _, tt := range tests { 100 | t.Run(tt.ex.String(), func(t *testing.T) { 101 | assert.Truef(t, tt.outputType.Equals(tt.ex.GetType()), "expected: %s\ngot: %s", tt.outputType, tt.ex.GetType()) 102 | }) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /expr/builder_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package expr_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/substrait-io/substrait-go/v4/expr" 11 | "github.com/substrait-io/substrait-go/v4/extensions" 12 | "github.com/substrait-io/substrait-go/v4/types" 13 | ) 14 | 15 | func TestExprBuilder(t *testing.T) { 16 | b := expr.ExprBuilder{ 17 | Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), 18 | BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), 19 | } 20 | precomputedLiteral, _ := expr.NewLiteral(int32(3), false) 21 | precomputedExpression, _ := b.ScalarFunc(addID).Args( 22 | b.Wrap(expr.NewLiteral(int32(3), false)), 23 | b.Wrap(expr.NewLiteral(int32(3), false))).BuildExpr() 24 | 25 | tests := []struct { 26 | name string 27 | expected string 28 | ex expr.Builder 29 | err string 30 | }{ 31 | {"literal", "i8?(5)", b.Wrap(expr.NewLiteral(int8(5), true)), ""}, 32 | {"preciseTimeStampliteral", "precision_timestamp?<3>(1970-01-01 00:02:03.456)", b.Wrap(expr.NewPrecisionTimestampLiteral(123456, types.PrecisionMilliSeconds, types.NullabilityNullable), nil), ""}, 33 | {"preciseTimeStampTzliteral", "precision_timestamp_tz?<6>(1970-01-01T00:00:00.123456Z)", b.Wrap(expr.NewPrecisionTimestampTzLiteral(123456, types.PrecisionMicroSeconds, types.NullabilityNullable), nil), ""}, 34 | {"simple add", "add(.field(1) => i8, i8(5)) => i8?", 35 | b.ScalarFunc(addID).Args( 36 | b.RootRef(expr.NewStructFieldRef(1)), 37 | b.Literal(expr.NewPrimitiveLiteral(int8(5), false)), 38 | ), ""}, 39 | {"expect args", "", 40 | b.ScalarFunc(indexInID), 41 | "invalid expression: mismatch in number of arguments provided. got 0, expected 2"}, 42 | {"with opt", "index_in(i32(5), list([]), {nan_equality: [NAN_IS_NAN]}) => i64?", 43 | b.ScalarFunc(indexInID, &types.FunctionOption{ 44 | Name: "nan_equality", 45 | Preference: []string{"NAN_IS_NAN"}}).Args( 46 | b.Wrap(expr.NewLiteral(int32(5), false)), 47 | b.Literal(expr.NewEmptyListLiteral(&types.Int32Type{}, false))), ""}, 48 | {"with cast", "subtract(.field(3) => i32, cast(.field(6) => fp32 AS i32, fail: FAILURE_BEHAVIOR_THROW_EXCEPTION)) => i32?", 49 | b.ScalarFunc(subID).Args( 50 | b.RootRef(expr.NewStructFieldRef(3)), 51 | b.Cast(b.RootRef(expr.NewStructFieldRef(6)), &types.Int32Type{}). 52 | FailBehavior(types.BehaviorThrowException), 53 | ), ""}, 54 | {"expression with lit", "subtract(.field(3) => i32, i32(3)) => i32", 55 | b.ScalarFunc(subID).Args(b.RootRef(expr.NewStructFieldRef(3)), 56 | b.Expression(precomputedLiteral)), ""}, 57 | {"expression with expr", "subtract(.field(3) => i32, add(i32(3), i32(3)) => i32) => i32", 58 | b.ScalarFunc(subID).Args(b.RootRef(expr.NewStructFieldRef(3)), 59 | b.Expression(precomputedExpression)), ""}, 60 | {"wrap expression", "subtract(.field(3) => i32, i32(3)) => i32", 61 | b.ScalarFunc(subID).Args(b.RootRef(expr.NewStructFieldRef(3)), 62 | b.Wrap(expr.NewLiteral(int32(3), false))), ""}, 63 | {"window func", "", 64 | b.WindowFunc(rankID), "invalid expression: non-decomposable window or agg function '{https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml rank}' must use InitialToResult phase"}, 65 | {"window func", "rank(; phase: AGGREGATION_PHASE_INITIAL_TO_RESULT, invocation: AGGREGATION_INVOCATION_UNSPECIFIED) => i64?", 66 | b.WindowFunc(rankID).Phase(types.AggPhaseInitialToResult), ""}, 67 | {"nested funcs", "add(extract(YEAR, date(2000-01-01)) => i64, rank(; phase: AGGREGATION_PHASE_INITIAL_TO_RESULT, invocation: AGGREGATION_INVOCATION_ALL) => i64?) => i64?", 68 | b.ScalarFunc(addID).Args( 69 | b.ScalarFunc(extractID).Args(b.Enum("YEAR"), 70 | b.Wrap(expr.NewLiteral(types.Date(10957), false))), 71 | b.WindowFunc(rankID).Phase(types.AggPhaseInitialToResult).Invocation(types.AggInvocationAll)), ""}, 72 | {"nested propagate error", "", 73 | b.ScalarFunc(addID).Args( 74 | b.RootRef(expr.NewListElemRef(0)), 75 | b.Literal(expr.NewPrimitiveLiteral(int32(5), false))), "error resolving ref type: invalid type"}, 76 | {"window func args", "ntile(i32(5); sort: [{expr: .field(1) => i8, SORT_DIRECTION_ASC_NULLS_FIRST}]; phase: AGGREGATION_PHASE_INITIAL_TO_RESULT, invocation: AGGREGATION_INVOCATION_UNSPECIFIED) => i32?", 77 | b.WindowFunc(ntileID).Args(b.Wrap(expr.NewLiteral(int32(5), false))). 78 | Phase(types.AggPhaseInitialToResult). 79 | Sort(expr.SortField{ 80 | Expr: expr.MustExpr(b.RootRef(expr.NewStructFieldRef(1)).Build()), 81 | Kind: types.SortAscNullsFirst}), ""}, 82 | {"window func arg error", "", 83 | b.WindowFunc(ntileID).Args(b.ScalarFunc(extensions.ID{})), 84 | "not found: could not find matching function for id: { :}"}, 85 | } 86 | 87 | for _, tt := range tests { 88 | t.Run(tt.name, func(t *testing.T) { 89 | e, err := tt.ex.BuildExpr() 90 | if tt.err == "" { 91 | require.NoError(t, err) 92 | assert.Equal(t, tt.expected, e.String()) 93 | // Also test that converting to proto does not panic. 94 | e.ToProto() 95 | } else { 96 | assert.EqualError(t, err, tt.err) 97 | } 98 | }) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /expr/decimal_util.go: -------------------------------------------------------------------------------- 1 | package expr 2 | 3 | import ( 4 | "fmt" 5 | "math/big" 6 | "regexp" 7 | "strings" 8 | 9 | "github.com/cockroachdb/apd/v3" 10 | ) 11 | 12 | var decimalPattern = regexp.MustCompile(`^[+-]?\d*(\.\d*)?([eE][+-]?\d*)?$`) 13 | 14 | // DecimalStringToBytes converts a decimal string to a 16-byte byte array. 15 | // 16-byte bytes represents a little-endian 128-bit integer, to be divided by 10^Scale to get the decimal value. 16 | // This function also returns the precision and scale of the decimal value. 17 | // The precision is the total number of digits in the decimal value. The precision is limited to 38 digits. 18 | // The scale is the number of digits to the right of the decimal point. The scale is limited to the precision. 19 | func DecimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) { 20 | var result [16]byte 21 | 22 | strings.Trim(decimalStr, " ") 23 | if !decimalPattern.MatchString(decimalStr) { 24 | return result, 0, 0, fmt.Errorf("invalid decimal string") 25 | } 26 | 27 | // Parse the decimal string using apd 28 | dec, cond, err := apd.NewFromString(decimalStr) 29 | if err != nil || cond.Any() { 30 | return result, 0, 0, fmt.Errorf("invalid decimal string %s: %v", decimalStr, err) 31 | } 32 | 33 | return DecimalToBytes(dec) 34 | } 35 | 36 | // DecimalToBytes converts apd.Decimal to a 16-byte byte array. 37 | // 16-byte bytes represents a little-endian 128-bit integer, to be divided by 10^Scale to get the decimal value. 38 | // This function also returns the precision and scale of the decimal value. 39 | func DecimalToBytes(dec *apd.Decimal) ([16]byte, int32, int32, error) { 40 | var ( 41 | result [16]byte 42 | precision int32 43 | scale int32 44 | ) 45 | 46 | if dec.Exponent > 0 { 47 | precision = int32(apd.NumDigits(&dec.Coeff)) + dec.Exponent 48 | scale = 0 49 | } else { 50 | scale = -dec.Exponent 51 | precision = max(int32(apd.NumDigits(&dec.Coeff)), scale+1) 52 | } 53 | if precision > 38 { 54 | return result, precision, scale, fmt.Errorf("number %s exceeds maximum precision of 38 (%d)", dec.String(), precision) 55 | } 56 | 57 | coefficient := dec.Coeff 58 | if dec.Exponent > 0 { 59 | // Multiply coefficient by 10^exponent. 60 | multiplier := apd.NewBigInt(1).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(dec.Exponent)), nil) 61 | coefficient.Mul(&dec.Coeff, multiplier) 62 | } 63 | // Convert the coefficient to a byte array. 64 | byteArray := coefficient.Bytes() 65 | if len(byteArray) > 16 { 66 | return result, 0, 0, fmt.Errorf("number exceeds 16 bytes") 67 | } 68 | copy(result[16-len(byteArray):], byteArray) 69 | 70 | // Handle the sign by taking the two's complement for negative numbers. 71 | if dec.Negative { 72 | negate(result[:]) 73 | } 74 | 75 | // Reverse the byte array to little-endian. 76 | for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { 77 | result[i], result[j] = result[j], result[i] 78 | } 79 | 80 | return result, precision, scale, nil 81 | } 82 | 83 | // negate flips the sign of a two-complements value by modifying it in place. 84 | func negate(bytes []byte) { 85 | for i := range bytes { 86 | bytes[i] = ^bytes[i] 87 | } 88 | carry := byte(1) 89 | for i := len(bytes) - 1; i >= 0; i-- { 90 | bytes[i] += carry 91 | if bytes[i] != 0 { 92 | break 93 | } 94 | } 95 | } 96 | 97 | func decimalBytesToString(decimalBytes [16]byte, scale int32) string { 98 | isNegative := decimalBytes[15]&0x80 != 0 99 | 100 | // Reverse the byte array to big-endian. 101 | processingValue := make([]byte, 16) 102 | for i := len(processingValue) - 1; i >= 0; i = i - 1 { 103 | processingValue[i] = decimalBytes[15-i] 104 | } 105 | if isNegative { 106 | negate(processingValue[:]) 107 | } 108 | 109 | // Convert into an apd.BigInt so it can handle the rendering. 110 | intValue := new(big.Int).SetBytes(processingValue[:]) 111 | if isNegative { 112 | intValue.Neg(intValue) 113 | } 114 | apdBigInt := new(apd.BigInt).SetMathBigInt(intValue) 115 | return apd.NewWithBigInt(apdBigInt, -scale).String() 116 | } 117 | 118 | func modifyDecimalPrecisionAndScale(decimalBytes [16]byte, scale, targetPrecision, targetScale int32) ([16]byte, int32, int32, error) { 119 | var result [16]byte 120 | if targetPrecision > 38 { 121 | return result, 0, 0, fmt.Errorf("target precision %d exceeds maximum allowed precision of 38", targetPrecision) 122 | } 123 | 124 | isNegative := decimalBytes[15]&0x80 != 0 125 | 126 | // Reverse the byte array to convert from little-endian to big-endian. 127 | processingValue := make([]byte, 16) 128 | for i := 0; i < 16; i++ { 129 | processingValue[i] = decimalBytes[15-i] 130 | } 131 | if isNegative { 132 | negate(processingValue[:]) 133 | } 134 | 135 | // Convert the bytes into a big.Int and wrap it into an apd.Decimal. 136 | intValue := new(big.Int).SetBytes(processingValue[:]) 137 | apdBigInt := new(apd.BigInt).SetMathBigInt(intValue) 138 | dec := apd.NewWithBigInt(apdBigInt, -scale) 139 | 140 | // Normalize the decimal by removing trailing zeros. 141 | dec.Reduce(dec) 142 | 143 | err2 := validatePrecisionAndScale(dec, targetPrecision, targetScale) 144 | if err2 != nil { 145 | return result, 0, 0, err2 146 | } 147 | 148 | // After ensuring the targetScale is sufficient for dec, adjust the scale to the target scale 149 | ctx := apd.BaseContext.WithPrecision(uint32(targetPrecision)) 150 | _, err := ctx.Quantize(dec, dec, -targetScale) 151 | if err != nil { 152 | return result, 0, 0, fmt.Errorf("error adjusting scale: %v", err) 153 | } 154 | 155 | // Convert the adjusted decimal coefficient to a byte array. 156 | byteArray := dec.Coeff.Bytes() 157 | if len(byteArray) > 16 { 158 | return result, 0, 0, fmt.Errorf("number exceeds 16 bytes") 159 | } 160 | copy(result[16-len(byteArray):], byteArray) 161 | 162 | // Handle the sign by applying two's complement for negative numbers. 163 | if isNegative { 164 | negate(result[:]) 165 | } 166 | 167 | // Reverse the byte array back to little-endian. 168 | for i, j := 0, 15; i < j; i, j = i+1, j-1 { 169 | result[i], result[j] = result[j], result[i] 170 | } 171 | 172 | return result, targetPrecision, targetScale, nil 173 | } 174 | 175 | func validatePrecisionAndScale(dec *apd.Decimal, targetPrecision int32, targetScale int32) error { 176 | // Validate the minimum precision and scale. 177 | minPrecision, minScale := getMinimumPrecisionAndScale(dec) 178 | if targetScale < minScale { 179 | return fmt.Errorf("number %v exceeds target scale %d, minimum scale needed is %d", dec.String(), targetScale, minScale) 180 | } 181 | if targetPrecision < minPrecision { 182 | return fmt.Errorf("number %s exceeds target precision %d, minimum precision needed is %d with target scale %d", dec.String(), targetPrecision, minPrecision, targetScale) 183 | } 184 | if targetPrecision-targetScale < minPrecision-minScale { 185 | return fmt.Errorf("number %v exceeds target precision %d with target scale %d, minimum precision needed is %d with minimum scale %d", dec.String(), targetPrecision, targetScale, minPrecision, minScale) 186 | } 187 | return nil 188 | } 189 | 190 | func getMinimumPrecisionAndScale(dec *apd.Decimal) (precision int32, scale int32) { 191 | if dec.Exponent > 0 { 192 | precision = int32(apd.NumDigits(&dec.Coeff)) + dec.Exponent 193 | scale = 0 194 | } else { 195 | scale = -dec.Exponent 196 | precision = max(int32(apd.NumDigits(&dec.Coeff)), scale) 197 | } 198 | return precision, scale 199 | } 200 | -------------------------------------------------------------------------------- /expr/interval_compound_literal.go: -------------------------------------------------------------------------------- 1 | package expr 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/substrait-io/substrait-go/v4/types" 8 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 9 | ) 10 | 11 | // IntervalCompoundLiteral creates an interval compound literal 12 | type IntervalCompoundLiteral struct { 13 | Years int32 14 | Months int32 15 | Days int32 16 | Seconds int32 17 | SubSeconds int64 18 | SubSecondPrecision types.TimePrecision 19 | Nullability types.Nullability 20 | } 21 | 22 | func (m IntervalCompoundLiteral) getType() types.Type { 23 | return types.NewIntervalCompoundType().WithPrecision(m.SubSecondPrecision).WithNullability(m.Nullability) 24 | } 25 | 26 | func (m IntervalCompoundLiteral) ToProtoLiteral() *proto.Expression_Literal { 27 | t := m.getType() 28 | intrCompPB := &proto.Expression_Literal_IntervalCompound{} 29 | 30 | if m.Years != 0 || m.Months != 0 { 31 | yearToMonthProto := &proto.Expression_Literal_IntervalYearToMonth{ 32 | Years: m.Years, 33 | Months: m.Months, 34 | } 35 | intrCompPB.IntervalYearToMonth = yearToMonthProto 36 | } 37 | 38 | if m.Days != 0 || m.Seconds != 0 || m.SubSeconds != 0 { 39 | dayToSecondProto := &proto.Expression_Literal_IntervalDayToSecond{ 40 | Days: m.Days, 41 | Seconds: m.Seconds, 42 | PrecisionMode: &proto.Expression_Literal_IntervalDayToSecond_Precision{Precision: m.SubSecondPrecision.ToProtoVal()}, 43 | Subseconds: m.SubSeconds, 44 | } 45 | intrCompPB.IntervalDayToSecond = dayToSecondProto 46 | } 47 | 48 | return &proto.Expression_Literal{ 49 | LiteralType: &proto.Expression_Literal_IntervalCompound_{IntervalCompound: intrCompPB}, 50 | Nullable: t.GetNullability() == types.NullabilityNullable, 51 | TypeVariationReference: t.GetTypeVariationReference(), 52 | } 53 | } 54 | 55 | func (m IntervalCompoundLiteral) ToProto() *proto.Expression { 56 | return &proto.Expression{RexType: &proto.Expression_Literal_{ 57 | Literal: m.ToProtoLiteral(), 58 | }} 59 | } 60 | 61 | func intervalCompoundLiteralFromProto(l *proto.Expression_Literal) Literal { 62 | icLiteral := IntervalCompoundLiteral{Nullability: getNullability(l.Nullable)} 63 | yearToMonth := l.GetIntervalCompound().GetIntervalYearToMonth() 64 | if yearToMonth != nil { 65 | icLiteral.Years = yearToMonth.Years 66 | icLiteral.Months = yearToMonth.Months 67 | } 68 | dayToSecond := l.GetIntervalCompound().GetIntervalDayToSecond() 69 | if dayToSecond == nil { 70 | // no day to second part 71 | return icLiteral 72 | } 73 | err := validateIntervalDayToSecondProto(dayToSecond) 74 | if err != nil { 75 | return nil 76 | } 77 | 78 | // get subSecond/precision value from proto. To get value it takes care of deprecated microseconds 79 | precision, subSeconds, err := intervalCompoundPrecisionSubSecondsFromProto(dayToSecond) 80 | if err != nil { 81 | return nil 82 | } 83 | icLiteral.Days = dayToSecond.Days 84 | icLiteral.Seconds = dayToSecond.Seconds 85 | icLiteral.SubSeconds = subSeconds 86 | icLiteral.SubSecondPrecision = precision 87 | return icLiteral 88 | } 89 | 90 | func (IntervalCompoundLiteral) isRootRef() {} 91 | func (m IntervalCompoundLiteral) GetType() types.Type { return m.getType() } 92 | func (m IntervalCompoundLiteral) String() string { 93 | return fmt.Sprintf("%s(%s)", m.getType(), m.ValueString()) 94 | } 95 | func (m IntervalCompoundLiteral) ValueString() string { 96 | return fmt.Sprintf("%d years, %d months, %d days, %d seconds, %d subseconds", 97 | m.Years, m.Months, m.Days, m.Seconds, m.SubSeconds) 98 | } 99 | func (m IntervalCompoundLiteral) Equals(rhs Expression) bool { 100 | if other, ok := rhs.(IntervalCompoundLiteral); ok { 101 | return m.getType().Equals(other.GetType()) && (m == other) 102 | } 103 | return false 104 | } 105 | 106 | func (m IntervalCompoundLiteral) ToProtoFuncArg() *proto.FunctionArgument { 107 | return &proto.FunctionArgument{ 108 | ArgType: &proto.FunctionArgument_Value{Value: m.ToProto()}, 109 | } 110 | } 111 | 112 | func (m IntervalCompoundLiteral) Visit(VisitFunc) Expression { return m } 113 | func (IntervalCompoundLiteral) IsScalar() bool { return true } 114 | 115 | func validateIntervalDayToSecondProto(idts *proto.Expression_Literal_IntervalDayToSecond) error { 116 | if idts.PrecisionMode == nil { 117 | // error, precision mode must be set for intervalCompound 118 | return errors.New("missing precision mode for interval compound") 119 | } 120 | if _, ok := idts.PrecisionMode.(*proto.Expression_Literal_IntervalDayToSecond_Microseconds); ok { 121 | // if microsecond precision then subseconds must be set to zero 122 | if idts.Subseconds > 0 { 123 | return errors.New("both deprecated microseconds and subseconds can't be non zero") 124 | } 125 | } 126 | return nil 127 | } 128 | 129 | func intervalCompoundPrecisionSubSecondsFromProto(protoVal *proto.Expression_Literal_IntervalDayToSecond) (types.TimePrecision, int64, error) { 130 | var precisionVal int32 131 | var subSecondVal int64 132 | switch pmt := protoVal.PrecisionMode.(type) { 133 | case *proto.Expression_Literal_IntervalDayToSecond_Precision: 134 | precisionVal = pmt.Precision 135 | subSecondVal = protoVal.Subseconds 136 | case *proto.Expression_Literal_IntervalDayToSecond_Microseconds: 137 | // deprecated field microsecond is set, treat its value subsecond 138 | precisionVal = types.PrecisionMicroSeconds.ToProtoVal() 139 | subSecondVal = int64(pmt.Microseconds) 140 | } 141 | precision, err := types.ProtoToTimePrecision(precisionVal) 142 | if err != nil { 143 | return types.PrecisionUnknown, 0, err 144 | } 145 | return precision, subSecondVal, nil 146 | } 147 | -------------------------------------------------------------------------------- /expr/interval_year_to_month.go: -------------------------------------------------------------------------------- 1 | package expr 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/substrait-io/substrait-go/v4/types" 7 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 8 | ) 9 | 10 | // IntervalYearToMonthLiteral implements Literal interface for interval year to month type 11 | type IntervalYearToMonthLiteral struct { 12 | Years int32 13 | Months int32 14 | Nullability types.Nullability 15 | } 16 | 17 | func (m IntervalYearToMonthLiteral) getType() types.Type { 18 | return types.NewIntervalYearToMonthType().WithNullability(m.Nullability) 19 | } 20 | 21 | func (m IntervalYearToMonthLiteral) ToProtoLiteral() *proto.Expression_Literal { 22 | t := m.getType() 23 | return &proto.Expression_Literal{ 24 | LiteralType: &proto.Expression_Literal_IntervalYearToMonth_{ 25 | IntervalYearToMonth: &proto.Expression_Literal_IntervalYearToMonth{ 26 | Years: m.Years, 27 | Months: m.Months, 28 | }, 29 | }, 30 | Nullable: t.GetNullability() == types.NullabilityNullable, 31 | TypeVariationReference: t.GetTypeVariationReference(), 32 | } 33 | } 34 | 35 | func (m IntervalYearToMonthLiteral) ToProto() *proto.Expression { 36 | return &proto.Expression{RexType: &proto.Expression_Literal_{ 37 | Literal: m.ToProtoLiteral(), 38 | }} 39 | } 40 | 41 | func intervalYearToMonthLiteralFromProto(l *proto.Expression_Literal) Literal { 42 | return IntervalYearToMonthLiteral{ 43 | Years: l.GetIntervalYearToMonth().Years, 44 | Months: l.GetIntervalYearToMonth().Months, 45 | Nullability: getNullability(l.Nullable), 46 | } 47 | } 48 | 49 | func (IntervalYearToMonthLiteral) isRootRef() {} 50 | func (m IntervalYearToMonthLiteral) GetType() types.Type { return m.getType() } 51 | func (m IntervalYearToMonthLiteral) String() string { 52 | return fmt.Sprintf("%s(%s)", m.getType(), m.ValueString()) 53 | } 54 | func (m IntervalYearToMonthLiteral) ValueString() string { 55 | return fmt.Sprintf("%d years, %d months", m.Years, m.Months) 56 | } 57 | func (m IntervalYearToMonthLiteral) Equals(rhs Expression) bool { 58 | if other, ok := rhs.(IntervalYearToMonthLiteral); ok { 59 | return m.getType().Equals(other.GetType()) && (m == other) 60 | } 61 | return false 62 | } 63 | 64 | func (m IntervalYearToMonthLiteral) ToProtoFuncArg() *proto.FunctionArgument { 65 | return &proto.FunctionArgument{ 66 | ArgType: &proto.FunctionArgument_Value{Value: m.ToProto()}, 67 | } 68 | } 69 | 70 | func (m IntervalYearToMonthLiteral) Visit(VisitFunc) Expression { return m } 71 | func (IntervalYearToMonthLiteral) IsScalar() bool { return true } 72 | -------------------------------------------------------------------------------- /expr/interval_year_to_month_test.go: -------------------------------------------------------------------------------- 1 | package expr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/substrait-io/substrait-go/v4/types" 9 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 10 | "google.golang.org/protobuf/testing/protocmp" 11 | ) 12 | 13 | func TestIntervalYearToMonthToProto(t *testing.T) { 14 | // nullability belong to type. In type unit tests they are already tested 15 | // for different values so no need to test for multiple values 16 | nullable := true 17 | nullability := types.NullabilityNullable 18 | var oneYear int32 = 1 19 | var oneMonth int32 = 1 20 | 21 | for _, tc := range []struct { 22 | name string 23 | literal Literal 24 | expectedExpression *proto.Expression 25 | }{ 26 | {"WithOnlyYear", 27 | IntervalYearToMonthLiteral{Nullability: nullability, Years: oneYear}, 28 | &proto.Expression{ 29 | RexType: &proto.Expression_Literal_{Literal: &proto.Expression_Literal{ 30 | LiteralType: &proto.Expression_Literal_IntervalYearToMonth_{ 31 | IntervalYearToMonth: &proto.Expression_Literal_IntervalYearToMonth{Years: oneYear}}, 32 | Nullable: nullable, 33 | }}, 34 | }, 35 | }, 36 | {"WithOnlyMonth", 37 | IntervalYearToMonthLiteral{Nullability: nullability, Months: oneMonth}, 38 | &proto.Expression{ 39 | RexType: &proto.Expression_Literal_{Literal: &proto.Expression_Literal{ 40 | LiteralType: &proto.Expression_Literal_IntervalYearToMonth_{ 41 | IntervalYearToMonth: &proto.Expression_Literal_IntervalYearToMonth{Months: oneMonth}}, 42 | Nullable: nullable, 43 | }}, 44 | }, 45 | }, 46 | } { 47 | t.Run(tc.name, func(t *testing.T) { 48 | toProto := tc.literal.ToProto() 49 | if diff := cmp.Diff(toProto, tc.expectedExpression, protocmp.Transform()); diff != "" { 50 | t.Errorf("expression proto didn't match, diff:\n%v", diff) 51 | } 52 | // verify ToProtoFuncArg 53 | funcArgProto := &proto.FunctionArgument{ 54 | ArgType: &proto.FunctionArgument_Value{Value: toProto}, 55 | } 56 | if diff := cmp.Diff(tc.literal.ToProtoFuncArg(), funcArgProto, protocmp.Transform()); diff != "" { 57 | t.Errorf("expression proto didn't match, diff:\n%v", diff) 58 | } 59 | }) 60 | 61 | } 62 | } 63 | 64 | func TestIntervalYearToMonthFromProto(t *testing.T) { 65 | nullable := true 66 | nullability := types.NullabilityNullable 67 | var oneYear int32 = 1 68 | var oneMonth int32 = 1 69 | for _, tc := range []struct { 70 | name string 71 | inputProto *proto.Expression_Literal 72 | expectedLiteral IntervalYearToMonthLiteral 73 | }{ 74 | {"OnlyYearToMonth", 75 | &proto.Expression_Literal{ 76 | LiteralType: &proto.Expression_Literal_IntervalYearToMonth_{ 77 | IntervalYearToMonth: &proto.Expression_Literal_IntervalYearToMonth{Years: oneYear, Months: oneMonth}}, 78 | Nullable: nullable}, 79 | IntervalYearToMonthLiteral{Years: oneYear, Months: oneMonth, Nullability: nullability}, 80 | }, 81 | } { 82 | t.Run(tc.name, func(t *testing.T) { 83 | gotLiteral := intervalYearToMonthLiteralFromProto(tc.inputProto) 84 | assert.Equal(t, tc.expectedLiteral, gotLiteral) 85 | // verify equal method too returns true 86 | assert.True(t, tc.expectedLiteral.Equals(gotLiteral)) 87 | assert.True(t, gotLiteral.IsScalar()) 88 | // got literal after serialization is different from empty literal 89 | assert.False(t, IntervalYearToMonthLiteral{}.Equals(gotLiteral)) 90 | }) 91 | 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /expr/proto_literals_test.go: -------------------------------------------------------------------------------- 1 | package expr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/substrait-io/substrait-go/v4/types" 9 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 10 | "google.golang.org/protobuf/testing/protocmp" 11 | ) 12 | 13 | func TestToProtoLiteral(t *testing.T) { 14 | for _, tc := range []struct { 15 | name string 16 | constructedLiteral *ProtoLiteral 17 | expectedExpressionLiteral *proto.Expression_Literal 18 | }{ 19 | {"TimeStampType", 20 | &ProtoLiteral{Value: int64(12345678), Type: types.NewPrecisionTimestampType(types.PrecisionEMinus4Seconds).WithNullability(types.NullabilityNullable)}, 21 | &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_PrecisionTimestamp_{PrecisionTimestamp: &proto.Expression_Literal_PrecisionTimestamp{Precision: 4, Value: 12345678}}, Nullable: true}, 22 | }, 23 | {"TimeStampTzType", 24 | &ProtoLiteral{Value: int64(12345678), Type: types.NewPrecisionTimestampTzType(types.PrecisionNanoSeconds).WithNullability(types.NullabilityNullable)}, 25 | &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_PrecisionTimestampTz{PrecisionTimestampTz: &proto.Expression_Literal_PrecisionTimestamp{Precision: 9, Value: 12345678}}, Nullable: true}, 26 | }, 27 | } { 28 | t.Run(tc.name, func(t *testing.T) { 29 | toProto := tc.constructedLiteral.ToProtoLiteral() 30 | if diff := cmp.Diff(toProto, tc.expectedExpressionLiteral, protocmp.Transform()); diff != "" { 31 | t.Errorf("proto didn't match, diff:\n%v", diff) 32 | } 33 | }) 34 | 35 | } 36 | } 37 | 38 | func TestLiteralFromProtoLiteral(t *testing.T) { 39 | intDayToSecVal := &proto.Expression_Literal_IntervalDayToSecond{Days: 1, Seconds: 2, PrecisionMode: &proto.Expression_Literal_IntervalDayToSecond_Precision{Precision: 5}} 40 | for _, tc := range []struct { 41 | name string 42 | constructedProto *proto.Expression_Literal 43 | expectedLiteral interface{} 44 | }{ 45 | {"TimeStampType", 46 | &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_PrecisionTimestamp_{PrecisionTimestamp: &proto.Expression_Literal_PrecisionTimestamp{Precision: 4, Value: 12345678}}, Nullable: true}, 47 | &ProtoLiteral{Value: int64(12345678), Type: types.NewPrecisionTimestampType(types.PrecisionEMinus4Seconds).WithNullability(types.NullabilityNullable)}, 48 | }, 49 | {"TimeStampTzType", 50 | &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_PrecisionTimestampTz{PrecisionTimestampTz: &proto.Expression_Literal_PrecisionTimestamp{Precision: 9, Value: 12345678}}, Nullable: true}, 51 | &ProtoLiteral{Value: int64(12345678), Type: types.NewPrecisionTimestampTzType(types.PrecisionNanoSeconds).WithNullability(types.NullabilityNullable)}, 52 | }, 53 | {"IntervalDayType", 54 | &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_IntervalDayToSecond_{IntervalDayToSecond: intDayToSecVal}, Nullable: true}, 55 | &ProtoLiteral{Value: intDayToSecVal, Type: &types.IntervalDayType{Precision: types.PrecisionEMinus5Seconds, Nullability: types.NullabilityNullable}}, 56 | }, 57 | {"IntervalYearToMonthType", 58 | &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_IntervalYearToMonth_{IntervalYearToMonth: &proto.Expression_Literal_IntervalYearToMonth{Years: 1234, Months: 5}}, Nullable: true}, 59 | IntervalYearToMonthLiteral{Years: 1234, Months: 5, Nullability: types.NullabilityNullable}, 60 | }, 61 | {"IntervalCompoundType", 62 | &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_IntervalCompound_{ 63 | IntervalCompound: &proto.Expression_Literal_IntervalCompound{ 64 | IntervalYearToMonth: &proto.Expression_Literal_IntervalYearToMonth{Years: 1234, Months: -5}, 65 | IntervalDayToSecond: &proto.Expression_Literal_IntervalDayToSecond{Days: 6, Seconds: -7, 66 | PrecisionMode: &proto.Expression_Literal_IntervalDayToSecond_Precision{Precision: 8}, 67 | Subseconds: -9, 68 | }, 69 | }}, Nullable: true}, 70 | IntervalCompoundLiteral{Years: 1234, Months: -5, Days: 6, Seconds: -7, SubSecondPrecision: 8, SubSeconds: -9, Nullability: types.NullabilityNullable}, 71 | }, 72 | } { 73 | t.Run(tc.name, func(t *testing.T) { 74 | literal := LiteralFromProto(tc.constructedProto) 75 | assert.Equal(t, tc.expectedLiteral, literal) 76 | }) 77 | 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /expr/testdata/expressions.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | baseSchema: 3 | names: [a, b, c, d] 4 | struct: 5 | nullability: NULLABILITY_REQUIRED 6 | types: 7 | - i32: { nullability: NULLABILITY_REQUIRED } 8 | - i8: { nullability: NULLABILITY_REQUIRED } 9 | - i16: { nullability: NULLABILITY_REQUIRED } 10 | - i16: { nullability: NULLABILITY_NULLABLE } 11 | cases: 12 | - name: scalar-func 13 | __test: 14 | type: fp64 15 | string: | 16 | add(fp64(1), subtract(.field(3) => i16?, multiply(i64(2), [root:(struct([binary?([98 97 122]) string(foobar) i32(5)]))].field(2) => i32) => i64) => fp32) => fp64 17 | expression: 18 | scalarFunction: 19 | functionReference: 2 20 | arguments: 21 | - value: 22 | literal: { fp64: 1.0 } 23 | - value: 24 | scalarFunction: 25 | functionReference: 3 26 | outputType: { fp32: {} } 27 | arguments: 28 | - value: 29 | selection: 30 | rootReference: {} 31 | directReference: { structField: { field: 3 }} 32 | - value: 33 | scalarFunction: 34 | functionReference: 4 35 | outputType: { i64: {} } 36 | arguments: 37 | - value: { literal: { i64: 2 } } 38 | - value: 39 | selection: 40 | expression: 41 | literal: 42 | struct: 43 | fields: 44 | - { binary: "YmF6", nullable: true } 45 | - { string: "foobar", nullable: false } 46 | - { i32: 5, nullable: false } 47 | directReference: { structField: { field: 2 } } 48 | outputType: { fp64: {} } 49 | - name: window-func 50 | __test: 51 | type: i32? 52 | string: | 53 | ntile(.field(1) => i8; sort: [{expr: ((.field(0) => i32) ?.field(1) => i8)(i32(1)), SORT_DIRECTION_CLUSTERED}]; [options: {DECOMPOSABLE => [NONE]}]; partitions: [.field(2) => i16]; phase: AGGREGATION_PHASE_INITIAL_TO_RESULT, invocation: AGGREGATION_INVOCATION_ALL) => i32? 54 | expression: 55 | windowFunction: 56 | functionReference: 5 57 | outputType: { i32: { nullability: "NULLABILITY_NULLABLE" } } 58 | options: 59 | - name: "DECOMPOSABLE" 60 | preference: ["NONE"] 61 | arguments: 62 | - value: 63 | selection: 64 | rootReference: {} 65 | directReference: { structField: { field: 1 } } 66 | phase: "AGGREGATION_PHASE_INITIAL_TO_RESULT" 67 | invocation: "AGGREGATION_INVOCATION_ALL" 68 | sorts: 69 | - direction: "SORT_DIRECTION_CLUSTERED" 70 | expr: 71 | ifThen: 72 | ifs: 73 | - if: 74 | selection: 75 | rootReference: {} 76 | directReference: { structField: { field: 0 } } 77 | then: 78 | selection: 79 | rootReference: {} 80 | directReference: { structField: { field: 1 } } 81 | else: 82 | literal: { i32: 1, nullable: false } 83 | lowerBound: { unbounded: {} } 84 | upperBound: { currentRow: {} } 85 | partitions: 86 | - selection: 87 | rootReference: {} 88 | directReference: { structField: { field: 2 } } 89 | - name: ifthen 90 | __test: { type: "i8" } 91 | expression: 92 | ifThen: 93 | ifs: 94 | - if: 95 | selection: 96 | rootReference: {} 97 | directReference: { structField: { field: 0 } } 98 | then: 99 | selection: 100 | rootReference: {} 101 | directReference: { structField: { field: 1 } } 102 | else: 103 | selection: 104 | rootReference: {} 105 | directReference: { structField: { field: 1 } } 106 | - name: switch-expr 107 | __test: { type: "i16?"} 108 | expression: 109 | switchExpression: 110 | match: 111 | selection: 112 | rootReference: {} 113 | directReference: { structField: { field: 0 }} 114 | ifs: 115 | - if: { i32: 0 } 116 | then: 117 | selection: 118 | rootReference: {} 119 | directReference: { structField: { field: 2 }} 120 | - if: { i32: 1 } 121 | then: 122 | selection: 123 | rootReference: {} 124 | directReference: { structField: { field: 3 } } 125 | else: 126 | selection: 127 | rootReference: {} 128 | directReference: { structField: { field: 2 } } 129 | - name: singular-or-list 130 | __test: 131 | type: "boolean" 132 | string: | 133 | .field(2) => i16 IN [i16(1),i16(2),i16(3)] 134 | expression: 135 | singularOrList: 136 | value: 137 | selection: 138 | rootReference: {} 139 | directReference: { structField: { field: 2 } } 140 | options: 141 | - literal: { i16: 1 } 142 | - literal: { i16: 2 } 143 | - literal: { i16: 3 } 144 | - name: multi-or-list 145 | __test: 146 | type: "boolean" 147 | string: | 148 | [.field(1) => i8, .field(2) => i16] IN [[i8(1), i16(2)], [i8(3), i16(4)], [i8(5), i16(6)]] 149 | expression: 150 | multiOrList: 151 | value: 152 | - selection: 153 | rootReference: {} 154 | directReference: { structField: { field: 1 }} 155 | - selection: 156 | rootReference: {} 157 | directReference: { structField: { field: 2 }} 158 | options: 159 | - fields: 160 | - literal: { i8: 1 } 161 | - literal: { i16: 2 } 162 | - fields: 163 | - literal: { i8: 3 } 164 | - literal: { i16: 4 } 165 | - fields: 166 | - literal: { i8: 5 } 167 | - literal: { i16: 6 } 168 | - name: nested-map-expr 169 | __test: { type: "map?" } 170 | expression: 171 | nested: 172 | nullable: true 173 | map: 174 | keyValues: 175 | - key: 176 | literal: { i8: 1 } 177 | value: 178 | selection: 179 | rootReference: {} 180 | directReference: { structField: { field: 2 }} 181 | - key: 182 | selection: 183 | rootReference: {} 184 | directReference: { structField: { field: 1 }} 185 | value: 186 | literal: { i16: 2 } 187 | - key: 188 | selection: 189 | rootReference: {} 190 | directReference: { structField: { field: 0 }} 191 | value: 192 | selection: 193 | rootReference: {} 194 | directReference: { structField: { field: 3 }} 195 | - name: nested-struct-expr 196 | __test: { type: "struct" } 197 | expression: 198 | nested: 199 | nullable: false 200 | struct: 201 | fields: 202 | - literal: { fp32: 1.5 } 203 | - selection: 204 | rootReference: {} 205 | directReference: { structField: { field: 0 } } 206 | - literal: { fp64: 1.5 } 207 | - selection: 208 | rootReference: {} 209 | directReference: { structField: { field: 3 } } 210 | - name: nested-list-expr 211 | __test: { type: "list?" } 212 | expression: 213 | nested: 214 | nullable: true 215 | list: 216 | values: 217 | - selection: 218 | rootReference: {} 219 | directReference: { structField: { field: 3 } } 220 | - literal: { i16: 1 } 221 | - name: cast 222 | __test: 223 | type: i64 224 | string: | 225 | cast(i16(6) AS i64, fail: FAILURE_BEHAVIOR_RETURN_NULL) 226 | expression: 227 | cast: 228 | type: { i64: {} } 229 | failureBehavior: FAILURE_BEHAVIOR_RETURN_NULL 230 | input: 231 | literal: { i16: 6 } -------------------------------------------------------------------------------- /expr/testdata/extended_exprs.yaml: -------------------------------------------------------------------------------- 1 | tests: 2 | - version: { producer: substraitgo-test } 3 | extensionUris: 4 | - extensionUriAnchor: 1 5 | uri: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml 6 | extensions: 7 | - extensionFunction: 8 | extensionUriReference: 1 9 | functionAnchor: 2 10 | name: add:i64_i64 11 | - extensionFunction: 12 | extensionUriReference: 1 13 | functionAnchor: 3 14 | name: subtract:i64_i64 15 | - extensionFunction: 16 | extensionUriReference: 1 17 | functionAnchor: 4 18 | name: multiply:i64_i64 19 | - extensionFunction: 20 | extensionUriReference: 1 21 | functionAnchor: 5 22 | name: ntile:i64_i64 23 | - extensionFunction: 24 | extensionUriReference: 1 25 | functionAnchor: 6 26 | name: sum:i64 27 | baseSchema: 28 | names: [a, b, c, d] 29 | struct: 30 | nullability: NULLABILITY_REQUIRED 31 | types: 32 | - i32: { nullability: NULLABILITY_REQUIRED } 33 | - i8: { nullability: NULLABILITY_REQUIRED } 34 | - i16: { nullability: NULLABILITY_REQUIRED } 35 | - i16: { nullability: NULLABILITY_NULLABLE } 36 | expectedTypeUrls: 37 | - substrait.Plan 38 | referredExpr: 39 | - outputNames: [sum] 40 | measure: 41 | functionReference: 6 42 | outputType: { i64: {} } 43 | arguments: 44 | - type: { i64: {} } 45 | - value: 46 | selection: 47 | rootReference: {} 48 | directReference: { structField: { field: 0 }} 49 | - outputNames: [x] 50 | expression: 51 | scalarFunction: 52 | functionReference: 2 53 | arguments: 54 | - value: 55 | selection: 56 | rootReference: {} 57 | directReference: { structField: { field: 1 }} 58 | - value: 59 | selection: 60 | rootReference: {} 61 | directReference: { structField: { field: 2 }} 62 | outputType: { i32: {} } 63 | -------------------------------------------------------------------------------- /expr/utils.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package expr 4 | 5 | import "github.com/substrait-io/substrait-go/v4/extensions" 6 | 7 | type ExtensionRegistry struct { 8 | extensions.Set 9 | c *extensions.Collection 10 | } 11 | 12 | // NewExtensionRegistry creates a new registry. If you have an existing plan you can use GetExtensionSet() to 13 | // populate an extensions.Set. 14 | func NewExtensionRegistry(extSet extensions.Set, c *extensions.Collection) ExtensionRegistry { 15 | if c == nil { 16 | panic("cannot create registry with nil collection") 17 | } 18 | return ExtensionRegistry{Set: extSet, c: c} 19 | } 20 | 21 | // NewEmptyExtensionRegistry creates an empty registry useful starting from scratch. 22 | func NewEmptyExtensionRegistry(c *extensions.Collection) ExtensionRegistry { 23 | return NewExtensionRegistry(extensions.NewSet(), c) 24 | } 25 | 26 | func (e *ExtensionRegistry) LookupTypeVariation(anchor uint32) (extensions.TypeVariation, bool) { 27 | return e.Set.LookupTypeVariation(anchor, e.c) 28 | } 29 | 30 | func (e *ExtensionRegistry) LookupType(anchor uint32) (extensions.Type, bool) { 31 | return e.Set.LookupType(anchor, e.c) 32 | } 33 | 34 | // LookupScalarFunction returns a ScalarFunctionVariant associated with a previously used function's anchor. 35 | func (e *ExtensionRegistry) LookupScalarFunction(anchor uint32) (*extensions.ScalarFunctionVariant, bool) { 36 | return e.Set.LookupScalarFunction(anchor, e.c) 37 | } 38 | 39 | // LookupAggregateFunction returns an AggregateFunctionVariant associated with a previously used function's anchor. 40 | func (e *ExtensionRegistry) LookupAggregateFunction(anchor uint32) (*extensions.AggregateFunctionVariant, bool) { 41 | return e.Set.LookupAggregateFunction(anchor, e.c) 42 | } 43 | 44 | // LookupWindowFunction returns a WindowFunctionVariant associated with a previously used function's anchor. 45 | func (e *ExtensionRegistry) LookupWindowFunction(anchor uint32) (*extensions.WindowFunctionVariant, bool) { 46 | return e.Set.LookupWindowFunction(anchor, e.c) 47 | } 48 | -------------------------------------------------------------------------------- /functions/functions.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import ( 4 | "github.com/substrait-io/substrait-go/v4/extensions" 5 | ) 6 | 7 | type functionRegistryImpl struct { 8 | scalarFunctions map[string][]*extensions.ScalarFunctionVariant 9 | aggregateFunctions map[string][]*extensions.AggregateFunctionVariant 10 | windowFunctions map[string][]*extensions.WindowFunctionVariant 11 | allFunctions []extensions.FunctionVariant 12 | } 13 | 14 | func getOrEmpty[K comparable, V any](key K, m map[K][]V) []V { 15 | if value, exists := m[key]; exists { 16 | return value 17 | } 18 | 19 | return make([]V, 0) 20 | } 21 | 22 | var _ FunctionRegistry = &functionRegistryImpl{} 23 | 24 | func NewFunctionRegistry(collection *extensions.Collection) FunctionRegistry { 25 | scalarFunctions := make(map[string][]*extensions.ScalarFunctionVariant) 26 | aggregateFunctions := make(map[string][]*extensions.AggregateFunctionVariant) 27 | windowFunctions := make(map[string][]*extensions.WindowFunctionVariant) 28 | allFunctions := make([]extensions.FunctionVariant, 0) 29 | 30 | processFunctions(collection.GetAllScalarFunctions(), scalarFunctions, &allFunctions) 31 | processFunctions(collection.GetAllAggregateFunctions(), aggregateFunctions, &allFunctions) 32 | processFunctions(collection.GetAllWindowFunctions(), windowFunctions, &allFunctions) 33 | 34 | return &functionRegistryImpl{ 35 | scalarFunctions: scalarFunctions, 36 | aggregateFunctions: aggregateFunctions, 37 | windowFunctions: windowFunctions, 38 | allFunctions: allFunctions, 39 | } 40 | } 41 | 42 | func processFunctions[T extensions.FunctionVariant](functions []T, funcMap map[string][]T, allFunctions *[]extensions.FunctionVariant) { 43 | for _, f := range functions { 44 | name := f.Name() 45 | if _, ok := funcMap[name]; !ok { 46 | funcMap[name] = make([]T, 0) 47 | } 48 | funcMap[name] = append(funcMap[name], f) 49 | *allFunctions = append(*allFunctions, f) 50 | } 51 | } 52 | 53 | func (f *functionRegistryImpl) GetAllFunctions() []extensions.FunctionVariant { 54 | return f.allFunctions 55 | } 56 | 57 | func (f *functionRegistryImpl) GetScalarFunctionsByName(name string) []*extensions.ScalarFunctionVariant { 58 | return getOrEmpty(name, f.scalarFunctions) 59 | } 60 | 61 | func (f *functionRegistryImpl) GetAggregateFunctionsByName(name string) []*extensions.AggregateFunctionVariant { 62 | return getOrEmpty(name, f.aggregateFunctions) 63 | } 64 | 65 | func (f *functionRegistryImpl) GetWindowFunctionsByName(name string) []*extensions.WindowFunctionVariant { 66 | return getOrEmpty(name, f.windowFunctions) 67 | } 68 | 69 | func (f *functionRegistryImpl) GetScalarFunctions(name string, numArgs int) []*extensions.ScalarFunctionVariant { 70 | return getFunctionVariantsByCount(f.GetScalarFunctionsByName(name), numArgs) 71 | } 72 | 73 | func (f *functionRegistryImpl) GetAggregateFunctions(name string, numArgs int) []*extensions.AggregateFunctionVariant { 74 | return getFunctionVariantsByCount(f.GetAggregateFunctionsByName(name), numArgs) 75 | } 76 | 77 | func (f *functionRegistryImpl) GetWindowFunctions(name string, numArgs int) []*extensions.WindowFunctionVariant { 78 | return getFunctionVariantsByCount(f.GetWindowFunctionsByName(name), numArgs) 79 | } 80 | 81 | func getFunctionVariantsByCount[T extensions.FunctionVariant](functions []T, numArgs int) []T { 82 | ret := make([]T, 0) 83 | for _, f := range functions { 84 | if len(f.Args()) == numArgs || f.Variadic().IsValidArgumentCount(numArgs) { 85 | ret = append(ret, f) 86 | } 87 | } 88 | return ret 89 | } 90 | 91 | var _ FunctionRegistry = &functionRegistryImpl{} 92 | -------------------------------------------------------------------------------- /functions/local_functions.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/substrait-io/substrait-go/v4/expr" 7 | "github.com/substrait-io/substrait-go/v4/extensions" 8 | ) 9 | 10 | type localFunctionRegistryImpl struct { 11 | dialect Dialect 12 | 13 | // substrait function name to local function variants 14 | scalarFunctions map[FunctionName][]*LocalScalarFunctionVariant 15 | aggregateFunctions map[FunctionName][]*LocalAggregateFunctionVariant 16 | windowFunctions map[FunctionName][]*LocalWindowFunctionVariant 17 | 18 | allFunctions []extensions.FunctionVariant 19 | 20 | idToLocalFunctionMap map[extensions.ID]localFunctionVariant 21 | localTypeRegistry LocalTypeRegistry 22 | funcRegistry FunctionRegistry 23 | } 24 | 25 | func makeLocalFunctionVariantsMap(functions []extensions.FunctionVariant) map[extensions.ID]localFunctionVariant { 26 | localFunctionVariants := make(map[extensions.ID]localFunctionVariant) 27 | for _, f := range functions { 28 | switch variant := f.(type) { 29 | case *LocalScalarFunctionVariant: 30 | localFunctionVariants[variant.ID()] = variant 31 | case *LocalAggregateFunctionVariant: 32 | localFunctionVariants[variant.ID()] = variant 33 | case *LocalWindowFunctionVariant: 34 | localFunctionVariants[variant.ID()] = variant 35 | } 36 | } 37 | return localFunctionVariants 38 | } 39 | 40 | func (l *localFunctionRegistryImpl) GetAllFunctions() []extensions.FunctionVariant { 41 | return l.allFunctions 42 | } 43 | 44 | func (l *localFunctionRegistryImpl) GetDialect() Dialect { 45 | return l.dialect 46 | } 47 | 48 | func (l *localFunctionRegistryImpl) GetFunctionRegistry() FunctionRegistry { 49 | return l.funcRegistry 50 | } 51 | 52 | func (l *localFunctionRegistryImpl) GetScalarFunctions(name FunctionName, numArgs int) []*LocalScalarFunctionVariant { 53 | return getFunctionVariantsByCount(getOrEmpty(name, l.scalarFunctions), numArgs) 54 | } 55 | 56 | func (l *localFunctionRegistryImpl) GetAggregateFunctions(name FunctionName, numArgs int) []*LocalAggregateFunctionVariant { 57 | return getFunctionVariantsByCount(getOrEmpty(name, l.aggregateFunctions), numArgs) 58 | } 59 | 60 | func (l *localFunctionRegistryImpl) GetWindowFunctions(name FunctionName, numArgs int) []*LocalWindowFunctionVariant { 61 | return getFunctionVariantsByCount(getOrEmpty(name, l.windowFunctions), numArgs) 62 | } 63 | 64 | func (l *localFunctionRegistryImpl) GetScalarFunctionByInvocation(scalarFuncInvocation *expr.ScalarFunction) (*LocalScalarFunctionVariant, error) { 65 | return getFunctionVariantByInvocation[*LocalScalarFunctionVariant](scalarFuncInvocation, l) 66 | } 67 | 68 | func (l *localFunctionRegistryImpl) GetAggregateFunctionByInvocation(aggregateFuncInvocation *expr.AggregateFunction) (*LocalAggregateFunctionVariant, error) { 69 | return getFunctionVariantByInvocation[*LocalAggregateFunctionVariant](aggregateFuncInvocation, l) 70 | } 71 | 72 | func (l *localFunctionRegistryImpl) GetWindowFunctionByInvocation(windowFuncInvocation *expr.WindowFunction) (*LocalWindowFunctionVariant, error) { 73 | return getFunctionVariantByInvocation[*LocalWindowFunctionVariant](windowFuncInvocation, l) 74 | } 75 | 76 | func getFunctionVariantByInvocation[V localFunctionVariant](invocation expr.FunctionInvocation, registry *localFunctionRegistryImpl) (V, error) { 77 | var zeroV V 78 | f, ok := registry.idToLocalFunctionMap[invocation.ID()] 79 | if !ok { 80 | return zeroV, fmt.Errorf("function variant not found for function: %s", invocation.ID()) 81 | } 82 | argTypes := invocation.GetArgTypes() 83 | for i, argType := range argTypes { 84 | _, err := registry.localTypeRegistry.GetLocalTypeFromSubstraitType(argType) 85 | if err != nil { 86 | return zeroV, fmt.Errorf("unsupported substrait type: %v as argument %d in %s", argType, i, invocation.CompoundName()) 87 | } 88 | } 89 | for _, option := range invocation.GetOptions() { 90 | for _, value := range option.Preference { 91 | if !f.IsOptionSupported(option.Name, value) { 92 | return zeroV, fmt.Errorf("unsupported option [%s:%s] in function %s", option.Name, value, invocation.CompoundName()) 93 | } 94 | } 95 | } 96 | return f.(V), nil 97 | } 98 | 99 | var _ LocalFunctionRegistry = &localFunctionRegistryImpl{} 100 | -------------------------------------------------------------------------------- /functions/types.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package functions 4 | 5 | import ( 6 | "strconv" 7 | "strings" 8 | 9 | substraitgo "github.com/substrait-io/substrait-go/v4" 10 | "github.com/substrait-io/substrait-go/v4/types" 11 | ) 12 | 13 | var ( 14 | nameToTypeMap map[string]types.Type 15 | toShortNameMap map[string]string 16 | ) 17 | 18 | func init() { 19 | initTypeMaps() 20 | } 21 | 22 | func initTypeMaps() { 23 | nameToTypeMap = types.GetTypeNameToTypeMap() 24 | toShortNameMap = make(map[string]string) 25 | for k := range nameToTypeMap { 26 | shortName := types.GetShortTypeName(types.TypeName(k)) 27 | if shortName != k { 28 | toShortNameMap[k] = shortName 29 | } 30 | } 31 | for k, v := range toShortNameMap { 32 | nameToTypeMap[v] = nameToTypeMap[k] 33 | } 34 | } 35 | 36 | func getTypeFromBaseTypeName(baseType string) (types.Type, error) { 37 | if typ, ok := nameToTypeMap[baseType]; ok { 38 | return typ, nil 39 | } 40 | return nil, substraitgo.ErrNotFound 41 | } 42 | 43 | var substraitEnclosure = &substraitTypeEnclosure{} 44 | 45 | func isSupportedType(typeString string) bool { 46 | _, ok := nameToTypeMap[typeString] 47 | return ok 48 | } 49 | 50 | type typeRegistryImpl struct { 51 | typeMap map[string]types.Type 52 | } 53 | 54 | func NewTypeRegistry() TypeRegistry { 55 | return &typeRegistryImpl{typeMap: nameToTypeMap} 56 | } 57 | 58 | func (t *typeRegistryImpl) GetTypeFromTypeString(typeString string) (types.Type, error) { 59 | return getTypeFromTypeString(typeString, t.typeMap, substraitEnclosure) 60 | } 61 | 62 | func getTypeFromTypeString(typeString string, typeMap map[string]types.Type, enclosure typeEnclosure) (types.Type, error) { 63 | baseType, parameters, err := extractTypeAndParameters(typeString, enclosure) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | nullable := types.NullabilityRequired 69 | if strings.HasSuffix(baseType, "?") { 70 | baseType = baseType[:len(baseType)-1] 71 | nullable = types.NullabilityNullable 72 | } 73 | if typ, ok := typeMap[baseType]; ok { 74 | if typ, err = getTypeWithParameters(typ, parameters); err != nil { 75 | return nil, err 76 | } 77 | return typ.WithNullability(nullable), nil 78 | } 79 | return nil, substraitgo.ErrNotFound 80 | } 81 | 82 | func getTypeWithParameters(typ types.Type, parameters []int32) (types.Type, error) { 83 | switch typ.(type) { 84 | case *types.DecimalType: 85 | if len(parameters) != 2 { 86 | return nil, substraitgo.ErrInvalidType 87 | } 88 | return &types.DecimalType{Precision: parameters[0], Scale: parameters[1]}, nil 89 | case *types.FixedBinaryType, *types.FixedCharType, *types.VarCharType: 90 | if len(parameters) != 1 { 91 | return nil, substraitgo.ErrInvalidType 92 | } 93 | switch typ.(type) { 94 | case *types.FixedBinaryType: 95 | return &types.FixedBinaryType{Length: parameters[0]}, nil 96 | case *types.FixedCharType: 97 | return &types.FixedCharType{Length: parameters[0]}, nil 98 | case *types.VarCharType: 99 | return &types.VarCharType{Length: parameters[0]}, nil 100 | } 101 | default: 102 | if len(parameters) != 0 { 103 | return nil, substraitgo.ErrInvalidType 104 | } 105 | } 106 | return typ, nil 107 | } 108 | 109 | func extractTypeAndParameters(typeString string, enclosure typeEnclosure) (string, []int32, error) { 110 | conStart, conEnd := enclosure.containerStart(), enclosure.containerEnd() 111 | if !strings.Contains(typeString, conStart) || !strings.HasSuffix(typeString, conEnd) { 112 | return typeString, nil, nil 113 | } 114 | baseType := typeString[:strings.Index(typeString, conStart)] 115 | paramStr := typeString[strings.Index(typeString, conStart)+1 : len(typeString)-len(conEnd)] 116 | params := strings.Split(paramStr, ",") 117 | parameters := make([]int32, len(params)) 118 | for i, p := range params { 119 | intValue, err := strconv.ParseInt(p, 10, 32) 120 | if err != nil { 121 | return "", nil, err 122 | } 123 | parameters[i] = int32(intValue) 124 | } 125 | return baseType, parameters, nil 126 | } 127 | 128 | type typeEnclosure interface { 129 | containerStart() string 130 | containerEnd() string 131 | } 132 | 133 | type substraitTypeEnclosure struct{} 134 | 135 | func (t *substraitTypeEnclosure) containerStart() string { 136 | return "<" 137 | } 138 | 139 | func (t *substraitTypeEnclosure) containerEnd() string { 140 | return ">" 141 | } 142 | 143 | type typeInfo struct { 144 | typ types.Type 145 | shortName string 146 | localName string 147 | supportedAsColumn bool 148 | } 149 | 150 | func (ti *typeInfo) getLongName() string { 151 | switch ti.typ.(type) { 152 | case types.CompositeType: 153 | return ti.typ.(types.CompositeType).BaseString() 154 | } 155 | return ti.typ.String() 156 | } 157 | 158 | func (ti *typeInfo) getLocalTypeString(input types.Type, enclosure typeEnclosure) string { 159 | if paramType, ok := input.(types.CompositeType); ok { 160 | return ti.localName + enclosure.containerStart() + paramType.ParameterString() + enclosure.containerEnd() 161 | } 162 | return ti.localName 163 | } 164 | 165 | type localTypeRegistryImpl struct { 166 | nameToType map[string]types.Type 167 | localNameToType map[string]types.Type 168 | typeInfoMap map[string]typeInfo 169 | } 170 | 171 | func NewLocalTypeRegistry(typeInfos []typeInfo) LocalTypeRegistry { 172 | nameToType := make(map[string]types.Type) 173 | localNameToType := make(map[string]types.Type) 174 | typeInfoMap := make(map[string]typeInfo) 175 | for _, ti := range typeInfos { 176 | nameToType[ti.shortName] = ti.typ 177 | localNameToType[ti.localName] = ti.typ 178 | typeInfoMap[ti.shortName] = ti 179 | longName := ti.getLongName() 180 | if longName != ti.shortName { 181 | nameToType[longName] = ti.typ 182 | typeInfoMap[longName] = ti 183 | } 184 | } 185 | return &localTypeRegistryImpl{ 186 | nameToType: nameToType, 187 | localNameToType: localNameToType, 188 | typeInfoMap: typeInfoMap, 189 | } 190 | } 191 | 192 | func (t *localTypeRegistryImpl) containerStart() string { 193 | return "(" 194 | } 195 | 196 | func (t *localTypeRegistryImpl) containerEnd() string { 197 | return ")" 198 | } 199 | 200 | func (t *localTypeRegistryImpl) GetTypeFromTypeString(typeString string) (types.Type, error) { 201 | return getTypeFromTypeString(typeString, t.nameToType, substraitEnclosure) 202 | } 203 | 204 | func (t *localTypeRegistryImpl) GetSubstraitTypeFromLocalType(localType string) (types.Type, error) { 205 | return getTypeFromTypeString(localType, t.localNameToType, t) 206 | } 207 | 208 | func (t *localTypeRegistryImpl) GetLocalTypeFromSubstraitType(typ types.Type) (string, error) { 209 | // TODO handle nullable 210 | name := typ.ShortString() 211 | if ti, ok := t.typeInfoMap[name]; ok { 212 | return ti.getLocalTypeString(typ, t), nil 213 | } 214 | return "", substraitgo.ErrNotFound 215 | } 216 | 217 | func (t *localTypeRegistryImpl) GetSupportedTypes() map[string]types.Type { 218 | return t.localNameToType 219 | } 220 | 221 | func (t *localTypeRegistryImpl) IsTypeSupportedInTables(typ types.Type) bool { 222 | if ti, ok := t.typeInfoMap[typ.ShortString()]; ok { 223 | return ti.supportedAsColumn 224 | } 225 | return false 226 | } 227 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | module github.com/substrait-io/substrait-go/v4 4 | 5 | go 1.22.0 6 | 7 | toolchain go1.22.3 8 | 9 | require ( 10 | cloud.google.com/go v0.118.0 11 | github.com/antlr4-go/antlr/v4 v4.13.1 12 | github.com/cockroachdb/apd/v3 v3.2.1 13 | github.com/creasty/defaults v1.8.0 14 | github.com/goccy/go-yaml v1.11.0 15 | github.com/google/go-cmp v0.6.0 16 | github.com/google/uuid v1.6.0 17 | github.com/stretchr/testify v1.10.0 18 | github.com/substrait-io/substrait v0.66.1-0.20250205013839-a30b3e2d7ec6 19 | github.com/substrait-io/substrait-protobuf/go v0.66.1 20 | golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 21 | google.golang.org/protobuf v1.35.2 22 | gopkg.in/yaml.v3 v3.0.1 23 | ) 24 | 25 | require ( 26 | github.com/davecgh/go-spew v1.1.1 // indirect 27 | github.com/fatih/color v1.15.0 // indirect 28 | github.com/go-playground/validator/v10 v10.11.1 // indirect 29 | github.com/kr/pretty v0.3.1 // indirect 30 | github.com/lib/pq v1.10.9 // indirect 31 | github.com/mattn/go-colorable v0.1.13 // indirect 32 | github.com/mattn/go-isatty v0.0.20 // indirect 33 | github.com/pmezard/go-difflib v1.0.0 // indirect 34 | github.com/rogpeppe/go-internal v1.12.0 // indirect 35 | golang.org/x/sys v0.28.0 // indirect 36 | golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect 37 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 38 | ) 39 | -------------------------------------------------------------------------------- /grammar/generate.go: -------------------------------------------------------------------------------- 1 | package grammar 2 | 3 | //go:generate wget -nc https://www.antlr.org/download/antlr-4.13.2-complete.jar 4 | //go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/a30b3e2d7ec667a6da8fee083d7823b11768bd2c/grammar/SubstraitLexer.g4 5 | //go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/a30b3e2d7ec667a6da8fee083d7823b11768bd2c/grammar/SubstraitType.g4 6 | //go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/a30b3e2d7ec667a6da8fee083d7823b11768bd2c/grammar/FuncTestCaseLexer.g4 7 | //go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/3d2ff77575a7177f82a4d5b53408a059e9818922/grammar/FuncTestCaseParser.g4 8 | //go:generate -command antlr java -Xmx500M -cp "./antlr-4.13.2-complete.jar:$CLASSPATH" org.antlr.v4.Tool 9 | //go:generate antlr -Dlanguage=Go -visitor -Dlanguage=Go -package baseparser -o "../types/parser/baseparser" SubstraitLexer.g4 SubstraitType.g4 10 | //go:generate antlr -Dlanguage=Go -visitor -no-listener -Dlanguage=Go -package baseparser -o "../testcases/parser/baseparser" FuncTestCaseLexer.g4 FuncTestCaseParser.g4 11 | -------------------------------------------------------------------------------- /plan/common.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package plan 4 | 5 | import ( 6 | "github.com/substrait-io/substrait-go/v4/extensions" 7 | "github.com/substrait-io/substrait-go/v4/types" 8 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 9 | ) 10 | 11 | type ( 12 | Hint = proto.RelCommon_Hint 13 | Stats = proto.RelCommon_Hint_Stats 14 | RuntimeConstraint = proto.RelCommon_Hint_RuntimeConstraint 15 | ) 16 | 17 | // RelCommon is the common fields of all relational operators and is 18 | // embedded in all of them. 19 | type RelCommon struct { 20 | hint *Hint 21 | mapping []int32 22 | advExtension *extensions.AdvancedExtension 23 | } 24 | 25 | func (rc *RelCommon) fromProtoCommon(c *proto.RelCommon) { 26 | rc.hint = c.Hint 27 | rc.advExtension = c.AdvancedExtension 28 | 29 | if emit, ok := c.GetEmitKind().(*proto.RelCommon_Emit_); ok { 30 | rc.mapping = emit.Emit.OutputMapping 31 | } else { 32 | rc.mapping = nil 33 | } 34 | } 35 | 36 | func (rc *RelCommon) remap(initial types.RecordType) types.RecordType { 37 | if rc.mapping == nil { 38 | return initial 39 | } 40 | 41 | outTypes := make([]types.Type, len(rc.mapping)) 42 | 43 | for i, m := range rc.mapping { 44 | outTypes[i] = initial.GetFieldRef(m) 45 | } 46 | 47 | return *types.NewRecordTypeFromTypes(outTypes) 48 | } 49 | 50 | func (rc *RelCommon) OutputMapping() []int32 { 51 | if rc.mapping == nil { 52 | return nil 53 | } 54 | // Make a copy of the output mapping to prevent accidental modification. 55 | mapCopy := make([]int32, len(rc.mapping)) 56 | copy(mapCopy, rc.mapping) 57 | return mapCopy 58 | } 59 | 60 | func (rc *RelCommon) setMapping(mapping []int32) { 61 | rc.mapping = mapping 62 | } 63 | 64 | func (rc *RelCommon) GetAdvancedExtension() *extensions.AdvancedExtension { 65 | return rc.advExtension 66 | } 67 | 68 | func (rc *RelCommon) Hint() *Hint { 69 | return rc.hint 70 | } 71 | 72 | func (rc *RelCommon) toProto() *proto.RelCommon { 73 | ret := &proto.RelCommon{ 74 | Hint: rc.hint, 75 | AdvancedExtension: rc.advExtension, 76 | } 77 | 78 | if rc.mapping == nil { 79 | ret.EmitKind = &proto.RelCommon_Direct_{ 80 | Direct: &proto.RelCommon_Direct{}, 81 | } 82 | } else { 83 | ret.EmitKind = &proto.RelCommon_Emit_{ 84 | Emit: &proto.RelCommon_Emit{OutputMapping: rc.mapping}, 85 | } 86 | } 87 | return ret 88 | } 89 | -------------------------------------------------------------------------------- /plan/ctas_plan_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "embed" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/substrait-io/substrait-go/v4/expr" 10 | "github.com/substrait-io/substrait-go/v4/extensions" 11 | "github.com/substrait-io/substrait-go/v4/literal" 12 | "github.com/substrait-io/substrait-go/v4/plan" 13 | "github.com/substrait-io/substrait-go/v4/types" 14 | substraitproto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 15 | ) 16 | 17 | // Embed test JSON files for expected output comparison. 18 | // 19 | //go:embed testdata/*.json 20 | var testdata embed.FS 21 | 22 | // schema structures for testing purposes. 23 | var ( 24 | employeeSchema = types.NamedStruct{Names: []string{"employee_id", "name", "department_id", "salary", "role"}, 25 | Struct: types.StructType{ 26 | Nullability: types.NullabilityRequired, 27 | Types: []types.Type{ 28 | &types.Int32Type{Nullability: types.NullabilityRequired}, 29 | &types.StringType{Nullability: types.NullabilityNullable}, 30 | &types.Int32Type{Nullability: types.NullabilityNullable}, 31 | &types.DecimalType{Precision: 10, Scale: 2, Nullability: types.NullabilityNullable}, 32 | &types.StringType{Nullability: types.NullabilityNullable}, 33 | }, 34 | }} 35 | 36 | employeeSalariesSchema = types.NamedStruct{Names: []string{"name", "salary"}, 37 | Struct: types.StructType{ 38 | Types: []types.Type{ 39 | &types.StringType{Nullability: types.NullabilityNullable}, 40 | &types.DecimalType{Precision: 10, Scale: 2, Nullability: types.NullabilityNullable}, 41 | }, 42 | }} 43 | 44 | employeeSchemaNullable = types.NamedStruct{Names: []string{"employee_id", "name", "department_id", "salary", "role"}, 45 | Struct: types.StructType{ 46 | Types: []types.Type{ 47 | &types.Int32Type{Nullability: types.NullabilityNullable}, 48 | &types.StringType{Nullability: types.NullabilityNullable}, 49 | &types.Int32Type{Nullability: types.NullabilityNullable}, 50 | &types.DecimalType{Precision: 10, Scale: 2, Nullability: types.NullabilityNullable}, 51 | &types.StringType{Nullability: types.NullabilityNullable}, 52 | }, 53 | }} 54 | ) 55 | 56 | // makeProjectionMaskExpr generates a MaskExpression to project or reorder columns by the given IDs. 57 | func makeProjectionMaskExpr(columnIds []int) *expr.MaskExpression { 58 | structItems := make([]*substraitproto.Expression_MaskExpression_StructItem, len(columnIds)) 59 | 60 | for index, columnId := range columnIds { 61 | structItems[index] = &substraitproto.Expression_MaskExpression_StructItem{ 62 | Field: int32(columnId), 63 | } 64 | } 65 | 66 | return expr.MaskExpressionFromProto( 67 | &substraitproto.Expression_MaskExpression{ 68 | Select: &substraitproto.Expression_MaskExpression_StructSelect{ 69 | StructItems: structItems, 70 | }, 71 | MaintainSingularStruct: true, 72 | }, 73 | ) 74 | } 75 | 76 | // makeNamedTableReadRel creates a named table read relation with the selected column IDs. 77 | func makeNamedTableReadRel(b plan.Builder, tableNames []string, tableSchema types.NamedStruct, columnIds []int) plan.Rel { 78 | namedTableReadRel := b.NamedScan(tableNames, tableSchema) 79 | namedTableReadRel.SetProjection(makeProjectionMaskExpr(columnIds)) 80 | return namedTableReadRel 81 | } 82 | 83 | // makeConditionExprForLike constructs a LIKE condition expression for the specified column and value. 84 | func makeConditionExprForLike(t *testing.T, b plan.Builder, scan plan.Rel, colId int, valueLiteral expr.Literal) expr.Expression { 85 | id := extensions.ID{ 86 | URI: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml", 87 | Name: "contains:str_str", 88 | } 89 | b.GetFunctionRef(id.URI, id.Name) 90 | colIdRef, err := b.RootFieldRef(scan, int32(colId)) 91 | require.NoError(t, err) 92 | scalarExpr, err := b.ScalarFn(id.URI, id.Name, nil, colIdRef, valueLiteral) 93 | require.NoError(t, err) 94 | return scalarExpr 95 | } 96 | 97 | func makeFilterRel(t *testing.T, b plan.Builder, input plan.Rel, condition expr.Expression) plan.Rel { 98 | filterRel, err := b.Filter(input, condition) 99 | require.NoError(t, err) 100 | return filterRel 101 | } 102 | 103 | func makeProjectRel(t *testing.T, b plan.Builder, input plan.Rel, columnIds []int) plan.Rel { 104 | refs := make([]expr.Expression, len(columnIds)) 105 | for i, c := range columnIds { 106 | ref, err := b.RootFieldRef(input, int32(c)) 107 | require.NoError(t, err) 108 | refs[i] = ref 109 | } 110 | project, err := b.Project(input, refs...) 111 | require.NoError(t, err) 112 | return project 113 | } 114 | 115 | // getProjectionForTest1 returns project rel for "Select name, salary from employees" 116 | func getProjectionForTest1(t *testing.T, b plan.Builder) plan.Rel { 117 | namedScanRel := makeNamedTableReadRel(b, []string{"employees"}, employeeSchema, []int{1, 3}) 118 | return makeProjectRel(t, b, namedScanRel, []int{0, 1}) 119 | } 120 | 121 | // getProjectionForTest2 returns project rel for "Select * from employees where role LIKE 'Engineer'" 122 | func getProjectionForTest2(t *testing.T, b plan.Builder) plan.Rel { 123 | // scanRel outputs role, employee_id, name, department_id, salary 124 | namedScanRel := makeNamedTableReadRel(b, []string{"employees"}, employeeSchema, []int{4, 0, 1, 2, 3}) 125 | 126 | // column 0 from the output of namedScanRel is role 127 | // Build the filter with condition `role LIKE 'Engineer'` 128 | l := literal.NewString("Engineer", false) 129 | roleLikeEngineer := makeConditionExprForLike(t, b, namedScanRel, 1, l) 130 | filterRel := makeFilterRel(t, b, namedScanRel, roleLikeEngineer) 131 | 132 | // projectRel output employee_id, name, department_id, salary, role 133 | return makeProjectRel(t, b, filterRel, []int{1, 2, 3, 4, 0}) 134 | } 135 | 136 | // TestCreateTableAsSelectRoundTrip verifies that generated plans match the expected JSON. 137 | func TestCreateTableAsSelectRoundTrip(t *testing.T) { 138 | for _, td := range []struct { 139 | name string 140 | ctasTableName []string 141 | ctasTableSchema types.NamedStruct 142 | getProjection func(t *testing.T, b plan.Builder) plan.Rel 143 | }{ 144 | {"ctas_basic", []string{"main", "employee_salaries"}, employeeSalariesSchema, getProjectionForTest1}, 145 | {"ctas_with_filter", []string{"main", "filtered_employees"}, employeeSchemaNullable, getProjectionForTest2}, 146 | } { 147 | t.Run(td.name, func(t *testing.T) { 148 | // Load the expected JSON. This will be our baseline for comparison. 149 | expectedJson, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", td.name)) 150 | require.NoError(t, err) 151 | 152 | // build plan for CTAS 153 | b := plan.NewBuilderDefault() 154 | ctasRel, err := b.CreateTableAsSelect(td.getProjection(t, b), td.ctasTableName, td.ctasTableSchema) 155 | require.NoError(t, err) 156 | ctasPlan, err := b.Plan(ctasRel, td.ctasTableSchema.Names) 157 | require.NoError(t, err) 158 | 159 | // Check that the generated plan matches the expected JSON. 160 | checkRoundTrip(t, string(expectedJson), ctasPlan) 161 | }) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /plan/internal/helper.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "github.com/substrait-io/substrait-go/v4/expr" 5 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 6 | ) 7 | 8 | func VirtualTableExpressionFromProto(s *proto.Expression_Nested_Struct, reg expr.ExtensionRegistry) (expr.VirtualTableExpressionValue, error) { 9 | fields := make(expr.VirtualTableExpressionValue, len(s.Fields)) 10 | for i, f := range s.Fields { 11 | val, err := expr.ExprFromProto(f, nil, reg) 12 | if err != nil { 13 | return nil, err 14 | } 15 | fields[i] = val 16 | } 17 | return fields, nil 18 | } 19 | 20 | func VirtualTableExprFromLiteralProto(s *proto.Expression_Literal_Struct) expr.VirtualTableExpressionValue { 21 | fields := make(expr.VirtualTableExpressionValue, len(s.Fields)) 22 | for i, f := range s.Fields { 23 | fields[i] = expr.LiteralFromProto(f) 24 | } 25 | return fields 26 | } 27 | -------------------------------------------------------------------------------- /plan/internal/helper_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "github.com/substrait-io/substrait-go/v4/expr" 8 | ext "github.com/substrait-io/substrait-go/v4/extensions" 9 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 10 | "google.golang.org/protobuf/encoding/protojson" 11 | ) 12 | 13 | func TestVirtualTableExpressionFromProto(t *testing.T) { 14 | // define extensions with no plan for now 15 | const planExt = `{ 16 | "extensionUris": [ 17 | { 18 | "extensionUriAnchor": 1, 19 | "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" 20 | } 21 | ], 22 | "extensions": [ 23 | { 24 | "extensionFunction": { 25 | "extensionUriReference": 1, 26 | "functionAnchor": 2, 27 | "name": "add:i32_i32" 28 | } 29 | } 30 | ], 31 | "relations": [] 32 | }` 33 | 34 | var plan proto.Plan 35 | if err := protojson.Unmarshal([]byte(planExt), &plan); err != nil { 36 | panic(err) 37 | } 38 | 39 | // get the extension set 40 | extSet := ext.GetExtensionSet(&plan) 41 | literal1 := expr.NewPrimitiveLiteral(int32(1), false) 42 | expr1 := literal1.ToProto() 43 | 44 | reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError()) 45 | rows := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{ 46 | expr1, 47 | }} 48 | exprRows, err := VirtualTableExpressionFromProto(rows, reg) 49 | require.NoError(t, err) 50 | require.Len(t, exprRows, 1) 51 | } 52 | -------------------------------------------------------------------------------- /plan/named_write_plan_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | "github.com/substrait-io/substrait-go/v4/literal" 9 | "github.com/substrait-io/substrait-go/v4/plan" 10 | "github.com/substrait-io/substrait-go/v4/types" 11 | ) 12 | 13 | // getFilterForTest1 returns filter rel for "name LIKE 'Alice'" 14 | func getFilterForTest1(t *testing.T, b plan.Builder) plan.Rel { 15 | namedTableReadRel := b.NamedScan([]string{"employee_salaries"}, employeeSalariesSchema) 16 | 17 | // column 0 from the output of namedTableReadRel is name 18 | // Build the filter with condition `name LIKE 'Alice'` 19 | l := literal.NewString("Alice", false) 20 | nameLikeAlice := makeConditionExprForLike(t, b, namedTableReadRel, 0, l) 21 | return makeFilterRel(t, b, namedTableReadRel, nameLikeAlice) 22 | } 23 | 24 | // TestNamedTableInsertRoundTrip verifies that generated plans match the expected JSON. 25 | func TestNamedTableInsertRoundTrip(t *testing.T) { 26 | for _, td := range []struct { 27 | name string 28 | tableName []string 29 | tableSchema types.NamedStruct 30 | getInputRel func(t *testing.T, b plan.Builder) plan.Rel 31 | }{ 32 | {"insert_from_select", []string{"main", "employee_salaries"}, employeeSalariesSchema, getProjectionForTest1}, 33 | } { 34 | t.Run(td.name, func(t *testing.T) { 35 | // Load the expected JSON. This will be our baseline for comparison. 36 | expectedJson, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", td.name)) 37 | require.NoError(t, err) 38 | 39 | // build plan for Insert 40 | b := plan.NewBuilderDefault() 41 | namedInsertRel, err := b.NamedInsert(td.getInputRel(t, b), td.tableName, td.tableSchema) 42 | require.NoError(t, err) 43 | namedInsertPlan, err := b.Plan(namedInsertRel, nil) 44 | require.NoError(t, err) 45 | 46 | // Check that the generated plan matches the expected JSON. 47 | checkRoundTrip(t, string(expectedJson), namedInsertPlan) 48 | }) 49 | } 50 | } 51 | 52 | // TestNamedTableDeleteRoundTrip verifies that generated plans match the expected JSON. 53 | func TestNamedTableDeleteRoundTrip(t *testing.T) { 54 | for _, td := range []struct { 55 | name string 56 | tableName []string 57 | tableSchema types.NamedStruct 58 | getInputRel func(t *testing.T, b plan.Builder) plan.Rel 59 | }{ 60 | {"delete_with_filter", []string{"main", "employee_salaries"}, employeeSalariesSchema, getFilterForTest1}, 61 | } { 62 | t.Run(td.name, func(t *testing.T) { 63 | // Load the expected JSON. This will be our baseline for comparison. 64 | expectedJson, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", td.name)) 65 | require.NoError(t, err) 66 | 67 | // build plan for Delete 68 | b := plan.NewBuilderDefault() 69 | namedDeleteRel, err := b.NamedDelete(td.getInputRel(t, b), td.tableName, td.tableSchema) 70 | require.NoError(t, err) 71 | namedDeletePlan, err := b.Plan(namedDeleteRel, nil) 72 | require.NoError(t, err) 73 | 74 | // Check that the generated plan matches the expected JSON. 75 | checkRoundTrip(t, string(expectedJson), namedDeletePlan) 76 | }) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /plan/plan_test.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/stretchr/testify/require" 8 | "github.com/substrait-io/substrait-go/v4/expr" 9 | "github.com/substrait-io/substrait-go/v4/extensions" 10 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 11 | "google.golang.org/protobuf/testing/protocmp" 12 | ) 13 | 14 | func TestRelFromProto(t *testing.T) { 15 | 16 | registry := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) 17 | literal5 := &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_I64{I64: 5}} 18 | exprLiteral5 := &proto.Expression{RexType: &proto.Expression_Literal_{Literal: literal5}} 19 | 20 | nestedStructExpr1 := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{exprLiteral5}} 21 | virtualTableWithExpression := &proto.ReadRel_VirtualTable_{VirtualTable: &proto.ReadRel_VirtualTable{Expressions: []*proto.Expression_Nested_Struct{nestedStructExpr1}}} 22 | readRelWithExpression := &proto.ReadRel{ReadType: virtualTableWithExpression} 23 | 24 | literalStruct := &proto.Expression_Literal_Struct{Fields: []*proto.Expression_Literal{literal5}} 25 | virtualTableWithLiteral := &proto.ReadRel_VirtualTable_{VirtualTable: &proto.ReadRel_VirtualTable{Values: []*proto.Expression_Literal_Struct{literalStruct}}} 26 | readRelWithLiteral := &proto.ReadRel{ReadType: virtualTableWithLiteral} 27 | 28 | for _, td := range []struct { 29 | name string 30 | readType *proto.ReadRel 31 | }{ 32 | {"virtual table with expression", readRelWithExpression}, 33 | {"virtual table with deprecated literal", readRelWithLiteral}, 34 | } { 35 | t.Run(td.name, func(t *testing.T) { 36 | rel := &proto.Rel{RelType: &proto.Rel_Read{Read: td.readType}} 37 | 38 | outRel, err := RelFromProto(rel, registry) 39 | require.NoError(t, err) 40 | gotRel := outRel.ToProto() 41 | gotReadRel, ok := gotRel.RelType.(*proto.Rel_Read) 42 | require.True(t, ok) 43 | gotVirtualTableReadRel, ok := gotReadRel.Read.ReadType.(*proto.ReadRel_VirtualTable_) 44 | require.True(t, ok) 45 | // in case of both deprecated or new expression, the output should be the same as the new expression 46 | if diff := cmp.Diff(gotVirtualTableReadRel, virtualTableWithExpression, protocmp.Transform()); diff != "" { 47 | t.Errorf("expression proto didn't match, diff:\n%v", diff) 48 | } 49 | }) 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /plan/testdata/ctas_basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "relations": [ 3 | { 4 | "root": { 5 | "input": { 6 | "write": { 7 | "common": {"direct": {}}, 8 | "namedTable": { 9 | "names": [ 10 | "main", 11 | "employee_salaries" 12 | ] 13 | }, 14 | "tableSchema": { 15 | "names": [ 16 | "name", 17 | "salary" 18 | ], 19 | "struct": { 20 | "types": [ 21 | { 22 | "string": { 23 | "nullability": "NULLABILITY_NULLABLE" 24 | } 25 | }, 26 | { 27 | "decimal": { 28 | "scale": 2, 29 | "precision": 10, 30 | "nullability": "NULLABILITY_NULLABLE" 31 | } 32 | } 33 | ] 34 | } 35 | }, 36 | "op": "WRITE_OP_CTAS", 37 | "input": { 38 | "project": { 39 | "common": {"direct": {}}, 40 | "input": { 41 | "read": { 42 | "common": {"direct": {}}, 43 | "baseSchema": { 44 | "names": [ 45 | "employee_id", 46 | "name", 47 | "department_id", 48 | "salary", 49 | "role" 50 | ], 51 | "struct": { 52 | "types": [ 53 | { 54 | "i32": { 55 | "nullability": "NULLABILITY_REQUIRED" 56 | } 57 | }, 58 | { 59 | "string": { 60 | "nullability": "NULLABILITY_NULLABLE" 61 | } 62 | }, 63 | { 64 | "i32": { 65 | "nullability": "NULLABILITY_NULLABLE" 66 | } 67 | }, 68 | { 69 | "decimal": { 70 | "scale": 2, 71 | "precision": 10, 72 | "nullability": "NULLABILITY_NULLABLE" 73 | } 74 | }, 75 | { 76 | "string": { 77 | "nullability": "NULLABILITY_NULLABLE" 78 | } 79 | } 80 | ], 81 | "nullability": "NULLABILITY_REQUIRED" 82 | } 83 | }, 84 | "projection": { 85 | "select": { 86 | "structItems": [ 87 | { 88 | "field": 1 89 | }, 90 | { 91 | "field": 3 92 | } 93 | ] 94 | }, 95 | "maintainSingularStruct": true 96 | }, 97 | "namedTable": { 98 | "names": [ 99 | "employees" 100 | ] 101 | } 102 | } 103 | }, 104 | "expressions": [ 105 | { 106 | "selection": { 107 | "directReference": { 108 | "structField": { 109 | "field": 0 110 | } 111 | }, 112 | "rootReference": {} 113 | } 114 | }, 115 | { 116 | "selection": { 117 | "directReference": { 118 | "structField": { 119 | "field": 1 120 | } 121 | }, 122 | "rootReference": {} 123 | } 124 | } 125 | ] 126 | } 127 | } 128 | } 129 | }, 130 | "names": [ 131 | "name", 132 | "salary" 133 | ] 134 | } 135 | } 136 | ], 137 | "version": { 138 | "majorNumber": 0, 139 | "minorNumber": 29, 140 | "patchNumber": 0, 141 | "producer": "substrait-go" 142 | } 143 | } -------------------------------------------------------------------------------- /plan/testdata/delete_with_filter.json: -------------------------------------------------------------------------------- 1 | { 2 | "extensionUris":[ 3 | { 4 | "extensionUriAnchor":1, 5 | "uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml" 6 | } 7 | ], 8 | "extensions":[ 9 | { 10 | "extensionFunction":{ 11 | "extensionUriReference":1, 12 | "functionAnchor":1, 13 | "name":"contains:str_str" 14 | } 15 | } 16 | ], 17 | "relations":[ 18 | { 19 | "root":{ 20 | "input":{ 21 | "write":{ 22 | "common":{ 23 | "direct":{ 24 | 25 | } 26 | }, 27 | "namedTable":{ 28 | "names":[ 29 | "main", 30 | "employee_salaries" 31 | ] 32 | }, 33 | "tableSchema":{ 34 | "names":[ 35 | "name", 36 | "salary" 37 | ], 38 | "struct":{ 39 | "types":[ 40 | { 41 | "string":{ 42 | "nullability":"NULLABILITY_NULLABLE" 43 | } 44 | }, 45 | { 46 | "decimal":{ 47 | "scale":2, 48 | "precision":10, 49 | "nullability":"NULLABILITY_NULLABLE" 50 | } 51 | } 52 | ] 53 | } 54 | }, 55 | "op":"WRITE_OP_DELETE", 56 | "input":{ 57 | "filter":{ 58 | "common":{ 59 | "direct":{ 60 | 61 | } 62 | }, 63 | "input":{ 64 | "read":{ 65 | "baseSchema":{ 66 | "names":[ 67 | "name", 68 | "salary" 69 | ], 70 | "struct":{ 71 | "types":[ 72 | { 73 | "string":{ 74 | "nullability":"NULLABILITY_NULLABLE" 75 | } 76 | }, 77 | { 78 | "decimal":{ 79 | "scale":2, 80 | "precision":10, 81 | "nullability":"NULLABILITY_NULLABLE" 82 | } 83 | } 84 | ] 85 | } 86 | }, 87 | "common":{ 88 | "direct":{ 89 | 90 | } 91 | }, 92 | "namedTable":{ 93 | "names":[ 94 | "employee_salaries" 95 | ] 96 | } 97 | } 98 | }, 99 | "condition":{ 100 | "scalarFunction":{ 101 | "functionReference":1, 102 | "outputType":{ 103 | "bool":{ 104 | "nullability":"NULLABILITY_NULLABLE" 105 | } 106 | }, 107 | "arguments":[ 108 | { 109 | "value":{ 110 | "selection":{ 111 | "directReference":{ 112 | "structField":{ 113 | "field":0 114 | } 115 | }, 116 | "rootReference":{ 117 | 118 | } 119 | } 120 | } 121 | }, 122 | { 123 | "value":{ 124 | "literal":{ 125 | "string":"Alice" 126 | } 127 | } 128 | } 129 | ] 130 | } 131 | } 132 | } 133 | } 134 | } 135 | } 136 | } 137 | } 138 | ], 139 | "version":{ 140 | "majorNumber":0, 141 | "minorNumber":29, 142 | "patchNumber":0, 143 | "producer":"substrait-go" 144 | } 145 | } -------------------------------------------------------------------------------- /plan/testdata/insert_from_select.json: -------------------------------------------------------------------------------- 1 | { 2 | "relations":[ 3 | { 4 | "root":{ 5 | "input":{ 6 | "write":{ 7 | "common": { 8 | "direct": { 9 | } 10 | }, 11 | "namedTable":{ 12 | "names":[ 13 | "main", 14 | "employee_salaries" 15 | ] 16 | }, 17 | "tableSchema":{ 18 | "names":[ 19 | "name", 20 | "salary" 21 | ], 22 | "struct":{ 23 | "types":[ 24 | { 25 | "string":{ 26 | "nullability":"NULLABILITY_NULLABLE" 27 | } 28 | }, 29 | { 30 | "decimal":{ 31 | "scale":2, 32 | "precision":10, 33 | "nullability":"NULLABILITY_NULLABLE" 34 | } 35 | } 36 | ] 37 | } 38 | }, 39 | "op":"WRITE_OP_INSERT", 40 | "input":{ 41 | "project":{ 42 | "common": { 43 | "direct": { 44 | } 45 | }, 46 | "input":{ 47 | "read":{ 48 | "common": { 49 | "direct": { 50 | } 51 | }, 52 | "baseSchema":{ 53 | "names":[ 54 | "employee_id", 55 | "name", 56 | "department_id", 57 | "salary", 58 | "role" 59 | ], 60 | "struct":{ 61 | "types":[ 62 | { 63 | "i32":{ 64 | "nullability":"NULLABILITY_REQUIRED" 65 | } 66 | }, 67 | { 68 | "string":{ 69 | "nullability":"NULLABILITY_NULLABLE" 70 | } 71 | }, 72 | { 73 | "i32":{ 74 | "nullability":"NULLABILITY_NULLABLE" 75 | } 76 | }, 77 | { 78 | "decimal":{ 79 | "scale":2, 80 | "precision":10, 81 | "nullability":"NULLABILITY_NULLABLE" 82 | } 83 | }, 84 | { 85 | "string":{ 86 | "nullability":"NULLABILITY_NULLABLE" 87 | } 88 | } 89 | ], 90 | "nullability":"NULLABILITY_REQUIRED" 91 | } 92 | }, 93 | "projection":{ 94 | "select":{ 95 | "structItems":[ 96 | { 97 | "field":1 98 | }, 99 | { 100 | "field":3 101 | } 102 | ] 103 | }, 104 | "maintainSingularStruct":true 105 | }, 106 | "namedTable":{ 107 | "names":[ 108 | "employees" 109 | ] 110 | } 111 | } 112 | }, 113 | "expressions":[ 114 | { 115 | "selection":{ 116 | "directReference":{ 117 | "structField":{ 118 | "field":0 119 | } 120 | }, 121 | "rootReference":{ 122 | 123 | } 124 | } 125 | }, 126 | { 127 | "selection":{ 128 | "directReference":{ 129 | "structField":{ 130 | "field":1 131 | } 132 | }, 133 | "rootReference":{ 134 | 135 | } 136 | } 137 | } 138 | ] 139 | } 140 | } 141 | } 142 | } 143 | } 144 | } 145 | ], 146 | "version": { 147 | "majorNumber": 0, 148 | "minorNumber": 29, 149 | "patchNumber": 0, 150 | "producer": "substrait-go" 151 | } 152 | } -------------------------------------------------------------------------------- /plan/testdata/value_with_literal.json: -------------------------------------------------------------------------------- 1 | { 2 | "relations": [ 3 | { 4 | "root": { 5 | "input": { 6 | "read": { 7 | "baseSchema": { 8 | "names": [ 9 | "col0", 10 | "col1" 11 | ], 12 | "struct": { 13 | "nullability": "NULLABILITY_REQUIRED", 14 | "types": [ 15 | { 16 | "i32": { 17 | "nullability": "NULLABILITY_REQUIRED" 18 | } 19 | }, 20 | { 21 | "i32": { 22 | "nullability": "NULLABILITY_REQUIRED" 23 | } 24 | } 25 | ] 26 | } 27 | }, 28 | "common": { 29 | "direct": {} 30 | }, 31 | "virtualTable": { 32 | "expressions": [ 33 | { 34 | "fields": [ 35 | { 36 | "literal": { 37 | "i32": 1 38 | } 39 | }, 40 | { 41 | "literal": { 42 | "i32": 2 43 | } 44 | } 45 | ] 46 | } 47 | ] 48 | } 49 | } 50 | }, 51 | "names": [ 52 | "col0", 53 | "col1" 54 | ] 55 | } 56 | } 57 | ], 58 | "version": { 59 | "majorNumber": 0, 60 | "minorNumber": 29, 61 | "patchNumber": 0, 62 | "producer": "substrait-go" 63 | } 64 | } -------------------------------------------------------------------------------- /plan/testdata/value_with_scalar.json: -------------------------------------------------------------------------------- 1 | { 2 | "extensionUris": [ 3 | { 4 | "extensionUriAnchor": 1, 5 | "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" 6 | } 7 | ], 8 | "extensions": [ 9 | { 10 | "extensionFunction": { 11 | "extensionUriReference": 1, 12 | "functionAnchor": 1, 13 | "name": "add:i32_i32" 14 | } 15 | } 16 | ], 17 | "relations": [ 18 | { 19 | "root": { 20 | "input": { 21 | "read": { 22 | "baseSchema": { 23 | "names": [ 24 | "col0", 25 | "col1" 26 | ], 27 | "struct": { 28 | "nullability": "NULLABILITY_REQUIRED", 29 | "types": [ 30 | { 31 | "i32": { 32 | "nullability": "NULLABILITY_REQUIRED" 33 | } 34 | }, 35 | { 36 | "i32": { 37 | "nullability": "NULLABILITY_REQUIRED" 38 | } 39 | } 40 | ] 41 | } 42 | }, 43 | "common": { 44 | "direct": { 45 | } 46 | }, 47 | "virtualTable": { 48 | "expressions": [ 49 | { 50 | "fields": [ 51 | { 52 | "scalarFunction": { 53 | "arguments": [ 54 | { 55 | "value": { 56 | "literal": { 57 | "i32": 1 58 | } 59 | } 60 | }, 61 | { 62 | "value": { 63 | "literal": { 64 | "i32": 1 65 | } 66 | } 67 | } 68 | ], 69 | "functionReference": 1, 70 | "outputType": { 71 | "i32": { 72 | "nullability": "NULLABILITY_REQUIRED" 73 | } 74 | } 75 | } 76 | }, 77 | { 78 | "scalarFunction": { 79 | "arguments": [ 80 | { 81 | "value": { 82 | "literal": { 83 | "i32": 2 84 | } 85 | } 86 | }, 87 | { 88 | "value": { 89 | "literal": { 90 | "i32": 2 91 | } 92 | } 93 | } 94 | ], 95 | "functionReference": 1, 96 | "outputType": { 97 | "i32": { 98 | "nullability": "NULLABILITY_REQUIRED" 99 | } 100 | } 101 | } 102 | } 103 | ] 104 | } 105 | ] 106 | } 107 | } 108 | }, 109 | "names": [ 110 | "col0", 111 | "col1" 112 | ] 113 | } 114 | } 115 | ], 116 | "version": { 117 | "majorNumber": 0, 118 | "minorNumber": 29, 119 | "patchNumber": 0, 120 | "producer": "substrait-go" 121 | } 122 | } -------------------------------------------------------------------------------- /plan/virtual_table_from_expr_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | "github.com/substrait-io/substrait-go/v4/expr" 9 | "github.com/substrait-io/substrait-go/v4/extensions" 10 | "github.com/substrait-io/substrait-go/v4/plan" 11 | "github.com/substrait-io/substrait-go/v4/types" 12 | ) 13 | 14 | var ( 15 | v1 = expr.PrimitiveLiteral[int32]{Value: 1, Type: &types.Int32Type{Nullability: types.NullabilityRequired}} 16 | v2 = expr.PrimitiveLiteral[int32]{Value: 2, Type: &types.Int32Type{Nullability: types.NullabilityRequired}} 17 | ) 18 | 19 | // makeAddExpr constructs expression val1 + val2. 20 | func makeAddExpr(t *testing.T, b plan.Builder, val1, val2 expr.Literal) expr.Expression { 21 | id := extensions.ID{ 22 | URI: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml", 23 | Name: "add:i32_i32", 24 | } 25 | b.GetFunctionRef(id.URI, id.Name) 26 | scalarExpr, err := b.ScalarFn(id.URI, id.Name, nil, val1, val2) 27 | require.NoError(t, err) 28 | return scalarExpr 29 | } 30 | 31 | func buildLiteralExpressions(_ *testing.T, _ plan.Builder) []expr.VirtualTableExpressionValue { 32 | return []expr.VirtualTableExpressionValue{{&v1, &v2}} 33 | } 34 | 35 | // buildScalarAddExpression builds a scalar binary add expression 36 | func buildScalarAddExpression(t *testing.T, b plan.Builder) []expr.VirtualTableExpressionValue { 37 | s1 := makeAddExpr(t, b, &v1, &v1) 38 | s2 := makeAddExpr(t, b, &v2, &v2) 39 | return []expr.VirtualTableExpressionValue{{s1, s2}} 40 | } 41 | 42 | // TestNamedTableInsertRoundTrip verifies that generated plans match the expected JSON. 43 | func TestVirtualTableFromExprRoundTrip(t *testing.T) { 44 | for _, td := range []struct { 45 | name string 46 | fieldNames []string 47 | buildExprForTest func(t *testing.T, b plan.Builder) []expr.VirtualTableExpressionValue 48 | }{ 49 | {"value_with_literal", []string{"col0", "col1"}, buildLiteralExpressions}, 50 | {"value_with_scalar", []string{"col0", "col1"}, buildScalarAddExpression}, 51 | } { 52 | t.Run(td.name, func(t *testing.T) { 53 | // Load the expected JSON. This will be our baseline for comparison. 54 | expectedJson, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", td.name)) 55 | require.NoError(t, err) 56 | 57 | // build plan for Project with virtual table 58 | b := plan.NewBuilderDefault() 59 | valueExpr := td.buildExprForTest(t, b) 60 | virtualTableExpr, err := b.VirtualTableFromExpr(td.fieldNames, valueExpr...) 61 | require.NoError(t, err) 62 | virtualTablePlan, err := b.Plan(virtualTableExpr, td.fieldNames) 63 | require.NoError(t, err) 64 | 65 | // Check that the generated plan matches the expected JSON. 66 | checkRoundTrip(t, string(expectedJson), virtualTablePlan) 67 | }) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /testcases/parser/parse.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "io/fs" 7 | 8 | "github.com/antlr4-go/antlr/v4" 9 | "github.com/substrait-io/substrait-go/v4/testcases/parser/baseparser" 10 | "github.com/substrait-io/substrait-go/v4/types/parser/util" 11 | ) 12 | 13 | func ParseTestCaseFileFromFS(fs fs.FS, s string) (*TestFile, error) { 14 | file, err := fs.Open(s) 15 | if err != nil { 16 | return nil, err 17 | } 18 | defer file.Close() 19 | return ParseTestCasesFromFile(file) 20 | } 21 | 22 | func ParseTestCasesFromFile(input fs.File) (*TestFile, error) { 23 | buf, err := io.ReadAll(input) 24 | if err != nil { 25 | return nil, err 26 | } 27 | is := antlr.NewInputStream(string(buf)) 28 | return parseTestCasesFromStream(is, fmt.Sprintf("file %s", input)) 29 | } 30 | 31 | func ParseTestCasesFromString(input string) (*TestFile, error) { 32 | is := antlr.NewInputStream(input) 33 | return parseTestCasesFromStream(is, input) 34 | } 35 | 36 | func parseTestCasesFromStream(is *antlr.InputStream, debugStr string) (*TestFile, error) { 37 | lexer := baseparser.NewFuncTestCaseLexer(is) 38 | stream := antlr.NewCommonTokenStream(lexer, 0) 39 | p := baseparser.NewFuncTestCaseParser(stream) 40 | errorListener := util.NewSimpleErrorListener() 41 | p.AddErrorListener(errorListener) 42 | p.GetInterpreter().SetPredictionMode(antlr.PredictionModeSLL) 43 | 44 | testFile, err := parseTestCases(p, errorListener, debugStr) 45 | if errorListener.ErrorCount() > 0 { 46 | return nil, fmt.Errorf("error parsing input '%s': %s", debugStr, errorListener.GetErrors()) 47 | } 48 | return testFile, err 49 | } 50 | 51 | func parseTestCases(p *baseparser.FuncTestCaseParser, errorListener util.VisitErrorListener, debugStr string) (*TestFile, error) { 52 | var err error 53 | defer util.TransformPanicToError(&err, debugStr, "ParseExpr", errorListener) 54 | 55 | visitor := &TestCaseVisitor{ErrorListener: errorListener} 56 | context := p.Doc() 57 | if errorListener.ErrorCount() > 0 { 58 | fmt.Printf("ParseTree: %v", antlr.TreesStringTree(context, []string{}, p)) 59 | return nil, fmt.Errorf("error parsing input '%s': %s", debugStr, errorListener.GetErrors()) 60 | } 61 | ret := visitor.Visit(context) 62 | if errorListener.ErrorCount() > 0 { 63 | return nil, fmt.Errorf("error parsing input '%s': %s", debugStr, errorListener.GetErrors()) 64 | } 65 | retType, ok := ret.(*TestFile) 66 | if !ok { 67 | return nil, fmt.Errorf("failed to parse %s as FuncDefArgType", debugStr) 68 | } 69 | return retType, err 70 | } 71 | -------------------------------------------------------------------------------- /types/any_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | "slices" 8 | ) 9 | 10 | // AnyType to represent AnyType, this type is to indicate "any" type of argument 11 | // This type is not used in function invocation. It is only used in function definition 12 | type AnyType struct { 13 | Name string 14 | TypeVariationRef uint32 15 | Nullability Nullability 16 | } 17 | 18 | func (m *AnyType) SetNullability(n Nullability) FuncDefArgType { 19 | m.Nullability = n 20 | return m 21 | } 22 | 23 | func (m *AnyType) String() string { 24 | return fmt.Sprintf("%s%s", m.Name, strFromNullability(m.Nullability)) 25 | } 26 | 27 | func (m *AnyType) HasParameterizedParam() bool { 28 | // primitive type doesn't have abstract parameters 29 | return false 30 | } 31 | 32 | func (m *AnyType) GetParameterizedParams() []interface{} { 33 | // any type doesn't have any abstract parameters 34 | return nil 35 | } 36 | 37 | func (m *AnyType) MatchWithNullability(ot Type) bool { 38 | return m.Nullability == ot.GetNullability() 39 | } 40 | 41 | func (m *AnyType) MatchWithoutNullability(ot Type) bool { 42 | return true 43 | } 44 | 45 | func (m *AnyType) GetNullability() Nullability { 46 | return m.Nullability 47 | } 48 | 49 | func (m *AnyType) ShortString() string { 50 | return "any" 51 | } 52 | 53 | // unwrapAnyTypeWithName searches for AnyType in p with the specified name, 54 | // and if found, returns argType. If p is a composite type, 55 | // recursively unwraps p to search for AnyType in p's parameters. 56 | // Returns nil Type if AnyType was not found. 57 | func unwrapAnyTypeWithName(name string, p FuncDefArgType, argType Type) (Type, error) { 58 | switch arg := p.(type) { 59 | case *AnyType: 60 | if arg.Name == name { 61 | return argType, nil 62 | } 63 | case *ParameterizedListType: 64 | argParams := argType.GetParameters() 65 | if len(argParams) != 1 || argParams[0] == nil { 66 | return nil, fmt.Errorf( 67 | "expected ListType to have non-nil 1 parameter, found %v", argParams) 68 | } 69 | return unwrapAnyTypeWithName(name, arg.Type, argParams[0].(Type)) 70 | case *ParameterizedMapType: 71 | argParams := argType.GetParameters() 72 | if len(argParams) != 2 || argParams[0] == nil || argParams[1] == nil { 73 | return nil, fmt.Errorf( 74 | "expected MapType to have 2 non-nil parameters, found %v", argParams) 75 | } 76 | keyType, err := unwrapAnyTypeWithName(name, arg.Key, argParams[0].(Type)) 77 | if err != nil { 78 | return nil, err 79 | } 80 | if keyType != nil { 81 | return keyType, nil 82 | } 83 | return unwrapAnyTypeWithName(name, arg.Value, argParams[1].(Type)) 84 | case *ParameterizedStructType: 85 | argParams := argType.GetParameters() 86 | if len(argParams) != len(arg.Types) || slices.Contains(argParams, nil) { 87 | return nil, fmt.Errorf("expected StructType to have %d non-nil parameters, found %v", 88 | len(arg.Types), argParams) 89 | } 90 | for i, param := range argParams { 91 | pt, err := unwrapAnyTypeWithName(name, arg.Types[i], param.(Type)) 92 | if err != nil { 93 | return nil, err 94 | } 95 | if pt != nil { 96 | return pt, nil 97 | } 98 | } 99 | } 100 | // Didn't find matching AnyType. 101 | return nil, nil 102 | } 103 | 104 | func (m *AnyType) ReturnType(funcParameters []FuncDefArgType, argumentTypes []Type) (Type, error) { 105 | // iterate through smaller of the funcParameters and argumentTypes; 106 | // argumentTypes may be larger than funcParameters due to variadic parameters. 107 | for i := 0; i < min(len(funcParameters), len(argumentTypes)); i++ { 108 | typ, err := unwrapAnyTypeWithName(m.Name, funcParameters[i], argumentTypes[i]) 109 | if err != nil { 110 | return nil, err 111 | } 112 | if typ != nil { 113 | return typ, nil 114 | } 115 | } 116 | 117 | return nil, fmt.Errorf("no matching any type found in function parameters") 118 | } 119 | 120 | func (m *AnyType) WithParameters([]interface{}) (Type, error) { 121 | return nil, fmt.Errorf("any type doesn't have any parameters") 122 | } 123 | -------------------------------------------------------------------------------- /types/any_type_test.go: -------------------------------------------------------------------------------- 1 | package types_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | . "github.com/substrait-io/substrait-go/v4/types" 8 | ) 9 | 10 | func TestAnyType(t *testing.T) { 11 | decP30S9 := &DecimalType{Precision: 30, Scale: 9, Nullability: NullabilityRequired} 12 | varchar37 := &VarCharType{Length: 37} 13 | for _, td := range []struct { 14 | testName string 15 | argName string 16 | parameters []FuncDefArgType 17 | args []Type 18 | concreteReturnType Type 19 | nullability Nullability 20 | expectedString string 21 | expectedErr string 22 | }{ 23 | { 24 | testName: "any", 25 | argName: "any", 26 | parameters: []FuncDefArgType{&AnyType{Name: "any"}}, 27 | args: []Type{decP30S9}, 28 | concreteReturnType: decP30S9, 29 | nullability: NullabilityNullable, 30 | expectedString: "any?", 31 | }, 32 | { 33 | testName: "anyrequired", 34 | argName: "any2", 35 | parameters: []FuncDefArgType{&Int16Type{}, &AnyType{Name: "any2"}}, 36 | args: []Type{&Int16Type{}, &Int64Type{}}, 37 | concreteReturnType: &Int64Type{}, 38 | nullability: NullabilityRequired, 39 | expectedString: "any2", 40 | }, 41 | { 42 | testName: "list", 43 | argName: "any1", 44 | parameters: []FuncDefArgType{ 45 | &Int16Type{}, 46 | &ParameterizedListType{Type: &AnyType{Name: "any1"}}, 47 | }, 48 | args: []Type{&Int16Type{}, &ListType{Type: &Int64Type{}}}, 49 | concreteReturnType: &Int64Type{}, 50 | nullability: NullabilityRequired, 51 | expectedString: "any1", 52 | }, 53 | { 54 | testName: "wrong_list", 55 | argName: "any1", 56 | parameters: []FuncDefArgType{ 57 | &Int16Type{}, 58 | &ParameterizedListType{Type: &AnyType{Name: "any1"}}, 59 | }, 60 | args: []Type{&Int16Type{}, &ListType{}}, 61 | nullability: NullabilityRequired, 62 | expectedString: "any1", 63 | expectedErr: "expected ListType to have non-nil 1 parameter, found []", 64 | }, 65 | { 66 | testName: "map", 67 | argName: "any1", 68 | parameters: []FuncDefArgType{ 69 | &Int16Type{}, 70 | &ParameterizedMapType{Key: &StringType{}, Value: &AnyType{Name: "any1"}}, 71 | }, 72 | args: []Type{ 73 | &Int16Type{}, 74 | &MapType{Key: &StringType{}, Value: &ListType{Type: &Int64Type{}}}, 75 | }, 76 | concreteReturnType: &ListType{Type: &Int64Type{}}, 77 | nullability: NullabilityRequired, 78 | expectedString: "any1", 79 | }, 80 | { 81 | testName: "map", 82 | argName: "any2", 83 | parameters: []FuncDefArgType{ 84 | &Int16Type{}, 85 | &ParameterizedMapType{Key: &AnyType{Name: "any2"}, Value: &StringType{}}, 86 | }, 87 | args: []Type{ 88 | &Int16Type{}, 89 | &MapType{Key: &ListType{Type: &Int64Type{}}, Value: &StringType{}}, 90 | }, 91 | concreteReturnType: &ListType{Type: &Int64Type{}}, 92 | nullability: NullabilityRequired, 93 | expectedString: "any2", 94 | }, 95 | { 96 | testName: "wrong_map", 97 | argName: "any2", 98 | parameters: []FuncDefArgType{ 99 | &Int16Type{}, 100 | &ParameterizedMapType{Key: &AnyType{Name: "any2"}, Value: &StringType{}}, 101 | }, 102 | args: []Type{ 103 | &Int16Type{}, 104 | &MapType{Value: &StringType{}}, 105 | }, 106 | nullability: NullabilityRequired, 107 | expectedString: "any2", 108 | expectedErr: "expected MapType to have 2 non-nil parameters, found [ string]", 109 | }, 110 | { 111 | testName: "struct", 112 | argName: "any1", 113 | parameters: []FuncDefArgType{ 114 | &ParameterizedStructType{Types: []FuncDefArgType{ 115 | &StringType{}, &AnyType{Name: "any1"}, &Int64Type{}, 116 | }}, 117 | }, 118 | args: []Type{ 119 | &StructType{Types: []Type{&StringType{}, &VarCharType{Length: 37}, &Int64Type{}}}, 120 | }, 121 | concreteReturnType: &VarCharType{Length: 37}, 122 | nullability: NullabilityRequired, 123 | expectedString: "any1", 124 | }, 125 | { 126 | testName: "wrong_struct", 127 | argName: "any1", 128 | parameters: []FuncDefArgType{ 129 | &ParameterizedStructType{Types: []FuncDefArgType{ 130 | &StringType{}, &AnyType{Name: "any1"}, &Int64Type{}, 131 | }}, 132 | }, 133 | args: []Type{ 134 | &StructType{Types: []Type{&StringType{}, &Int64Type{}, nil, &Int64Type{}}}, 135 | }, 136 | expectedString: "any1", 137 | expectedErr: "expected StructType to have 3 non-nil parameters, found [string i64 i64]", 138 | }, 139 | { 140 | testName: "anyOtherName", 141 | argName: "any1", 142 | parameters: []FuncDefArgType{&AnyType{Name: "any1"}, &Int32Type{}}, 143 | args: []Type{varchar37, &Int32Type{}}, 144 | concreteReturnType: varchar37, 145 | nullability: NullabilityNullable, 146 | expectedString: "any1?", 147 | }, 148 | { 149 | testName: "T name", 150 | argName: "T", 151 | parameters: []FuncDefArgType{&AnyType{Name: "U"}}, 152 | args: []Type{varchar37}, 153 | nullability: NullabilityNullable, 154 | expectedString: "T?", 155 | }, 156 | } { 157 | t.Run(td.testName, func(t *testing.T) { 158 | anyBase := &AnyType{ 159 | Name: td.argName, 160 | Nullability: td.nullability, 161 | } 162 | anyType := anyBase.SetNullability(td.nullability) 163 | require.Equal(t, td.nullability, anyType.GetNullability()) 164 | require.Equal(t, "any", anyType.ShortString()) 165 | require.Equal(t, td.expectedString, anyType.String()) 166 | returnType, err := anyType.ReturnType(td.parameters, td.args) 167 | if td.concreteReturnType != nil { 168 | require.NoError(t, err) 169 | require.Equal(t, td.concreteReturnType, returnType) 170 | } else { 171 | require.Error(t, err) 172 | if td.expectedErr != "" { 173 | require.Equal(t, td.expectedErr, err.Error()) 174 | } 175 | } 176 | }) 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /types/integer_parameters/concrete_int_param.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package integer_parameters 4 | 5 | import "fmt" 6 | 7 | // ConcreteIntParam represents a single integer concrete parameter for a concrete type 8 | // Example: VARCHAR(6) -> 6 is an ConcreteIntParam 9 | // DECIMAL --> 0 Is an ConcreteIntParam but P not 10 | type ConcreteIntParam int32 11 | 12 | func NewConcreteIntParam(v int32) IntegerParameter { 13 | m := ConcreteIntParam(v) 14 | return &m 15 | } 16 | 17 | func (m *ConcreteIntParam) IsCompatible(o IntegerParameter) bool { 18 | if t, ok := o.(*ConcreteIntParam); ok { 19 | return *t == *m 20 | } 21 | return false 22 | } 23 | 24 | func (m *ConcreteIntParam) String() string { 25 | return fmt.Sprintf("%d", *m) 26 | } 27 | -------------------------------------------------------------------------------- /types/integer_parameters/integer_parameter_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package integer_parameters 4 | 5 | import "fmt" 6 | 7 | // IntegerParameter represents a parameter type 8 | // parameter can of concrete (38) or abstract type (P) 9 | // or another parameterized type like VARCHAR<"L1"> 10 | type IntegerParameter interface { 11 | // IsCompatible is type compatible with other 12 | // compatible is other can be used in place of this type 13 | IsCompatible(other IntegerParameter) bool 14 | fmt.Stringer 15 | } 16 | -------------------------------------------------------------------------------- /types/integer_parameters/integer_parameter_type_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package integer_parameters_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 10 | ) 11 | 12 | func TestConcreteParameterType(t *testing.T) { 13 | concreteType1 := integer_parameters.ConcreteIntParam(1) 14 | require.Equal(t, "1", concreteType1.String()) 15 | } 16 | 17 | func TestLeafParameterType(t *testing.T) { 18 | var concreteType1, concreteType2, abstractType1 integer_parameters.IntegerParameter 19 | 20 | concreteType1 = integer_parameters.NewConcreteIntParam(1) 21 | concreteType2 = integer_parameters.NewConcreteIntParam(2) 22 | 23 | abstractType1 = integer_parameters.NewVariableIntParam("P") 24 | 25 | // verify string val 26 | require.Equal(t, "1", concreteType1.String()) 27 | require.Equal(t, "P", abstractType1.String()) 28 | 29 | // concrete type is only compatible with same type 30 | require.True(t, concreteType1.IsCompatible(concreteType1)) 31 | require.False(t, concreteType1.IsCompatible(concreteType2)) 32 | 33 | // abstract type is compatible with both abstract and concrete type 34 | require.True(t, abstractType1.IsCompatible(abstractType1)) 35 | require.True(t, abstractType1.IsCompatible(concreteType2)) 36 | } 37 | -------------------------------------------------------------------------------- /types/integer_parameters/variable_int_param.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package integer_parameters 4 | 5 | // VariableIntParam represents an integer parameter for a parameterized type 6 | // Example: VARCHAR(L1) -> L1 is an VariableIntParam 7 | // DECIMAL --> P Is an VariableIntParam 8 | type VariableIntParam string 9 | 10 | func NewVariableIntParam(s string) IntegerParameter { 11 | m := VariableIntParam(s) 12 | return &m 13 | } 14 | 15 | func (m *VariableIntParam) IsCompatible(o IntegerParameter) bool { 16 | switch o.(type) { 17 | case *VariableIntParam, *ConcreteIntParam: 18 | return true 19 | default: 20 | return false 21 | } 22 | } 23 | 24 | func (m *VariableIntParam) String() string { 25 | return string(*m) 26 | } 27 | 28 | func (m *VariableIntParam) GetAbstractParamName() string { 29 | return string(*m) 30 | } 31 | -------------------------------------------------------------------------------- /types/interval_compound_type.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | 6 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 7 | ) 8 | 9 | // IntervalCompoundType this is used to represent a type of interval compound. 10 | type IntervalCompoundType struct { 11 | precision TimePrecision 12 | typeVariationRef uint32 13 | nullability Nullability 14 | } 15 | 16 | // NewIntervalCompoundType creates a type of new interval compound. 17 | func NewIntervalCompoundType() IntervalCompoundType { 18 | return IntervalCompoundType{} 19 | } 20 | 21 | func (m IntervalCompoundType) WithTypeVariationRef(typeVariationRef uint32) IntervalCompoundType { 22 | m.typeVariationRef = typeVariationRef 23 | return m 24 | } 25 | 26 | func (m IntervalCompoundType) GetPrecisionProtoVal() int32 { 27 | return m.precision.ToProtoVal() 28 | } 29 | 30 | func (IntervalCompoundType) isRootRef() {} 31 | func (m IntervalCompoundType) WithNullability(n Nullability) Type { 32 | return IntervalCompoundType{ 33 | precision: m.precision, 34 | nullability: n, 35 | } 36 | } 37 | 38 | func (m IntervalCompoundType) WithPrecision(precision TimePrecision) IntervalCompoundType { 39 | return IntervalCompoundType{ 40 | precision: precision, 41 | nullability: m.nullability, 42 | } 43 | } 44 | 45 | func (m IntervalCompoundType) GetType() Type { return m } 46 | func (m IntervalCompoundType) GetNullability() Nullability { return m.nullability } 47 | func (m IntervalCompoundType) GetTypeVariationReference() uint32 { return m.typeVariationRef } 48 | func (m IntervalCompoundType) Equals(rhs Type) bool { 49 | if o, ok := rhs.(IntervalCompoundType); ok { 50 | return o == m 51 | } 52 | if o, ok := rhs.(*IntervalCompoundType); ok { 53 | return *o == m 54 | } 55 | return false 56 | } 57 | 58 | func (m IntervalCompoundType) ToProtoFuncArg() *proto.FunctionArgument { 59 | return &proto.FunctionArgument{ 60 | ArgType: &proto.FunctionArgument_Type{Type: m.ToProto()}, 61 | } 62 | } 63 | 64 | func (m IntervalCompoundType) ToProto() *proto.Type { 65 | return &proto.Type{Kind: &proto.Type_IntervalCompound_{ 66 | IntervalCompound: &proto.Type_IntervalCompound{ 67 | Precision: m.precision.ToProtoVal(), 68 | Nullability: m.nullability, 69 | TypeVariationReference: m.typeVariationRef}}} 70 | } 71 | 72 | func (IntervalCompoundType) ShortString() string { return shortTypeNames[TypeNameIntervalCompound] } 73 | func (m IntervalCompoundType) String() string { 74 | return fmt.Sprintf("%s%s<%d>", TypeNameIntervalCompound, strNullable(m), m.precision.ToProtoVal()) 75 | } 76 | 77 | func (m IntervalCompoundType) GetParameters() []interface{} { 78 | return []interface{}{m.precision} 79 | } 80 | -------------------------------------------------------------------------------- /types/interval_compound_type_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | "github.com/stretchr/testify/assert" 9 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 10 | "google.golang.org/protobuf/testing/protocmp" 11 | ) 12 | 13 | func TestNewIntervalCompoundType(t *testing.T) { 14 | allPossibleTimePrecision := []TimePrecision{PrecisionSeconds, PrecisionDeciSeconds, PrecisionCentiSeconds, PrecisionMilliSeconds, 15 | PrecisionEMinus4Seconds, PrecisionEMinus5Seconds, PrecisionMicroSeconds, PrecisionEMinus7Seconds, PrecisionEMinus8Seconds, PrecisionNanoSeconds} 16 | allPossibleNullability := []Nullability{NullabilityUnspecified, NullabilityNullable, NullabilityRequired} 17 | 18 | for _, precision := range allPossibleTimePrecision { 19 | for _, nullability := range allPossibleNullability { 20 | expectedIntervalCompoundType := IntervalCompoundType{precision: precision, nullability: nullability} 21 | expectedFormatString := fmt.Sprintf("%s<%d>", strNullable(expectedIntervalCompoundType), precision.ToProtoVal()) 22 | 23 | parameters := expectedIntervalCompoundType.GetParameters() 24 | assert.Equal(t, parameters, []interface{}{precision}) 25 | // verify IntervalCompoundType 26 | createdIntervalCompoundTypeIfc := NewIntervalCompoundType().WithPrecision(precision).WithTypeVariationRef(0).WithNullability(nullability) 27 | createdIntervalCompoundType := createdIntervalCompoundTypeIfc.(IntervalCompoundType) 28 | assert.True(t, createdIntervalCompoundType.Equals(expectedIntervalCompoundType)) 29 | assert.Equal(t, expectedProtoValMap[precision], createdIntervalCompoundType.GetPrecisionProtoVal()) 30 | assert.Equal(t, nullability, createdIntervalCompoundType.GetNullability()) 31 | assert.Zero(t, createdIntervalCompoundType.GetTypeVariationReference()) 32 | assert.Equal(t, fmt.Sprintf("interval_compound%s", expectedFormatString), createdIntervalCompoundType.String()) 33 | assert.Equal(t, "icompound", createdIntervalCompoundType.ShortString()) 34 | assertIntervalCompoundTypeProto(t, precision, nullability, createdIntervalCompoundType) 35 | } 36 | } 37 | } 38 | 39 | func assertIntervalCompoundTypeProto(t *testing.T, expectedPrecision TimePrecision, expectedNullability Nullability, 40 | toVerifyType IntervalCompoundType) { 41 | 42 | expectedTypeProto := &proto.Type{Kind: &proto.Type_IntervalCompound_{ 43 | IntervalCompound: &proto.Type_IntervalCompound{ 44 | Precision: expectedPrecision.ToProtoVal(), 45 | Nullability: expectedNullability, 46 | }, 47 | }} 48 | if diff := cmp.Diff(toVerifyType.ToProto(), expectedTypeProto, protocmp.Transform()); diff != "" { 49 | t.Errorf("IntervalCompoundType proto didn't match, diff:\n%v", diff) 50 | } 51 | 52 | expectedFuncArgProto := &proto.FunctionArgument{ArgType: &proto.FunctionArgument_Type{ 53 | Type: expectedTypeProto, 54 | }} 55 | if diff := cmp.Diff(toVerifyType.ToProtoFuncArg(), expectedFuncArgProto, protocmp.Transform()); diff != "" { 56 | t.Errorf("IntervalCompoundType func arg proto didn't match, diff:\n%v", diff) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /types/interval_day_type.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | 6 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 7 | ) 8 | 9 | // IntervalDayType this is used to represent a type of interval day. 10 | type IntervalDayType struct { 11 | Precision TimePrecision 12 | TypeVariationRef uint32 13 | Nullability Nullability 14 | } 15 | 16 | func (m *IntervalDayType) GetPrecisionProtoVal() int32 { 17 | return m.Precision.ToProtoVal() 18 | } 19 | 20 | func (*IntervalDayType) isRootRef() {} 21 | func (m *IntervalDayType) WithNullability(n Nullability) Type { 22 | m.Nullability = n 23 | return m 24 | } 25 | 26 | func (m *IntervalDayType) GetType() Type { return m } 27 | func (m *IntervalDayType) GetNullability() Nullability { return m.Nullability } 28 | func (m *IntervalDayType) GetTypeVariationReference() uint32 { return m.TypeVariationRef } 29 | func (m *IntervalDayType) Equals(rhs Type) bool { 30 | if o, ok := rhs.(*IntervalDayType); ok { 31 | return *o == *m 32 | } 33 | return false 34 | } 35 | 36 | func (m *IntervalDayType) ToProtoFuncArg() *proto.FunctionArgument { 37 | return &proto.FunctionArgument{ 38 | ArgType: &proto.FunctionArgument_Type{Type: m.ToProto()}, 39 | } 40 | } 41 | 42 | func (m *IntervalDayType) ToProto() *proto.Type { 43 | precisionVal := m.Precision.ToProtoVal() 44 | return &proto.Type{Kind: &proto.Type_IntervalDay_{ 45 | IntervalDay: &proto.Type_IntervalDay{ 46 | Precision: &precisionVal, 47 | Nullability: m.Nullability, 48 | TypeVariationReference: m.TypeVariationRef}}} 49 | } 50 | 51 | func (*IntervalDayType) ShortString() string { return shortTypeNames[TypeNameIntervalDay] } 52 | 53 | func (m *IntervalDayType) String() string { 54 | return fmt.Sprintf("%s%s<%d>", TypeNameIntervalDay, strNullable(m), 55 | m.Precision.ToProtoVal()) 56 | } 57 | 58 | func (m *IntervalDayType) ParameterString() string { 59 | return fmt.Sprintf("%d", m.Precision.ToProtoVal()) 60 | } 61 | 62 | func (s *IntervalDayType) BaseString() string { 63 | return string(TypeNameIntervalDay) 64 | } 65 | 66 | func (m *IntervalDayType) GetPrecision() TimePrecision { 67 | return m.Precision 68 | } 69 | 70 | func (m *IntervalDayType) GetReturnType(length int32, nullability Nullability) Type { 71 | out := *m 72 | out.Precision = TimePrecision(length) 73 | out.Nullability = nullability 74 | return &out 75 | } 76 | 77 | func (m *IntervalDayType) GetParameters() []interface{} { 78 | return []interface{}{m.Precision} 79 | } 80 | -------------------------------------------------------------------------------- /types/interval_day_type_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | "github.com/stretchr/testify/assert" 9 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 10 | "google.golang.org/protobuf/testing/protocmp" 11 | ) 12 | 13 | func TestIntervalDayType(t *testing.T) { 14 | anotherType := &FixedCharType{Length: 10, Nullability: NullabilityNullable} 15 | allPossibleTimePrecision := []TimePrecision{PrecisionSeconds, PrecisionDeciSeconds, PrecisionCentiSeconds, PrecisionMilliSeconds, 16 | PrecisionEMinus4Seconds, PrecisionEMinus5Seconds, PrecisionMicroSeconds, PrecisionEMinus7Seconds, PrecisionEMinus8Seconds, PrecisionNanoSeconds} 17 | allPossibleNullability := []Nullability{NullabilityUnspecified, NullabilityNullable, NullabilityRequired} 18 | 19 | for _, precision := range allPossibleTimePrecision { 20 | for _, nullability := range allPossibleNullability { 21 | expectedIntervalDayType := &IntervalDayType{Precision: precision, Nullability: nullability} 22 | expectedFormatString := fmt.Sprintf("%s<%d>", strNullable(expectedIntervalDayType), precision.ToProtoVal()) 23 | 24 | parameters := expectedIntervalDayType.GetParameters() 25 | assert.Equal(t, parameters, []interface{}{precision}) 26 | // verify IntervalDayType 27 | createdIntervalDayTypeIfc := (&IntervalDayType{Precision: precision}).WithNullability(nullability) 28 | createdIntervalDayType := createdIntervalDayTypeIfc.(*IntervalDayType) 29 | assert.True(t, createdIntervalDayType.Equals(expectedIntervalDayType)) 30 | assert.Equal(t, expectedProtoValMap[precision], createdIntervalDayType.GetPrecisionProtoVal()) 31 | assert.Equal(t, nullability, createdIntervalDayType.GetNullability()) 32 | assert.Zero(t, createdIntervalDayType.GetTypeVariationReference()) 33 | assert.Equal(t, fmt.Sprintf("interval_day%s", expectedFormatString), createdIntervalDayType.String()) 34 | assert.Equal(t, "iday", createdIntervalDayType.ShortString()) 35 | assert.Equal(t, "interval_day", createdIntervalDayType.BaseString()) 36 | assert.Equal(t, precision, createdIntervalDayType.GetPrecision()) 37 | expectedParameterString := fmt.Sprintf("%d", precision.ToProtoVal()) 38 | assert.Equal(t, expectedParameterString, createdIntervalDayType.ParameterString()) 39 | assertIntervalDayTypeProto(t, precision, nullability, createdIntervalDayType) 40 | assert.False(t, createdIntervalDayTypeIfc.Equals(anotherType)) 41 | } 42 | } 43 | } 44 | 45 | func assertIntervalDayTypeProto(t *testing.T, expectedPrecision TimePrecision, expectedNullability Nullability, 46 | toVerifyType *IntervalDayType) { 47 | 48 | expectedPrecisionProtoVal := expectedPrecision.ToProtoVal() 49 | expectedTypeProto := &proto.Type{Kind: &proto.Type_IntervalDay_{ 50 | IntervalDay: &proto.Type_IntervalDay{ 51 | Precision: &expectedPrecisionProtoVal, 52 | Nullability: expectedNullability, 53 | }, 54 | }} 55 | if diff := cmp.Diff(toVerifyType.ToProto(), expectedTypeProto, protocmp.Transform()); diff != "" { 56 | t.Errorf("IntervalDayType proto didn't match, diff:\n%v", diff) 57 | } 58 | 59 | expectedFuncArgProto := &proto.FunctionArgument{ArgType: &proto.FunctionArgument_Type{ 60 | Type: expectedTypeProto, 61 | }} 62 | if diff := cmp.Diff(toVerifyType.ToProtoFuncArg(), expectedFuncArgProto, protocmp.Transform()); diff != "" { 63 | t.Errorf("IntervalDayType func arg proto didn't match, diff:\n%v", diff) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /types/interval_year_month_type.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | 6 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 7 | ) 8 | 9 | // IntervalYearToMonthType this is used to represent a type of interval which represents YearToMonth. 10 | type IntervalYearToMonthType struct { 11 | typeVariationRef uint32 12 | nullability Nullability 13 | } 14 | 15 | // NewIntervalYearToMonthType creates a type of new interval YearToMonth. 16 | // Created type has nullability as Nullable 17 | func NewIntervalYearToMonthType() IntervalYearToMonthType { 18 | return IntervalYearToMonthType{ 19 | nullability: NullabilityNullable, 20 | } 21 | } 22 | 23 | func (m IntervalYearToMonthType) WithTypeVariationRef(typeVariationRef uint32) IntervalYearToMonthType { 24 | m.typeVariationRef = typeVariationRef 25 | return m 26 | } 27 | 28 | func (IntervalYearToMonthType) isRootRef() {} 29 | func (m IntervalYearToMonthType) WithNullability(n Nullability) Type { 30 | m.nullability = n 31 | return m 32 | } 33 | 34 | func (m IntervalYearToMonthType) GetType() Type { return m } 35 | func (m IntervalYearToMonthType) GetNullability() Nullability { return m.nullability } 36 | func (m IntervalYearToMonthType) GetTypeVariationReference() uint32 { return m.typeVariationRef } 37 | func (m IntervalYearToMonthType) Equals(rhs Type) bool { 38 | if o, ok := rhs.(IntervalYearToMonthType); ok { 39 | return o == m 40 | } 41 | return false 42 | } 43 | 44 | func (m IntervalYearToMonthType) ToProtoFuncArg() *proto.FunctionArgument { 45 | return &proto.FunctionArgument{ 46 | ArgType: &proto.FunctionArgument_Type{Type: m.ToProto()}, 47 | } 48 | } 49 | 50 | func (m IntervalYearToMonthType) ToProto() *proto.Type { 51 | return &proto.Type{Kind: &proto.Type_IntervalYear_{ 52 | IntervalYear: &proto.Type_IntervalYear{ 53 | Nullability: m.nullability, 54 | TypeVariationReference: m.typeVariationRef}}} 55 | } 56 | 57 | func (IntervalYearToMonthType) ShortString() string { return shortTypeNames[TypeNameIntervalYear] } 58 | func (m IntervalYearToMonthType) String() string { 59 | return fmt.Sprintf("%s%s", TypeNameIntervalYear, strNullable(m)) 60 | } 61 | 62 | func (m IntervalYearToMonthType) GetParameters() []interface{} { 63 | return nil 64 | } 65 | -------------------------------------------------------------------------------- /types/interval_year_month_type_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | "github.com/stretchr/testify/assert" 9 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 10 | "google.golang.org/protobuf/testing/protocmp" 11 | ) 12 | 13 | func TestNewIntervalYearToMonthType(t *testing.T) { 14 | allPossibleNullability := []Nullability{NullabilityUnspecified, NullabilityNullable, NullabilityRequired} 15 | 16 | for _, nullability := range allPossibleNullability { 17 | expectedIntervalType := IntervalYearToMonthType{nullability: nullability} 18 | 19 | parameters := expectedIntervalType.GetParameters() 20 | assert.Len(t, parameters, 0) 21 | // verify IntervalYearToMonthType 22 | createdIntervalTypeIfcType := NewIntervalYearToMonthType().WithTypeVariationRef(0).WithNullability(nullability) 23 | createdIntervalType := createdIntervalTypeIfcType.(IntervalYearToMonthType) 24 | assert.True(t, createdIntervalType.Equals(expectedIntervalType)) 25 | assert.Equal(t, nullability, createdIntervalType.GetNullability()) 26 | assert.Zero(t, createdIntervalTypeIfcType.GetTypeVariationReference()) 27 | assert.Equal(t, fmt.Sprintf("interval_year%s", strNullable(expectedIntervalType)), createdIntervalType.String()) 28 | assert.Equal(t, "iyear", createdIntervalType.ShortString()) 29 | assertIntervalYearToMonthTypeProto(t, nullability, createdIntervalType) 30 | } 31 | } 32 | 33 | func assertIntervalYearToMonthTypeProto(t *testing.T, expectedNullability Nullability, 34 | toVerifyType IntervalYearToMonthType) { 35 | 36 | expectedTypeProto := &proto.Type{Kind: &proto.Type_IntervalYear_{ 37 | IntervalYear: &proto.Type_IntervalYear{ 38 | Nullability: expectedNullability, 39 | }, 40 | }} 41 | if diff := cmp.Diff(toVerifyType.ToProto(), expectedTypeProto, protocmp.Transform()); diff != "" { 42 | t.Errorf("IntervalYearToMonthType proto didn't match, diff:\n%v", diff) 43 | } 44 | 45 | expectedFuncArgProto := &proto.FunctionArgument{ArgType: &proto.FunctionArgument_Type{ 46 | Type: expectedTypeProto, 47 | }} 48 | if diff := cmp.Diff(toVerifyType.ToProtoFuncArg(), expectedFuncArgProto, protocmp.Transform()); diff != "" { 49 | t.Errorf("IntervalYearToMonthType func arg proto didn't match, diff:\n%v", diff) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /types/parameterized_decimal_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | 8 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 9 | ) 10 | 11 | // ParameterizedDecimalType is a decimal type which to hold function arguments 12 | // example: Decimal or Decimal or Decimal(10, 2) 13 | type ParameterizedDecimalType struct { 14 | Nullability Nullability 15 | TypeVariationRef uint32 16 | Precision integer_parameters.IntegerParameter 17 | Scale integer_parameters.IntegerParameter 18 | } 19 | 20 | func (m *ParameterizedDecimalType) SetNullability(n Nullability) FuncDefArgType { 21 | m.Nullability = n 22 | return m 23 | } 24 | 25 | func (m *ParameterizedDecimalType) String() string { 26 | t := DecimalType{} 27 | parameterString := fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) 28 | return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) 29 | } 30 | 31 | func (m *ParameterizedDecimalType) HasParameterizedParam() bool { 32 | _, ok1 := m.Precision.(*integer_parameters.VariableIntParam) 33 | _, ok2 := m.Scale.(*integer_parameters.VariableIntParam) 34 | return ok1 || ok2 35 | } 36 | 37 | func (m *ParameterizedDecimalType) GetParameterizedParams() []interface{} { 38 | if !m.HasParameterizedParam() { 39 | return nil 40 | } 41 | var params []interface{} 42 | params = append(params, m.Precision) 43 | params = append(params, m.Scale) 44 | return params 45 | } 46 | 47 | func (m *ParameterizedDecimalType) MatchWithNullability(ot Type) bool { 48 | if m.Nullability != ot.GetNullability() { 49 | return false 50 | } 51 | return m.MatchWithoutNullability(ot) 52 | } 53 | 54 | func (m *ParameterizedDecimalType) MatchWithoutNullability(ot Type) bool { 55 | if odt, ok := ot.(*DecimalType); ok { 56 | concretePrecision := integer_parameters.NewConcreteIntParam(odt.Precision) 57 | concreteScale := integer_parameters.NewConcreteIntParam(odt.Scale) 58 | return m.Precision.IsCompatible(concretePrecision) && m.Scale.IsCompatible(concreteScale) 59 | } 60 | return false 61 | } 62 | 63 | func (m *ParameterizedDecimalType) GetNullability() Nullability { 64 | return m.Nullability 65 | } 66 | 67 | func (m *ParameterizedDecimalType) ShortString() string { 68 | return "dec" 69 | } 70 | 71 | func (m *ParameterizedDecimalType) ReturnType(parameters []FuncDefArgType, argumentTypes []Type) (Type, error) { 72 | precision, perr := m.Precision.(*integer_parameters.ConcreteIntParam) 73 | scale, serr := m.Scale.(*integer_parameters.ConcreteIntParam) 74 | if !perr || !serr { 75 | derivation := OutputDerivation{FinalType: m} 76 | return derivation.ReturnType(parameters, argumentTypes) 77 | } 78 | return &DecimalType{Nullability: m.Nullability, Precision: int32(*precision), Scale: int32(*scale)}, nil 79 | } 80 | 81 | func (m *ParameterizedDecimalType) WithParameters(params []interface{}) (Type, error) { 82 | if len(params) != 2 { 83 | p, pOk := m.Precision.(*integer_parameters.ConcreteIntParam) 84 | s, sOk := m.Scale.(*integer_parameters.ConcreteIntParam) 85 | if pOk && sOk { 86 | return &DecimalType{Nullability: m.Nullability, Precision: int32(*p), Scale: int32(*s)}, nil 87 | } 88 | return nil, fmt.Errorf("decimal type must have 2 parameters") 89 | } 90 | if precision, ok := params[0].(int64); ok { 91 | if scale, ok := params[1].(int64); ok { 92 | return &DecimalType{Nullability: m.Nullability, Precision: int32(precision), Scale: int32(scale)}, nil 93 | } 94 | return nil, fmt.Errorf("scale must be an integer") 95 | } 96 | return nil, fmt.Errorf("precision must be an integer, but got %t", params[0]) 97 | } 98 | -------------------------------------------------------------------------------- /types/parameterized_decimal_type_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/substrait-io/substrait-go/v4/types" 10 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 11 | ) 12 | 13 | func TestParameterizedDecimalType(t *testing.T) { 14 | precision_P := integer_parameters.NewVariableIntParam("P") 15 | scale_S := integer_parameters.NewVariableIntParam("S") 16 | precision_38 := integer_parameters.NewConcreteIntParam(38) 17 | scale_5 := integer_parameters.NewConcreteIntParam(5) 18 | for _, td := range []struct { 19 | name string 20 | precision integer_parameters.IntegerParameter 21 | scale integer_parameters.IntegerParameter 22 | args []interface{} 23 | expectedNullableString string 24 | expectedNullableRequiredString string 25 | expectedHasParameterizedParam bool 26 | expectedParameterizedParams []interface{} 27 | expectedReturnType types.Type 28 | }{ 29 | {"both parameterized", precision_P, scale_S, []any{int64(30), int64(13)}, "decimal?", "decimal", true, []interface{}{precision_P, scale_S}, &types.DecimalType{Precision: 30, Scale: 13, Nullability: types.NullabilityRequired}}, 30 | {"precision concrete", precision_38, scale_S, []any{int64(38), int64(6)}, "decimal?<38,S>", "decimal<38,S>", true, []interface{}{precision_38, scale_S}, &types.DecimalType{Precision: 38, Scale: 6, Nullability: types.NullabilityRequired}}, 31 | {"scale concrete", precision_P, scale_5, []any{int64(30), int64(5)}, "decimal?", "decimal", true, []interface{}{precision_P, scale_5}, &types.DecimalType{Precision: 30, Scale: 5, Nullability: types.NullabilityRequired}}, 32 | {"both concrete", precision_38, scale_5, []any{}, "decimal?<38,5>", "decimal<38,5>", false, nil, &types.DecimalType{Precision: 38, Scale: 5, Nullability: types.NullabilityRequired}}, 33 | } { 34 | t.Run(td.name, func(t *testing.T) { 35 | pd := &types.ParameterizedDecimalType{Precision: td.precision, Scale: td.scale} 36 | require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) 37 | require.Equal(t, types.NullabilityNullable, pd.GetNullability()) 38 | require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) 39 | require.Equal(t, types.NullabilityRequired, pd.GetNullability()) 40 | require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) 41 | require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) 42 | require.Equal(t, "dec", pd.ShortString()) 43 | retType, err := pd.ReturnType(nil, nil) 44 | if td.expectedHasParameterizedParam { 45 | require.Error(t, err) 46 | require.True(t, pd.HasParameterizedParam()) 47 | retType, err = pd.ReturnType([]types.FuncDefArgType{pd}, []types.Type{td.expectedReturnType}) 48 | require.NoError(t, err) 49 | require.Equal(t, td.expectedReturnType, retType) 50 | } else { 51 | require.Nil(t, err) 52 | require.Equal(t, td.expectedReturnType, retType) 53 | } 54 | resultType, err := pd.WithParameters(td.args) 55 | require.Nil(t, err) 56 | require.Equal(t, td.expectedReturnType, resultType) 57 | }) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /types/parameterized_list_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | ) 8 | 9 | // ParameterizedListType is a list type having parameter of ParameterizedAbstractType 10 | // basically a list of which type is another abstract parameter 11 | // example: List. Kindly note concrete types List is not represented by this type 12 | // Concrete type is represented by ListType 13 | type ParameterizedListType struct { 14 | Nullability Nullability 15 | TypeVariationRef uint32 16 | Type FuncDefArgType 17 | } 18 | 19 | func (m *ParameterizedListType) SetNullability(n Nullability) FuncDefArgType { 20 | m.Nullability = n 21 | return m 22 | } 23 | 24 | func (m *ParameterizedListType) String() string { 25 | t := ListType{} 26 | parameterString := fmt.Sprintf("<%s>", m.Type) 27 | return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) 28 | } 29 | 30 | func (m *ParameterizedListType) HasParameterizedParam() bool { 31 | return m.Type.HasParameterizedParam() 32 | } 33 | 34 | func (m *ParameterizedListType) GetParameterizedParams() []interface{} { 35 | if !m.HasParameterizedParam() { 36 | return nil 37 | } 38 | return []interface{}{m.Type} 39 | } 40 | 41 | func (m *ParameterizedListType) MatchWithNullability(ot Type) bool { 42 | if m.Nullability != ot.GetNullability() { 43 | return false 44 | } 45 | if olt, ok := ot.(*ListType); ok { 46 | result := m.Type.MatchWithNullability(olt.Type) 47 | return result 48 | } 49 | return false 50 | } 51 | 52 | func (m *ParameterizedListType) MatchWithoutNullability(ot Type) bool { 53 | if olt, ok := ot.(*ListType); ok { 54 | return m.Type.MatchWithoutNullability(olt.Type) 55 | } 56 | return false 57 | } 58 | 59 | func (m *ParameterizedListType) GetNullability() Nullability { 60 | return m.Nullability 61 | } 62 | 63 | func (m *ParameterizedListType) ShortString() string { 64 | return "list" 65 | } 66 | 67 | func (m *ParameterizedListType) ReturnType( 68 | funcParams []FuncDefArgType, argTypes []Type, 69 | ) (Type, error) { 70 | elemType, err := m.Type.ReturnType(funcParams, argTypes) 71 | if err != nil { 72 | return nil, err 73 | } 74 | return &ListType{Nullability: m.Nullability, Type: elemType}, nil 75 | } 76 | 77 | func (m *ParameterizedListType) WithParameters(params []interface{}) (Type, error) { 78 | if len(params) != 1 { 79 | return nil, fmt.Errorf("expected 1 parameter, got %d", len(params)) 80 | } 81 | if t, ok := params[0].(Type); ok { 82 | return &ListType{Nullability: m.Nullability, Type: t}, nil 83 | } 84 | return nil, fmt.Errorf("expected parameter to be of type Type, got %T", params[0]) 85 | } 86 | -------------------------------------------------------------------------------- /types/parameterized_list_type_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/substrait-io/substrait-go/v4/types" 11 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 12 | ) 13 | 14 | func TestParameterizedListType(t *testing.T) { 15 | decimalType := &types.ParameterizedDecimalType{ 16 | Precision: integer_parameters.NewVariableIntParam("P"), 17 | Scale: integer_parameters.NewVariableIntParam("S"), 18 | Nullability: types.NullabilityRequired, 19 | } 20 | int8Type := &types.Int8Type{} 21 | dec30PS5 := &types.DecimalType{Precision: 30, Scale: 5, Nullability: types.NullabilityRequired} 22 | for _, td := range []struct { 23 | name string 24 | param types.FuncDefArgType 25 | args []interface{} 26 | expectedNullableString string 27 | expectedNullableRequiredString string 28 | expectedHasParameterizedParam bool 29 | expectedParameterizedParams []interface{} 30 | expectedReturnType types.Type 31 | }{ 32 | {"parameterized param", decimalType, []any{dec30PS5}, "list?>", "list>", true, []interface{}{decimalType}, &types.ListType{Nullability: types.NullabilityRequired, Type: dec30PS5}}, 33 | {"concrete param", int8Type, []any{int8Type}, "list?", "list", false, nil, &types.ListType{Nullability: types.NullabilityRequired, Type: int8Type}}, 34 | {"list", &types.AnyType{Name: "any"}, []any{int8Type}, "list?", "list", false, nil, &types.ListType{Nullability: types.NullabilityRequired, Type: int8Type}}, 35 | } { 36 | t.Run(td.name, func(t *testing.T) { 37 | pd := &types.ParameterizedListType{Type: td.param} 38 | assert.Equal(t, types.NullabilityUnspecified, pd.GetNullability()) 39 | require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) 40 | require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) 41 | assert.Equal(t, types.NullabilityRequired, pd.GetNullability()) 42 | require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) 43 | require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) 44 | assert.Equal(t, "list", pd.ShortString()) 45 | retType, err := pd.ReturnType([]types.FuncDefArgType{td.param}, []types.Type{td.args[0].(types.Type)}) 46 | if td.expectedReturnType == nil { 47 | assert.Error(t, err) 48 | require.True(t, pd.HasParameterizedParam()) 49 | } else { 50 | require.Nil(t, err) 51 | require.Equal(t, td.expectedReturnType, retType) 52 | resultType, err := pd.WithParameters(td.args) 53 | require.Nil(t, err) 54 | require.Equal(t, td.expectedReturnType, resultType) 55 | } 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /types/parameterized_map_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | ) 8 | 9 | // ParameterizedMapType is a struct having at least one of key or value of type ParameterizedAbstractType 10 | // If All arguments are concrete they are represented by MapType 11 | type ParameterizedMapType struct { 12 | Nullability Nullability 13 | TypeVariationRef uint32 14 | Key FuncDefArgType 15 | Value FuncDefArgType 16 | } 17 | 18 | func (m *ParameterizedMapType) SetNullability(n Nullability) FuncDefArgType { 19 | m.Nullability = n 20 | return m 21 | } 22 | 23 | func (m *ParameterizedMapType) String() string { 24 | t := MapType{} 25 | parameterString := fmt.Sprintf("<%s, %s>", m.Key.String(), m.Value.String()) 26 | return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) 27 | } 28 | 29 | func (m *ParameterizedMapType) HasParameterizedParam() bool { 30 | return m.Key.HasParameterizedParam() || m.Value.HasParameterizedParam() 31 | } 32 | 33 | func (m *ParameterizedMapType) GetParameterizedParams() []interface{} { 34 | if !m.HasParameterizedParam() { 35 | return nil 36 | } 37 | var abstractParams []interface{} 38 | if m.Key.HasParameterizedParam() { 39 | abstractParams = append(abstractParams, m.Key) 40 | } 41 | if m.Value.HasParameterizedParam() { 42 | abstractParams = append(abstractParams, m.Value) 43 | } 44 | return abstractParams 45 | } 46 | 47 | func (m *ParameterizedMapType) MatchWithNullability(ot Type) bool { 48 | if m.Nullability != ot.GetNullability() { 49 | return false 50 | } 51 | if omt, ok := ot.(*MapType); ok { 52 | return m.Key.MatchWithNullability(omt.Key) && m.Value.MatchWithNullability(omt.Value) 53 | } 54 | return false 55 | } 56 | 57 | func (m *ParameterizedMapType) MatchWithoutNullability(ot Type) bool { 58 | if omt, ok := ot.(*MapType); ok { 59 | return m.Key.MatchWithoutNullability(omt.Key) && m.Value.MatchWithoutNullability(omt.Value) 60 | } 61 | return false 62 | } 63 | 64 | func (m *ParameterizedMapType) GetNullability() Nullability { 65 | return m.Nullability 66 | } 67 | 68 | func (m *ParameterizedMapType) ShortString() string { 69 | return "map" 70 | } 71 | 72 | func (m *ParameterizedMapType) ReturnType([]FuncDefArgType, []Type) (Type, error) { 73 | keyType, kerr := m.Key.ReturnType(nil, nil) 74 | if kerr != nil { 75 | return nil, fmt.Errorf("error in getting key type: %w", kerr) 76 | } 77 | valueType, verr := m.Value.ReturnType(nil, nil) 78 | if verr != nil { 79 | return nil, fmt.Errorf("error in getting value type: %w", kerr) 80 | } 81 | 82 | return &MapType{Nullability: m.Nullability, Key: keyType, Value: valueType}, nil 83 | } 84 | 85 | func (m *ParameterizedMapType) WithParameters(params []interface{}) (Type, error) { 86 | if len(params) != 2 { 87 | if m.Key.HasParameterizedParam() || m.Value.HasParameterizedParam() { 88 | return nil, fmt.Errorf("map type must have 2 parameters") 89 | } 90 | return m.ReturnType(nil, nil) 91 | } 92 | if key, ok := params[0].(Type); ok { 93 | if value, ok := params[1].(Type); ok { 94 | return &MapType{Nullability: m.Nullability, Key: key, Value: value}, nil 95 | } 96 | return nil, fmt.Errorf("value must be a Type") 97 | } 98 | return nil, fmt.Errorf("key must be a Type") 99 | } 100 | -------------------------------------------------------------------------------- /types/parameterized_map_type_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/substrait-io/substrait-go/v4/types" 11 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 12 | ) 13 | 14 | func TestParameterizedMapType(t *testing.T) { 15 | decimalType := &types.ParameterizedDecimalType{ 16 | Precision: integer_parameters.NewVariableIntParam("P"), 17 | Scale: integer_parameters.NewVariableIntParam("S"), 18 | Nullability: types.NullabilityRequired, 19 | } 20 | int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} 21 | listParametrizedType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} 22 | dec30PS5 := &types.DecimalType{Precision: 30, Scale: 5, Nullability: types.NullabilityRequired} 23 | dec30PS9 := &types.DecimalType{Precision: 30, Scale: 9, Nullability: types.NullabilityRequired} 24 | listType := &types.ListType{Type: dec30PS9, Nullability: types.NullabilityRequired} 25 | for _, td := range []struct { 26 | name string 27 | Key types.FuncDefArgType 28 | Value types.FuncDefArgType 29 | args []interface{} 30 | expectedNullableString string 31 | expectedNullableRequiredString string 32 | expectedHasParameterizedParam bool 33 | expectedParameterizedParams []interface{} 34 | expectedReturnType types.Type 35 | }{ 36 | {"parameterized kv", decimalType, listParametrizedType, []any{dec30PS5, listType}, "map?, list?>>", "map, list?>>", true, []interface{}{decimalType, listParametrizedType}, &types.MapType{Nullability: types.NullabilityRequired, Key: dec30PS5, Value: listType}}, 37 | {"concrete key", int8Type, listParametrizedType, []any{int8Type, listType}, "map?>>", "map>>", true, []interface{}{listParametrizedType}, &types.MapType{Nullability: types.NullabilityRequired, Key: int8Type, Value: listType}}, 38 | {"concrete value", decimalType, int8Type, []any{dec30PS9, int8Type}, "map?, i8?>", "map, i8?>", true, []interface{}{decimalType}, &types.MapType{Nullability: types.NullabilityRequired, Key: dec30PS9, Value: int8Type}}, 39 | {"no parameterized param", int8Type, int8Type, []any{}, "map?", "map", false, nil, &types.MapType{Nullability: types.NullabilityRequired, Key: int8Type, Value: int8Type}}, 40 | } { 41 | t.Run(td.name, func(t *testing.T) { 42 | pd := &types.ParameterizedMapType{Key: td.Key, Value: td.Value} 43 | assert.Equal(t, types.NullabilityUnspecified, pd.GetNullability()) 44 | require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) 45 | assert.Equal(t, types.NullabilityNullable, pd.GetNullability()) 46 | require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) 47 | assert.Equal(t, types.NullabilityRequired, pd.GetNullability()) 48 | require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) 49 | require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) 50 | assert.Equal(t, "map", pd.ShortString()) 51 | retType, err := pd.ReturnType(nil, nil) 52 | if td.expectedHasParameterizedParam { 53 | assert.Error(t, err) 54 | require.True(t, pd.HasParameterizedParam()) 55 | } else { 56 | require.Nil(t, err) 57 | require.Equal(t, td.expectedReturnType, retType) 58 | } 59 | resultType, err := pd.WithParameters(td.args) 60 | require.Nil(t, err) 61 | require.Equal(t, td.expectedReturnType, resultType) 62 | }) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /types/parameterized_single_integer_param_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | "reflect" 8 | 9 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 10 | ) 11 | 12 | type singleIntegerParamType interface { 13 | BaseString() string 14 | ShortString() string 15 | GetReturnType(length int32, nullability Nullability) Type 16 | } 17 | 18 | // parameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter 19 | type parameterizedTypeSingleIntegerParam[T singleIntegerParamType] struct { 20 | Nullability Nullability 21 | TypeVariationRef uint32 22 | IntegerOption integer_parameters.IntegerParameter 23 | } 24 | 25 | func (m *parameterizedTypeSingleIntegerParam[T]) SetNullability(n Nullability) FuncDefArgType { 26 | m.Nullability = n 27 | return m 28 | } 29 | 30 | func (m *parameterizedTypeSingleIntegerParam[T]) String() string { 31 | return fmt.Sprintf("%s%s%s", m.baseString(), strFromNullability(m.Nullability), m.parameterString()) 32 | } 33 | 34 | func (m *parameterizedTypeSingleIntegerParam[T]) parameterString() string { 35 | return fmt.Sprintf("<%s>", m.IntegerOption.String()) 36 | } 37 | 38 | func (m *parameterizedTypeSingleIntegerParam[T]) baseString() string { 39 | var t T 40 | tType := reflect.TypeOf(t) 41 | if tType.Kind() == reflect.Ptr { 42 | tType = tType.Elem() 43 | } 44 | newInstance := reflect.New(tType).Interface().(T) 45 | return newInstance.BaseString() 46 | } 47 | 48 | func (m *parameterizedTypeSingleIntegerParam[T]) HasParameterizedParam() bool { 49 | _, ok1 := m.IntegerOption.(*integer_parameters.VariableIntParam) 50 | return ok1 51 | } 52 | 53 | func (m *parameterizedTypeSingleIntegerParam[T]) GetParameterizedParams() []interface{} { 54 | if !m.HasParameterizedParam() { 55 | return nil 56 | } 57 | return []interface{}{m.IntegerOption} 58 | } 59 | 60 | func (m *parameterizedTypeSingleIntegerParam[T]) MatchWithNullability(ot Type) bool { 61 | if m.Nullability != ot.GetNullability() { 62 | return false 63 | } 64 | return m.MatchWithoutNullability(ot) 65 | } 66 | 67 | func (m *parameterizedTypeSingleIntegerParam[T]) MatchWithoutNullability(ot Type) bool { 68 | if reflect.TypeFor[T]() != reflect.TypeOf(ot) { 69 | return false 70 | } 71 | if odt, ok := ot.(FixedType); ok { 72 | concreteLength := integer_parameters.NewConcreteIntParam(odt.GetLength()) 73 | return m.IntegerOption.IsCompatible(concreteLength) 74 | } 75 | if odt, ok := ot.(timestampPrecisionType); ok { 76 | concreteLength := integer_parameters.NewConcreteIntParam(odt.GetPrecision().ToProtoVal()) 77 | return m.IntegerOption.IsCompatible(concreteLength) 78 | } 79 | return false 80 | } 81 | 82 | func (m *parameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { 83 | return m.Nullability 84 | } 85 | 86 | func (m *parameterizedTypeSingleIntegerParam[T]) ShortString() string { 87 | newInstance := m.getNewInstance() 88 | return newInstance.ShortString() 89 | } 90 | 91 | func (m *parameterizedTypeSingleIntegerParam[T]) getNewInstance() T { 92 | var t T 93 | tType := reflect.TypeOf(t) 94 | if tType.Kind() == reflect.Ptr { 95 | tType = tType.Elem() 96 | } 97 | return reflect.New(tType).Interface().(T) 98 | } 99 | 100 | func (m *parameterizedTypeSingleIntegerParam[T]) ReturnType(params []FuncDefArgType, argumentTypes []Type) (Type, error) { 101 | concreteIntParam, ok := m.IntegerOption.(*integer_parameters.ConcreteIntParam) 102 | if !ok { 103 | derivation := OutputDerivation{FinalType: m} 104 | return derivation.ReturnType(params, argumentTypes) 105 | } 106 | t := m.getNewInstance() 107 | return t.GetReturnType(int32(*concreteIntParam), m.Nullability), nil 108 | } 109 | 110 | func (m *parameterizedTypeSingleIntegerParam[T]) WithParameters(params []interface{}) (Type, error) { 111 | if len(params) != 1 { 112 | if concreteIntParam, ok := m.IntegerOption.(*integer_parameters.ConcreteIntParam); ok { 113 | return m.getNewInstance().GetReturnType(int32(*concreteIntParam), m.Nullability), nil 114 | } 115 | return nil, fmt.Errorf("type must have 1 parameter") 116 | } 117 | switch params[0].(type) { 118 | case int64: 119 | return m.getNewInstance().GetReturnType(int32(params[0].(int64)), m.Nullability), nil 120 | case int32: 121 | return m.getNewInstance().GetReturnType(params[0].(int32), m.Nullability), nil 122 | case TimePrecision: 123 | return m.getNewInstance().GetReturnType(int32(params[0].(TimePrecision)), m.Nullability), nil 124 | default: 125 | return nil, fmt.Errorf("unknown parameter type for integer parameter") 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /types/parameterized_single_integer_param_type_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/substrait-io/substrait-go/v4/types" 11 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 12 | ) 13 | 14 | func TestParameterizedSingleIntegerType(t *testing.T) { 15 | abstractLeafParam_L1 := integer_parameters.NewVariableIntParam("L1") 16 | concreteLeafParam_38 := integer_parameters.NewConcreteIntParam(38) 17 | concreteLeafParam_5 := integer_parameters.NewConcreteIntParam(5) 18 | for _, td := range []struct { 19 | name string 20 | typ types.FuncDefArgType 21 | typeParams []interface{} 22 | expectedNullableString string 23 | expectedNullableRequiredString string 24 | expectedShortString string 25 | expectedIsParameterized bool 26 | expectedAbstractParams []interface{} 27 | expectedReturnType types.Type 28 | }{ 29 | {"nullable parameterized varchar", &types.ParameterizedVarCharType{IntegerOption: abstractLeafParam_L1}, []any{int64(11)}, "varchar?", "varchar", "vchar", true, []interface{}{abstractLeafParam_L1}, &types.VarCharType{Length: 11, Nullability: types.NullabilityRequired}}, 30 | {"nullable concrete varchar", &types.ParameterizedVarCharType{IntegerOption: concreteLeafParam_38}, []any{}, "varchar?<38>", "varchar<38>", "vchar", false, nil, &types.VarCharType{Length: 38, Nullability: types.NullabilityRequired}}, 31 | {"nullable fixChar", &types.ParameterizedFixedCharType{IntegerOption: abstractLeafParam_L1}, []any{int64(13)}, "fixedchar?", "fixedchar", "fchar", true, []interface{}{abstractLeafParam_L1}, &types.FixedCharType{Length: 13, Nullability: types.NullabilityRequired}}, 32 | {"nullable concrete fixChar", &types.ParameterizedFixedCharType{IntegerOption: concreteLeafParam_38}, []any{}, "fixedchar?<38>", "fixedchar<38>", "fchar", false, nil, &types.FixedCharType{Length: 38, Nullability: types.NullabilityRequired}}, 33 | {"nullable fixBinary", &types.ParameterizedFixedBinaryType{IntegerOption: abstractLeafParam_L1}, []any{int64(17)}, "fixedbinary?", "fixedbinary", "fbin", true, []interface{}{abstractLeafParam_L1}, &types.FixedBinaryType{Length: 17, Nullability: types.NullabilityRequired}}, 34 | {"nullable concrete fixBinary", &types.ParameterizedFixedBinaryType{IntegerOption: concreteLeafParam_38}, []any{}, "fixedbinary?<38>", "fixedbinary<38>", "fbin", false, nil, &types.FixedBinaryType{Length: 38, Nullability: types.NullabilityRequired}}, 35 | {"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{IntegerOption: abstractLeafParam_L1}, []any{int64(7)}, "precision_timestamp?", "precision_timestamp", "pts", true, []interface{}{abstractLeafParam_L1}, &types.PrecisionTimestampType{Precision: 7, Nullability: types.NullabilityRequired}}, 36 | {"nullable concrete precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{IntegerOption: concreteLeafParam_38}, []any{}, "precision_timestamp?<38>", "precision_timestamp<38>", "pts", false, nil, &types.PrecisionTimestampType{Precision: 38, Nullability: types.NullabilityRequired}}, 37 | {"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: abstractLeafParam_L1}, []any{int64(5)}, "precision_timestamp_tz?", "precision_timestamp_tz", "ptstz", true, []interface{}{abstractLeafParam_L1}, &types.PrecisionTimestampTzType{PrecisionTimestampType: types.PrecisionTimestampType{Precision: 5, Nullability: types.NullabilityRequired}}}, 38 | {"nullable concrete precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: concreteLeafParam_38}, []any{}, "precision_timestamp_tz?<38>", "precision_timestamp_tz<38>", "ptstz", false, nil, &types.PrecisionTimestampTzType{PrecisionTimestampType: types.PrecisionTimestampType{Precision: 38, Nullability: types.NullabilityRequired}}}, 39 | {"nullable interval day", &types.ParameterizedIntervalDayType{IntegerOption: abstractLeafParam_L1}, []any{int64(3)}, "interval_day?", "interval_day", "iday", true, []interface{}{abstractLeafParam_L1}, &types.IntervalDayType{Precision: 3, Nullability: types.NullabilityRequired}}, 40 | {"nullable concrete interval day", &types.ParameterizedIntervalDayType{IntegerOption: concreteLeafParam_5}, []any{}, "interval_day?<5>", "interval_day<5>", "iday", false, nil, &types.IntervalDayType{Precision: 5, Nullability: types.NullabilityRequired}}, 41 | } { 42 | t.Run(td.name, func(t *testing.T) { 43 | require.Equal(t, td.expectedNullableString, td.typ.SetNullability(types.NullabilityNullable).String()) 44 | require.Equal(t, td.expectedNullableRequiredString, td.typ.SetNullability(types.NullabilityRequired).String()) 45 | require.Equal(t, td.expectedIsParameterized, td.typ.HasParameterizedParam()) 46 | require.Equal(t, td.expectedAbstractParams, td.typ.GetParameterizedParams()) 47 | assert.Equal(t, td.expectedShortString, td.typ.ShortString()) 48 | retType, err := td.typ.ReturnType(nil, nil) 49 | if td.expectedIsParameterized { 50 | require.Error(t, err) 51 | require.True(t, td.typ.HasParameterizedParam()) 52 | retType, err = td.typ.ReturnType([]types.FuncDefArgType{td.typ}, []types.Type{td.expectedReturnType}) 53 | require.NoError(t, err) 54 | require.Equal(t, td.expectedReturnType, retType) 55 | } else { 56 | require.Nil(t, err) 57 | require.Equal(t, td.expectedReturnType, retType) 58 | } 59 | resultType, err := td.typ.WithParameters(td.typeParams) 60 | require.Nil(t, err) 61 | require.Equal(t, td.expectedReturnType, resultType) 62 | }) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /types/parameterized_struct_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | // ParameterizedStructType is a parameter type struct 11 | // example: Struct or Struct. 12 | type ParameterizedStructType struct { 13 | Nullability Nullability 14 | TypeVariationRef uint32 15 | Types []FuncDefArgType 16 | } 17 | 18 | func (m *ParameterizedStructType) SetNullability(n Nullability) FuncDefArgType { 19 | m.Nullability = n 20 | return m 21 | } 22 | 23 | func (m *ParameterizedStructType) String() string { 24 | sb := strings.Builder{} 25 | for i, typ := range m.Types { 26 | if i != 0 { 27 | sb.WriteString(", ") 28 | } 29 | sb.WriteString(typ.String()) 30 | } 31 | t := StructType{} 32 | parameterString := fmt.Sprintf("<%s>", sb.String()) 33 | return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) 34 | } 35 | 36 | func (m *ParameterizedStructType) HasParameterizedParam() bool { 37 | for _, typ := range m.Types { 38 | if typ.HasParameterizedParam() { 39 | return true 40 | } 41 | } 42 | return false 43 | } 44 | 45 | func (m *ParameterizedStructType) GetParameterizedParams() []interface{} { 46 | if !m.HasParameterizedParam() { 47 | return nil 48 | } 49 | var abstractParams []interface{} 50 | for _, typ := range m.Types { 51 | if typ.HasParameterizedParam() { 52 | abstractParams = append(abstractParams, typ) 53 | } 54 | } 55 | return abstractParams 56 | } 57 | 58 | func (m *ParameterizedStructType) MatchWithNullability(ot Type) bool { 59 | if m.Nullability != ot.GetNullability() { 60 | return false 61 | } 62 | if omt, ok := ot.(*StructType); ok { 63 | if len(m.Types) != len(omt.Types) { 64 | return false 65 | } 66 | for i, typ := range m.Types { 67 | if !typ.MatchWithNullability(omt.Types[i]) { 68 | return false 69 | } 70 | } 71 | return true 72 | } 73 | return false 74 | } 75 | 76 | func (m *ParameterizedStructType) MatchWithoutNullability(ot Type) bool { 77 | if omt, ok := ot.(*StructType); ok { 78 | if len(m.Types) != len(omt.Types) { 79 | return false 80 | } 81 | for i, typ := range m.Types { 82 | if !typ.MatchWithoutNullability(omt.Types[i]) { 83 | return false 84 | } 85 | } 86 | return true 87 | } 88 | return false 89 | } 90 | 91 | func (m *ParameterizedStructType) GetNullability() Nullability { 92 | return m.Nullability 93 | } 94 | 95 | func (m *ParameterizedStructType) ShortString() string { 96 | return "struct" 97 | } 98 | 99 | func (m *ParameterizedStructType) ReturnType([]FuncDefArgType, []Type) (Type, error) { 100 | var types []Type 101 | for _, typ := range m.Types { 102 | retType, err := typ.ReturnType(nil, nil) 103 | if err != nil { 104 | return nil, fmt.Errorf("error in struct field type: %w", err) 105 | } 106 | types = append(types, retType) 107 | 108 | } 109 | return &StructType{Nullability: m.Nullability, Types: types}, nil 110 | } 111 | 112 | func (m *ParameterizedStructType) WithParameters(params []interface{}) (Type, error) { 113 | if len(params) != len(m.Types) { 114 | return nil, fmt.Errorf("expected %d parameters, got %d", len(m.Types), len(params)) 115 | } 116 | var types []Type 117 | for i, typ := range m.Types { 118 | t, ok := params[i].(Type) 119 | if !ok { 120 | return nil, fmt.Errorf("expected parameter to be of type Type, got %T", params[i]) 121 | } 122 | itype, err := typ.WithParameters(t.GetParameters()) 123 | if err != nil { 124 | return nil, err 125 | } 126 | types = append(types, itype) 127 | } 128 | return &StructType{Nullability: m.Nullability, Types: types}, nil 129 | } 130 | -------------------------------------------------------------------------------- /types/parameterized_struct_type_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/substrait-io/substrait-go/v4/types" 11 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 12 | ) 13 | 14 | func TestParameterizedStructType(t *testing.T) { 15 | decimalType := &types.ParameterizedDecimalType{ 16 | Precision: integer_parameters.NewVariableIntParam("P"), 17 | Scale: integer_parameters.NewVariableIntParam("S"), 18 | Nullability: types.NullabilityRequired, 19 | } 20 | int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} 21 | listParameterizedType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} 22 | dec30PS5 := &types.DecimalType{Precision: 30, Scale: 5, Nullability: types.NullabilityRequired} 23 | dec30PS9 := &types.DecimalType{Precision: 30, Scale: 9, Nullability: types.NullabilityRequired} 24 | listType := &types.ListType{Type: dec30PS9, Nullability: types.NullabilityRequired} 25 | for _, td := range []struct { 26 | name string 27 | params []types.FuncDefArgType 28 | args []interface{} 29 | expectedNullableString string 30 | expectedNullableRequiredString string 31 | expectedHasParameterizedParam bool 32 | expectedParameterizedParams []interface{} 33 | expectedReturnType types.Type 34 | }{ 35 | {"all parameterized param", []types.FuncDefArgType{decimalType, listParameterizedType}, []any{dec30PS5, listType}, "struct?, list?>>", "struct, list?>>", true, []interface{}{decimalType, listParameterizedType}, nil}, 36 | {"mix parameterized concrete param", []types.FuncDefArgType{decimalType, int8Type, listParameterizedType}, []any{dec30PS9, int8Type, listType}, "struct?, i8?, list?>>", "struct, i8?, list?>>", true, []interface{}{decimalType, listParameterizedType}, nil}, 37 | {"all concrete param", []types.FuncDefArgType{int8Type, int8Type, int8Type}, []any{int8Type, int8Type, int8Type}, "struct?", "struct", false, nil, &types.StructType{Nullability: types.NullabilityRequired, Types: []types.Type{int8Type, int8Type, int8Type}}}, 38 | } { 39 | t.Run(td.name, func(t *testing.T) { 40 | pd := &types.ParameterizedStructType{Types: td.params} 41 | require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) 42 | require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) 43 | require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) 44 | require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) 45 | assert.Equal(t, "struct", pd.ShortString()) 46 | retType, err := pd.ReturnType(nil, nil) 47 | if td.expectedReturnType == nil { 48 | assert.Error(t, err) 49 | require.True(t, pd.HasParameterizedParam()) 50 | } else { 51 | require.Nil(t, err) 52 | require.Equal(t, td.expectedReturnType, retType) 53 | got, err := pd.WithParameters(td.args) 54 | require.Nil(t, err) 55 | require.Equal(t, td.expectedReturnType, got) 56 | } 57 | }) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /types/parameterized_user_defined_type.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | type UDTParameter interface { 11 | isTypeParameter() bool 12 | String() string 13 | toTypeParam() (TypeParam, error) 14 | MatchWithoutNullability(param TypeParam) bool 15 | MatchWithNullability(param TypeParam) bool 16 | } 17 | 18 | type DataTypeUDTParam struct { 19 | Type FuncDefArgType 20 | } 21 | 22 | func (d *DataTypeUDTParam) isTypeParameter() bool { return true } 23 | 24 | func (d *DataTypeUDTParam) String() string { 25 | return d.Type.String() 26 | } 27 | 28 | func (d *DataTypeUDTParam) toTypeParam() (TypeParam, error) { 29 | typ, err := d.Type.ReturnType(nil, nil) 30 | if err != nil { 31 | return nil, err 32 | } 33 | return &DataTypeParameter{Type: typ}, nil 34 | } 35 | 36 | func (d *DataTypeUDTParam) MatchWithNullability(param TypeParam) bool { 37 | if d.MatchWithoutNullability(param) { 38 | if dataParam, ok := param.(*DataTypeParameter); ok { 39 | return d.Type.GetNullability() == dataParam.Type.GetNullability() 40 | } 41 | } 42 | return false 43 | } 44 | 45 | func (d *DataTypeUDTParam) MatchWithoutNullability(param TypeParam) bool { 46 | if dataParam, ok := param.(*DataTypeParameter); ok { 47 | return d.Type.MatchWithoutNullability(dataParam.Type) 48 | } 49 | return false 50 | } 51 | 52 | type IntegerUDTParam struct { 53 | Integer int32 54 | } 55 | 56 | func (i *IntegerUDTParam) isTypeParameter() bool { return true } 57 | 58 | func (i *IntegerUDTParam) String() string { 59 | return fmt.Sprintf("%d", i.Integer) 60 | } 61 | 62 | func (i *IntegerUDTParam) toTypeParam() (TypeParam, error) { 63 | return IntegerParameter(i.Integer), nil 64 | } 65 | 66 | func (i *IntegerUDTParam) MatchWithoutNullability(param TypeParam) bool { 67 | if intParam, ok := param.(IntegerParameter); ok { 68 | return i.Integer == int32(intParam) 69 | } 70 | return false 71 | } 72 | 73 | func (i *IntegerUDTParam) MatchWithNullability(param TypeParam) bool { 74 | return i.MatchWithoutNullability(param) 75 | } 76 | 77 | type StringUDTParam struct { 78 | StringVal string 79 | } 80 | 81 | func (s *StringUDTParam) isTypeParameter() bool { return true } 82 | 83 | func (s *StringUDTParam) String() string { 84 | return s.StringVal 85 | } 86 | 87 | func (s *StringUDTParam) toTypeParam() (TypeParam, error) { 88 | return StringParameter(s.StringVal), nil 89 | } 90 | 91 | func (s *StringUDTParam) MatchWithoutNullability(param TypeParam) bool { 92 | if strParam, ok := param.(StringParameter); ok { 93 | return s.StringVal == string(strParam) 94 | } 95 | return false 96 | } 97 | 98 | func (s *StringUDTParam) MatchWithNullability(param TypeParam) bool { 99 | return s.MatchWithoutNullability(param) 100 | } 101 | 102 | // ParameterizedUserDefinedType is a parameter type struct 103 | // example: U!point or U!square. 104 | type ParameterizedUserDefinedType struct { 105 | Nullability Nullability 106 | TypeVariationRef uint32 107 | TypeParameters []UDTParameter 108 | Name string 109 | } 110 | 111 | func (m *ParameterizedUserDefinedType) SetNullability(n Nullability) FuncDefArgType { 112 | m.Nullability = n 113 | return m 114 | } 115 | 116 | func (m *ParameterizedUserDefinedType) String() string { 117 | var parameterString string 118 | if len(m.TypeParameters) > 0 { 119 | sb := strings.Builder{} 120 | for i, typ := range m.TypeParameters { 121 | if i != 0 { 122 | sb.WriteString(", ") 123 | } 124 | sb.WriteString(typ.String()) 125 | } 126 | parameterString = fmt.Sprintf("<%s>", sb.String()) 127 | } 128 | return fmt.Sprintf("u!%s%s%s", m.Name, strFromNullability(m.Nullability), parameterString) 129 | } 130 | 131 | func (m *ParameterizedUserDefinedType) HasParameterizedParam() bool { 132 | for _, typ := range m.TypeParameters { 133 | if param, ok := typ.(*DataTypeUDTParam); ok { 134 | return param.Type.HasParameterizedParam() 135 | } 136 | } 137 | return false 138 | } 139 | 140 | func (m *ParameterizedUserDefinedType) GetParameterizedParams() []interface{} { 141 | if !m.HasParameterizedParam() { 142 | return nil 143 | } 144 | var abstractParams []interface{} 145 | for _, typ := range m.TypeParameters { 146 | if param, ok := typ.(*DataTypeUDTParam); ok && param.Type.HasParameterizedParam() { 147 | abstractParams = append(abstractParams, typ) 148 | } 149 | } 150 | return abstractParams 151 | } 152 | 153 | func (m *ParameterizedUserDefinedType) MatchWithNullability(ot Type) bool { 154 | if m.Nullability != ot.GetNullability() { 155 | return false 156 | } 157 | if udt, ok := ot.(*UserDefinedType); ok { 158 | if len(m.TypeParameters) != len(udt.TypeParameters) { 159 | return false 160 | } 161 | for i, param := range m.TypeParameters { 162 | if !param.MatchWithNullability(udt.TypeParameters[i]) { 163 | return false 164 | } 165 | } 166 | return true 167 | } 168 | return false 169 | } 170 | 171 | func (m *ParameterizedUserDefinedType) MatchWithoutNullability(ot Type) bool { 172 | if udt, ok := ot.(*UserDefinedType); ok { 173 | if len(m.TypeParameters) != len(udt.TypeParameters) { 174 | return false 175 | } 176 | for i, param := range m.TypeParameters { 177 | if !param.MatchWithoutNullability(udt.TypeParameters[i]) { 178 | return false 179 | } 180 | } 181 | return true 182 | } 183 | return false 184 | } 185 | 186 | func (m *ParameterizedUserDefinedType) GetNullability() Nullability { 187 | return m.Nullability 188 | } 189 | 190 | func (m *ParameterizedUserDefinedType) ShortString() string { 191 | return fmt.Sprintf("u!%s", m.Name) 192 | } 193 | 194 | func (m *ParameterizedUserDefinedType) ReturnType([]FuncDefArgType, []Type) (Type, error) { 195 | var types []TypeParam 196 | for _, udtParam := range m.TypeParameters { 197 | param, err := udtParam.toTypeParam() 198 | if err != nil { 199 | return nil, err 200 | } 201 | types = append(types, param) 202 | } 203 | return &UserDefinedType{Nullability: m.Nullability, TypeParameters: types}, nil 204 | } 205 | 206 | func (m *ParameterizedUserDefinedType) WithParameters(params []interface{}) (Type, error) { 207 | if len(params) != len(m.TypeParameters) { 208 | return nil, fmt.Errorf("expected %d parameters, got %d", len(m.TypeParameters), len(params)) 209 | } 210 | var typeParams []TypeParam 211 | for i, param := range params { 212 | if p, ok := param.(TypeParam); ok { 213 | typeParams = append(typeParams, p) 214 | continue 215 | } 216 | return nil, fmt.Errorf("unexpected type %T for parameter %d", param, i) 217 | } 218 | return &UserDefinedType{ 219 | Nullability: m.Nullability, 220 | TypeVariationRef: m.TypeVariationRef, 221 | TypeParameters: typeParams, 222 | }, nil 223 | } 224 | -------------------------------------------------------------------------------- /types/parameterized_user_defined_type_test.go: -------------------------------------------------------------------------------- 1 | package types_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | . "github.com/substrait-io/substrait-go/v4/types" 10 | "github.com/substrait-io/substrait-go/v4/types/integer_parameters" 11 | ) 12 | 13 | func TestParameterizedUserDefinedType(t *testing.T) { 14 | decimalType := &ParameterizedDecimalType{ 15 | Precision: integer_parameters.NewVariableIntParam("P"), 16 | Scale: integer_parameters.NewVariableIntParam("S"), 17 | Nullability: NullabilityRequired, 18 | } 19 | int8Type := &Int8Type{Nullability: NullabilityNullable} 20 | //userDefineType := &ParameterizedUserDefinedType{TypeParameters: []UDTParameter{}, Nullability: NullabilityNullable} 21 | for _, td := range []struct { 22 | name string 23 | Params []UDTParameter 24 | Args []interface{} 25 | expectedNullableString string 26 | expectedNullableRequiredString string 27 | expectedHasParameterizedParam bool 28 | expectedParameterizedParams []any 29 | expectedReturnType Type 30 | }{ 31 | {"udt_noparam", []UDTParameter{}, []any{}, "u!udt_noparam?", "u!udt_noparam", false, nil, &UserDefinedType{Nullability: NullabilityRequired}}, 32 | {"concrete_udt", []UDTParameter{&DataTypeUDTParam{int8Type}}, []any{&DataTypeParameter{int8Type}}, "u!concrete_udt?", "u!concrete_udt", false, nil, &UserDefinedType{Nullability: NullabilityRequired, TypeParameters: []TypeParam{&DataTypeParameter{Type: int8Type}}}}, 33 | {"variable_udt", []UDTParameter{&DataTypeUDTParam{decimalType}}, []any{}, "u!variable_udt?>", "u!variable_udt>", true, []any{&DataTypeUDTParam{decimalType}}, nil}, 34 | {"udt_with_int", []UDTParameter{&IntegerUDTParam{Integer: 10}}, []any{IntegerParameter(10)}, "u!udt_with_int?<10>", "u!udt_with_int<10>", false, nil, &UserDefinedType{Nullability: NullabilityRequired, TypeParameters: []TypeParam{IntegerParameter(10)}}}, 35 | {"udt_with_str", []UDTParameter{&StringUDTParam{StringVal: "test"}}, []any{StringParameter("test")}, "u!udt_with_str?", "u!udt_with_str", false, nil, &UserDefinedType{Nullability: NullabilityRequired, TypeParameters: []TypeParam{StringParameter("test")}}}, 36 | {"udt_with_int_and_str", []UDTParameter{&IntegerUDTParam{Integer: 10}, &StringUDTParam{StringVal: "test"}}, []any{IntegerParameter(10), StringParameter("test")}, "u!udt_with_int_and_str?<10, test>", "u!udt_with_int_and_str<10, test>", false, nil, &UserDefinedType{Nullability: NullabilityRequired, TypeParameters: []TypeParam{IntegerParameter(10), StringParameter("test")}}}, 37 | } { 38 | t.Run(td.name, func(t *testing.T) { 39 | pd := &ParameterizedUserDefinedType{TypeParameters: td.Params, Name: td.name} 40 | assert.Equal(t, NullabilityUnspecified, pd.GetNullability()) 41 | require.Equal(t, td.expectedNullableString, pd.SetNullability(NullabilityNullable).String()) 42 | assert.Equal(t, NullabilityNullable, pd.GetNullability()) 43 | require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(NullabilityRequired).String()) 44 | assert.Equal(t, NullabilityRequired, pd.GetNullability()) 45 | require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) 46 | require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) 47 | assert.Equal(t, fmt.Sprintf("u!%s", td.name), pd.ShortString()) 48 | 49 | retType, err := pd.ReturnType(nil, nil) 50 | if td.expectedReturnType == nil { 51 | assert.Error(t, err) 52 | require.True(t, pd.HasParameterizedParam()) 53 | } else { 54 | require.Nil(t, err) 55 | require.Equal(t, td.expectedReturnType, retType) 56 | resultType, err := pd.WithParameters(td.Args) 57 | require.Nil(t, err) 58 | require.Equal(t, td.expectedReturnType, resultType) 59 | } 60 | }) 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /types/parser/baseparser/README.md: -------------------------------------------------------------------------------- 1 | ## Type Parsing 2 | 3 | This folder contains the parser code for the Substrait types. 4 | The parser is generated using the ANTLR4 tool. The parser reads a Substrait type string and returns a `Type` object. 5 | The parser is generated using the SubstraitLexer.g4 and SubstraitType.g4 files from https://github.com/substrait-io/substrait/blob/main/grammar. 6 | 7 | ### Steps to regenerate the parser code 8 | 9 | Whenever the grammar files are updated in Substrait repo, the parser code must be regenerated. To do this follow below steps: 10 | 11 | #### Step1: 12 | Update the `generate.go` file in the `grammar` folder at the root of the repository to pull the new grammar files. 13 | 14 | ``` 15 | //go:generate wget https://raw.githubusercontent.com/substrait-io/substrait//grammar/SubstraitLexer.g4 16 | //go:generate wget https://raw.githubusercontent.com/substrait-io/substrait//grammar/SubstraitType.g4 17 | ``` 18 | Replace `` with the commit hash of the Substrait repo that contains the updated grammar files. 19 | 20 | #### Step2: 21 | To generate the parser code, run the following command in the root of the repository: 22 | 23 | ``` 24 | go generate ./grammar/... 25 | ``` 26 | -------------------------------------------------------------------------------- /types/parser/baseparser/substraittype_base_visitor.go: -------------------------------------------------------------------------------- 1 | // Code generated from SubstraitType.g4 by ANTLR 4.13.2. DO NOT EDIT. 2 | 3 | package baseparser // SubstraitType 4 | import "github.com/antlr4-go/antlr/v4" 5 | 6 | type BaseSubstraitTypeVisitor struct { 7 | *antlr.BaseParseTreeVisitor 8 | } 9 | 10 | func (v *BaseSubstraitTypeVisitor) VisitStartRule(ctx *StartRuleContext) interface{} { 11 | return v.VisitChildren(ctx) 12 | } 13 | 14 | func (v *BaseSubstraitTypeVisitor) VisitTypeStatement(ctx *TypeStatementContext) interface{} { 15 | return v.VisitChildren(ctx) 16 | } 17 | 18 | func (v *BaseSubstraitTypeVisitor) VisitBoolean(ctx *BooleanContext) interface{} { 19 | return v.VisitChildren(ctx) 20 | } 21 | 22 | func (v *BaseSubstraitTypeVisitor) VisitI8(ctx *I8Context) interface{} { 23 | return v.VisitChildren(ctx) 24 | } 25 | 26 | func (v *BaseSubstraitTypeVisitor) VisitI16(ctx *I16Context) interface{} { 27 | return v.VisitChildren(ctx) 28 | } 29 | 30 | func (v *BaseSubstraitTypeVisitor) VisitI32(ctx *I32Context) interface{} { 31 | return v.VisitChildren(ctx) 32 | } 33 | 34 | func (v *BaseSubstraitTypeVisitor) VisitI64(ctx *I64Context) interface{} { 35 | return v.VisitChildren(ctx) 36 | } 37 | 38 | func (v *BaseSubstraitTypeVisitor) VisitFp32(ctx *Fp32Context) interface{} { 39 | return v.VisitChildren(ctx) 40 | } 41 | 42 | func (v *BaseSubstraitTypeVisitor) VisitFp64(ctx *Fp64Context) interface{} { 43 | return v.VisitChildren(ctx) 44 | } 45 | 46 | func (v *BaseSubstraitTypeVisitor) VisitString(ctx *StringContext) interface{} { 47 | return v.VisitChildren(ctx) 48 | } 49 | 50 | func (v *BaseSubstraitTypeVisitor) VisitBinary(ctx *BinaryContext) interface{} { 51 | return v.VisitChildren(ctx) 52 | } 53 | 54 | func (v *BaseSubstraitTypeVisitor) VisitTimestamp(ctx *TimestampContext) interface{} { 55 | return v.VisitChildren(ctx) 56 | } 57 | 58 | func (v *BaseSubstraitTypeVisitor) VisitTimestampTz(ctx *TimestampTzContext) interface{} { 59 | return v.VisitChildren(ctx) 60 | } 61 | 62 | func (v *BaseSubstraitTypeVisitor) VisitDate(ctx *DateContext) interface{} { 63 | return v.VisitChildren(ctx) 64 | } 65 | 66 | func (v *BaseSubstraitTypeVisitor) VisitTime(ctx *TimeContext) interface{} { 67 | return v.VisitChildren(ctx) 68 | } 69 | 70 | func (v *BaseSubstraitTypeVisitor) VisitIntervalYear(ctx *IntervalYearContext) interface{} { 71 | return v.VisitChildren(ctx) 72 | } 73 | 74 | func (v *BaseSubstraitTypeVisitor) VisitUuid(ctx *UuidContext) interface{} { 75 | return v.VisitChildren(ctx) 76 | } 77 | 78 | func (v *BaseSubstraitTypeVisitor) VisitFixedChar(ctx *FixedCharContext) interface{} { 79 | return v.VisitChildren(ctx) 80 | } 81 | 82 | func (v *BaseSubstraitTypeVisitor) VisitVarChar(ctx *VarCharContext) interface{} { 83 | return v.VisitChildren(ctx) 84 | } 85 | 86 | func (v *BaseSubstraitTypeVisitor) VisitFixedBinary(ctx *FixedBinaryContext) interface{} { 87 | return v.VisitChildren(ctx) 88 | } 89 | 90 | func (v *BaseSubstraitTypeVisitor) VisitDecimal(ctx *DecimalContext) interface{} { 91 | return v.VisitChildren(ctx) 92 | } 93 | 94 | func (v *BaseSubstraitTypeVisitor) VisitPrecisionIntervalDay(ctx *PrecisionIntervalDayContext) interface{} { 95 | return v.VisitChildren(ctx) 96 | } 97 | 98 | func (v *BaseSubstraitTypeVisitor) VisitPrecisionTimestamp(ctx *PrecisionTimestampContext) interface{} { 99 | return v.VisitChildren(ctx) 100 | } 101 | 102 | func (v *BaseSubstraitTypeVisitor) VisitPrecisionTimestampTZ(ctx *PrecisionTimestampTZContext) interface{} { 103 | return v.VisitChildren(ctx) 104 | } 105 | 106 | func (v *BaseSubstraitTypeVisitor) VisitStruct(ctx *StructContext) interface{} { 107 | return v.VisitChildren(ctx) 108 | } 109 | 110 | func (v *BaseSubstraitTypeVisitor) VisitNStruct(ctx *NStructContext) interface{} { 111 | return v.VisitChildren(ctx) 112 | } 113 | 114 | func (v *BaseSubstraitTypeVisitor) VisitList(ctx *ListContext) interface{} { 115 | return v.VisitChildren(ctx) 116 | } 117 | 118 | func (v *BaseSubstraitTypeVisitor) VisitMap(ctx *MapContext) interface{} { 119 | return v.VisitChildren(ctx) 120 | } 121 | 122 | func (v *BaseSubstraitTypeVisitor) VisitUserDefined(ctx *UserDefinedContext) interface{} { 123 | return v.VisitChildren(ctx) 124 | } 125 | 126 | func (v *BaseSubstraitTypeVisitor) VisitNumericLiteral(ctx *NumericLiteralContext) interface{} { 127 | return v.VisitChildren(ctx) 128 | } 129 | 130 | func (v *BaseSubstraitTypeVisitor) VisitNumericParameterName(ctx *NumericParameterNameContext) interface{} { 131 | return v.VisitChildren(ctx) 132 | } 133 | 134 | func (v *BaseSubstraitTypeVisitor) VisitNumericExpression(ctx *NumericExpressionContext) interface{} { 135 | return v.VisitChildren(ctx) 136 | } 137 | 138 | func (v *BaseSubstraitTypeVisitor) VisitAnyType(ctx *AnyTypeContext) interface{} { 139 | return v.VisitChildren(ctx) 140 | } 141 | 142 | func (v *BaseSubstraitTypeVisitor) VisitTypeDef(ctx *TypeDefContext) interface{} { 143 | return v.VisitChildren(ctx) 144 | } 145 | 146 | func (v *BaseSubstraitTypeVisitor) VisitIfExpr(ctx *IfExprContext) interface{} { 147 | return v.VisitChildren(ctx) 148 | } 149 | 150 | func (v *BaseSubstraitTypeVisitor) VisitTypeLiteral(ctx *TypeLiteralContext) interface{} { 151 | return v.VisitChildren(ctx) 152 | } 153 | 154 | func (v *BaseSubstraitTypeVisitor) VisitMultilineDefinition(ctx *MultilineDefinitionContext) interface{} { 155 | return v.VisitChildren(ctx) 156 | } 157 | 158 | func (v *BaseSubstraitTypeVisitor) VisitTernary(ctx *TernaryContext) interface{} { 159 | return v.VisitChildren(ctx) 160 | } 161 | 162 | func (v *BaseSubstraitTypeVisitor) VisitBinaryExpr(ctx *BinaryExprContext) interface{} { 163 | return v.VisitChildren(ctx) 164 | } 165 | 166 | func (v *BaseSubstraitTypeVisitor) VisitParenExpression(ctx *ParenExpressionContext) interface{} { 167 | return v.VisitChildren(ctx) 168 | } 169 | 170 | func (v *BaseSubstraitTypeVisitor) VisitParameterName(ctx *ParameterNameContext) interface{} { 171 | return v.VisitChildren(ctx) 172 | } 173 | 174 | func (v *BaseSubstraitTypeVisitor) VisitFunctionCall(ctx *FunctionCallContext) interface{} { 175 | return v.VisitChildren(ctx) 176 | } 177 | 178 | func (v *BaseSubstraitTypeVisitor) VisitNotExpr(ctx *NotExprContext) interface{} { 179 | return v.VisitChildren(ctx) 180 | } 181 | 182 | func (v *BaseSubstraitTypeVisitor) VisitLiteralNumber(ctx *LiteralNumberContext) interface{} { 183 | return v.VisitChildren(ctx) 184 | } 185 | -------------------------------------------------------------------------------- /types/parser/baseparser/substraittype_visitor.go: -------------------------------------------------------------------------------- 1 | // Code generated from SubstraitType.g4 by ANTLR 4.13.2. DO NOT EDIT. 2 | 3 | package baseparser // SubstraitType 4 | import "github.com/antlr4-go/antlr/v4" 5 | 6 | // A complete Visitor for a parse tree produced by SubstraitTypeParser. 7 | type SubstraitTypeVisitor interface { 8 | antlr.ParseTreeVisitor 9 | 10 | // Visit a parse tree produced by SubstraitTypeParser#startRule. 11 | VisitStartRule(ctx *StartRuleContext) interface{} 12 | 13 | // Visit a parse tree produced by SubstraitTypeParser#typeStatement. 14 | VisitTypeStatement(ctx *TypeStatementContext) interface{} 15 | 16 | // Visit a parse tree produced by SubstraitTypeParser#boolean. 17 | VisitBoolean(ctx *BooleanContext) interface{} 18 | 19 | // Visit a parse tree produced by SubstraitTypeParser#i8. 20 | VisitI8(ctx *I8Context) interface{} 21 | 22 | // Visit a parse tree produced by SubstraitTypeParser#i16. 23 | VisitI16(ctx *I16Context) interface{} 24 | 25 | // Visit a parse tree produced by SubstraitTypeParser#i32. 26 | VisitI32(ctx *I32Context) interface{} 27 | 28 | // Visit a parse tree produced by SubstraitTypeParser#i64. 29 | VisitI64(ctx *I64Context) interface{} 30 | 31 | // Visit a parse tree produced by SubstraitTypeParser#fp32. 32 | VisitFp32(ctx *Fp32Context) interface{} 33 | 34 | // Visit a parse tree produced by SubstraitTypeParser#fp64. 35 | VisitFp64(ctx *Fp64Context) interface{} 36 | 37 | // Visit a parse tree produced by SubstraitTypeParser#string. 38 | VisitString(ctx *StringContext) interface{} 39 | 40 | // Visit a parse tree produced by SubstraitTypeParser#binary. 41 | VisitBinary(ctx *BinaryContext) interface{} 42 | 43 | // Visit a parse tree produced by SubstraitTypeParser#timestamp. 44 | VisitTimestamp(ctx *TimestampContext) interface{} 45 | 46 | // Visit a parse tree produced by SubstraitTypeParser#timestampTz. 47 | VisitTimestampTz(ctx *TimestampTzContext) interface{} 48 | 49 | // Visit a parse tree produced by SubstraitTypeParser#date. 50 | VisitDate(ctx *DateContext) interface{} 51 | 52 | // Visit a parse tree produced by SubstraitTypeParser#time. 53 | VisitTime(ctx *TimeContext) interface{} 54 | 55 | // Visit a parse tree produced by SubstraitTypeParser#intervalYear. 56 | VisitIntervalYear(ctx *IntervalYearContext) interface{} 57 | 58 | // Visit a parse tree produced by SubstraitTypeParser#uuid. 59 | VisitUuid(ctx *UuidContext) interface{} 60 | 61 | // Visit a parse tree produced by SubstraitTypeParser#fixedChar. 62 | VisitFixedChar(ctx *FixedCharContext) interface{} 63 | 64 | // Visit a parse tree produced by SubstraitTypeParser#varChar. 65 | VisitVarChar(ctx *VarCharContext) interface{} 66 | 67 | // Visit a parse tree produced by SubstraitTypeParser#fixedBinary. 68 | VisitFixedBinary(ctx *FixedBinaryContext) interface{} 69 | 70 | // Visit a parse tree produced by SubstraitTypeParser#decimal. 71 | VisitDecimal(ctx *DecimalContext) interface{} 72 | 73 | // Visit a parse tree produced by SubstraitTypeParser#precisionIntervalDay. 74 | VisitPrecisionIntervalDay(ctx *PrecisionIntervalDayContext) interface{} 75 | 76 | // Visit a parse tree produced by SubstraitTypeParser#precisionTimestamp. 77 | VisitPrecisionTimestamp(ctx *PrecisionTimestampContext) interface{} 78 | 79 | // Visit a parse tree produced by SubstraitTypeParser#precisionTimestampTZ. 80 | VisitPrecisionTimestampTZ(ctx *PrecisionTimestampTZContext) interface{} 81 | 82 | // Visit a parse tree produced by SubstraitTypeParser#struct. 83 | VisitStruct(ctx *StructContext) interface{} 84 | 85 | // Visit a parse tree produced by SubstraitTypeParser#nStruct. 86 | VisitNStruct(ctx *NStructContext) interface{} 87 | 88 | // Visit a parse tree produced by SubstraitTypeParser#list. 89 | VisitList(ctx *ListContext) interface{} 90 | 91 | // Visit a parse tree produced by SubstraitTypeParser#map. 92 | VisitMap(ctx *MapContext) interface{} 93 | 94 | // Visit a parse tree produced by SubstraitTypeParser#userDefined. 95 | VisitUserDefined(ctx *UserDefinedContext) interface{} 96 | 97 | // Visit a parse tree produced by SubstraitTypeParser#numericLiteral. 98 | VisitNumericLiteral(ctx *NumericLiteralContext) interface{} 99 | 100 | // Visit a parse tree produced by SubstraitTypeParser#numericParameterName. 101 | VisitNumericParameterName(ctx *NumericParameterNameContext) interface{} 102 | 103 | // Visit a parse tree produced by SubstraitTypeParser#numericExpression. 104 | VisitNumericExpression(ctx *NumericExpressionContext) interface{} 105 | 106 | // Visit a parse tree produced by SubstraitTypeParser#anyType. 107 | VisitAnyType(ctx *AnyTypeContext) interface{} 108 | 109 | // Visit a parse tree produced by SubstraitTypeParser#typeDef. 110 | VisitTypeDef(ctx *TypeDefContext) interface{} 111 | 112 | // Visit a parse tree produced by SubstraitTypeParser#IfExpr. 113 | VisitIfExpr(ctx *IfExprContext) interface{} 114 | 115 | // Visit a parse tree produced by SubstraitTypeParser#TypeLiteral. 116 | VisitTypeLiteral(ctx *TypeLiteralContext) interface{} 117 | 118 | // Visit a parse tree produced by SubstraitTypeParser#MultilineDefinition. 119 | VisitMultilineDefinition(ctx *MultilineDefinitionContext) interface{} 120 | 121 | // Visit a parse tree produced by SubstraitTypeParser#Ternary. 122 | VisitTernary(ctx *TernaryContext) interface{} 123 | 124 | // Visit a parse tree produced by SubstraitTypeParser#BinaryExpr. 125 | VisitBinaryExpr(ctx *BinaryExprContext) interface{} 126 | 127 | // Visit a parse tree produced by SubstraitTypeParser#ParenExpression. 128 | VisitParenExpression(ctx *ParenExpressionContext) interface{} 129 | 130 | // Visit a parse tree produced by SubstraitTypeParser#ParameterName. 131 | VisitParameterName(ctx *ParameterNameContext) interface{} 132 | 133 | // Visit a parse tree produced by SubstraitTypeParser#FunctionCall. 134 | VisitFunctionCall(ctx *FunctionCallContext) interface{} 135 | 136 | // Visit a parse tree produced by SubstraitTypeParser#NotExpr. 137 | VisitNotExpr(ctx *NotExprContext) interface{} 138 | 139 | // Visit a parse tree produced by SubstraitTypeParser#LiteralNumber. 140 | VisitLiteralNumber(ctx *LiteralNumberContext) interface{} 141 | } 142 | -------------------------------------------------------------------------------- /types/parser/parse.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/antlr4-go/antlr/v4" 7 | substraitgo "github.com/substrait-io/substrait-go/v4" 8 | "github.com/substrait-io/substrait-go/v4/types" 9 | baseparser2 "github.com/substrait-io/substrait-go/v4/types/parser/baseparser" 10 | "github.com/substrait-io/substrait-go/v4/types/parser/util" 11 | ) 12 | 13 | type TypeExpression struct { 14 | ValueType types.FuncDefArgType 15 | } 16 | 17 | func (t *TypeExpression) MarshalYAML() (interface{}, error) { 18 | return t.ValueType.String(), nil 19 | } 20 | 21 | func (t *TypeExpression) UnmarshalYAML(fn func(interface{}) error) error { 22 | type Alias any 23 | var alias Alias 24 | if err := fn(&alias); err != nil { 25 | return err 26 | } 27 | 28 | switch v := alias.(type) { 29 | case string: 30 | exp, err := ParseType(v) 31 | if err != nil { 32 | return err 33 | } 34 | t.ValueType = exp 35 | return nil 36 | } 37 | 38 | return substraitgo.ErrNotImplemented 39 | } 40 | 41 | func ParseType(input string) (types.FuncDefArgType, error) { 42 | is := antlr.NewInputStream(input) 43 | lexer := baseparser2.NewSubstraitTypeLexer(is) 44 | stream := antlr.NewCommonTokenStream(lexer, 0) 45 | p := baseparser2.NewSubstraitTypeParser(stream) 46 | errorListener := util.NewSimpleErrorListener() 47 | p.AddErrorListener(errorListener) 48 | p.GetInterpreter().SetPredictionMode(antlr.PredictionModeSLL) 49 | 50 | visitor := &TypeVisitor{} 51 | ret, err := parseType(input, p, errorListener, visitor) 52 | if err != nil { 53 | return nil, err 54 | } 55 | if errorListener.ErrorCount() > 0 { 56 | return nil, fmt.Errorf("error parsing input '%s': %s", input, errorListener.GetErrors()) 57 | } 58 | retType, ok := ret.(types.FuncDefArgType) 59 | if !ok { 60 | return nil, fmt.Errorf("failed to parse %s as FuncDefArgType", input) 61 | } 62 | return retType, nil 63 | } 64 | 65 | func parseType(input string, p *baseparser2.SubstraitTypeParser, errorListener *util.SimpleErrorListener, visitor *TypeVisitor) (any, error) { 66 | var err error 67 | defer util.TransformPanicToError(&err, input, "ParseExpr", errorListener) 68 | context := p.StartRule() 69 | if errorListener.ErrorCount() > 0 { 70 | fmt.Printf("ParseTree: %v", antlr.TreesStringTree(context, []string{}, p)) 71 | return nil, fmt.Errorf("error parsing input '%s': %s", input, errorListener.GetErrors()) 72 | } 73 | ret := visitor.Visit(context) 74 | 75 | return ret, err 76 | } 77 | -------------------------------------------------------------------------------- /types/parser/util/error_listener.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/antlr4-go/antlr/v4" 7 | ) 8 | 9 | type VisitErrorListener interface { 10 | ReportVisitError(ctx antlr.ParserRuleContext, err error) 11 | ReportPanicError(err error) 12 | ErrorCount() int 13 | GetErrors() []string 14 | } 15 | 16 | type SimpleErrorListener struct { 17 | errorCount int 18 | errors []string 19 | } 20 | 21 | func (l *SimpleErrorListener) ReportVisitError(ctx antlr.ParserRuleContext, err error) { 22 | l.errorCount++ 23 | l.errors = append(l.errors, fmt.Sprintf("Visit error at line %d: %s", ctx.GetStart().GetLine(), err)) 24 | } 25 | 26 | func (l *SimpleErrorListener) ReportPanicError(err error) { 27 | l.errorCount++ 28 | l.errors = append(l.errors, fmt.Sprintf("Tree Visit panic error %s", err)) 29 | } 30 | 31 | func (l *SimpleErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { 32 | l.errorCount++ 33 | l.errors = append(l.errors, fmt.Sprintf("Syntax error at line %d:%d: %s ", line, column, msg)) 34 | } 35 | 36 | func (l *SimpleErrorListener) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs *antlr.ATNConfigSet) { 37 | } 38 | 39 | func (l *SimpleErrorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) { 40 | } 41 | 42 | func (l *SimpleErrorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) { 43 | } 44 | 45 | func (l *SimpleErrorListener) ErrorCount() int { 46 | return l.errorCount 47 | } 48 | 49 | func (l *SimpleErrorListener) GetErrors() []string { 50 | return l.errors 51 | } 52 | 53 | func NewSimpleErrorListener() *SimpleErrorListener { 54 | return new(SimpleErrorListener) 55 | } 56 | 57 | func TransformPanicToError(err *error, input, ctxStr string, errorListener VisitErrorListener) { 58 | if r := recover(); r != nil { 59 | switch t := r.(type) { 60 | case string: 61 | *err = fmt.Errorf("failed %s %s with error: %s", ctxStr, input, t) 62 | case error: 63 | *err = t 64 | default: 65 | *err = fmt.Errorf("failed %s %s with unknown panic", ctxStr, input) 66 | } 67 | if errorListener != nil { 68 | errorListener.ReportPanicError(*err) 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /types/precison_timestamp_types.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | "reflect" 8 | "time" 9 | 10 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 11 | ) 12 | 13 | // TimePrecision is used to represent the precision of a timestamp 14 | type TimePrecision int32 15 | 16 | const ( 17 | PrecisionUnknown TimePrecision = -1 18 | // below precision values are proto values 19 | PrecisionSeconds TimePrecision = 0 20 | PrecisionDeciSeconds TimePrecision = 1 21 | PrecisionCentiSeconds TimePrecision = 2 22 | PrecisionMilliSeconds TimePrecision = 3 23 | PrecisionEMinus4Seconds TimePrecision = 4 // 10^-4 of seconds 24 | PrecisionEMinus5Seconds TimePrecision = 5 // 10^-5 of seconds 25 | PrecisionMicroSeconds TimePrecision = 6 26 | PrecisionEMinus7Seconds TimePrecision = 7 // 10^-7 of seconds 27 | PrecisionEMinus8Seconds TimePrecision = 8 // 10^-8 of seconds 28 | PrecisionNanoSeconds TimePrecision = 9 29 | ) 30 | 31 | func (m TimePrecision) ToProtoVal() int32 { 32 | return int32(m) 33 | } 34 | 35 | func SubSecondsToDuration(subSeconds int64, precision TimePrecision) time.Duration { 36 | switch precision { 37 | case PrecisionSeconds: 38 | return time.Duration(subSeconds) * time.Second 39 | case PrecisionDeciSeconds: 40 | return time.Duration(subSeconds) * time.Second / 10 41 | case PrecisionCentiSeconds: 42 | return time.Duration(subSeconds) * time.Second / 100 43 | case PrecisionMilliSeconds: 44 | return time.Duration(subSeconds) * time.Millisecond 45 | case PrecisionEMinus4Seconds: 46 | return time.Duration(subSeconds) * 100 * time.Microsecond 47 | case PrecisionEMinus5Seconds: 48 | return time.Duration(subSeconds) * 10 * time.Microsecond 49 | case PrecisionMicroSeconds: 50 | return time.Duration(subSeconds) * time.Microsecond 51 | case PrecisionEMinus7Seconds: 52 | return time.Duration(subSeconds) * 100 * time.Nanosecond 53 | case PrecisionEMinus8Seconds: 54 | return time.Duration(subSeconds) * 10 * time.Nanosecond 55 | case PrecisionNanoSeconds: 56 | return time.Duration(subSeconds) * time.Nanosecond 57 | default: 58 | panic(fmt.Sprintf("invalid precision %d", precision)) 59 | } 60 | } 61 | 62 | func ProtoToTimePrecision(val int32) (TimePrecision, error) { 63 | if val < PrecisionSeconds.ToProtoVal() || val > PrecisionNanoSeconds.ToProtoVal() { 64 | return PrecisionUnknown, fmt.Errorf("invalid TimePrecision value %d", val) 65 | } 66 | return TimePrecision(val), nil 67 | } 68 | 69 | // PrecisionTimestampType this is used to represent a type of Precision timestamp 70 | type PrecisionTimestampType struct { 71 | Precision TimePrecision 72 | TypeVariationRef uint32 73 | Nullability Nullability 74 | } 75 | 76 | // NewPrecisionTimestampType creates a type of new Precision timestamp. 77 | // Created type has Nullability as Nullable 78 | func NewPrecisionTimestampType(precision TimePrecision) *PrecisionTimestampType { 79 | return &PrecisionTimestampType{ 80 | Precision: precision, 81 | Nullability: NullabilityNullable, 82 | } 83 | } 84 | 85 | func (m *PrecisionTimestampType) GetPrecisionProtoVal() int32 { 86 | return m.Precision.ToProtoVal() 87 | } 88 | 89 | func (*PrecisionTimestampType) isRootRef() {} 90 | func (m *PrecisionTimestampType) WithNullability(n Nullability) Type { 91 | return m.withNullability(n) 92 | } 93 | 94 | func (m *PrecisionTimestampType) GetParameters() []interface{} { 95 | return []interface{}{m.Precision} 96 | } 97 | 98 | func (m *PrecisionTimestampType) withNullability(n Nullability) *PrecisionTimestampType { 99 | return &PrecisionTimestampType{ 100 | Precision: m.Precision, 101 | Nullability: n, 102 | } 103 | } 104 | 105 | func (m *PrecisionTimestampType) GetType() Type { return m } 106 | func (m *PrecisionTimestampType) GetNullability() Nullability { return m.Nullability } 107 | func (m *PrecisionTimestampType) GetTypeVariationReference() uint32 { return m.TypeVariationRef } 108 | func (m *PrecisionTimestampType) Equals(rhs Type) bool { 109 | if o, ok := rhs.(*PrecisionTimestampType); ok { 110 | return *o == *m 111 | } 112 | return false 113 | } 114 | 115 | func (m *PrecisionTimestampType) ToProtoFuncArg() *proto.FunctionArgument { 116 | return &proto.FunctionArgument{ 117 | ArgType: &proto.FunctionArgument_Type{Type: m.ToProto()}, 118 | } 119 | } 120 | 121 | func (m *PrecisionTimestampType) ToProto() *proto.Type { 122 | return &proto.Type{Kind: &proto.Type_PrecisionTimestamp_{ 123 | PrecisionTimestamp: &proto.Type_PrecisionTimestamp{ 124 | Precision: m.Precision.ToProtoVal(), 125 | Nullability: m.Nullability, 126 | TypeVariationReference: m.TypeVariationRef}}} 127 | } 128 | 129 | func (*PrecisionTimestampType) ShortString() string { 130 | return GetShortTypeName(TypeNamePrecisionTimestamp) 131 | } 132 | func (m *PrecisionTimestampType) String() string { 133 | return fmt.Sprintf("%s%s<%d>", TypeNamePrecisionTimestamp, strNullable(m), 134 | m.Precision.ToProtoVal()) 135 | } 136 | 137 | func (m *PrecisionTimestampType) ParameterString() string { 138 | return fmt.Sprintf("%d", m.Precision.ToProtoVal()) 139 | } 140 | 141 | func (m *PrecisionTimestampType) BaseString() string { 142 | return typeNames[reflect.TypeOf(m)] 143 | } 144 | 145 | func (m *PrecisionTimestampType) GetPrecision() TimePrecision { 146 | return m.Precision 147 | } 148 | 149 | func (m *PrecisionTimestampType) GetReturnType(length int32, nullability Nullability) Type { 150 | out := *m 151 | out.Precision = TimePrecision(length) 152 | out.Nullability = nullability 153 | return &out 154 | } 155 | 156 | // PrecisionTimestampTzType this is used to represent a type of Precision timestamp with TimeZone 157 | type PrecisionTimestampTzType struct { 158 | PrecisionTimestampType 159 | } 160 | 161 | // NewPrecisionTimestampTzType creates a type of new Precision timestamp with TimeZone. 162 | // Created type has Nullability as Nullable 163 | func NewPrecisionTimestampTzType(precision TimePrecision) *PrecisionTimestampTzType { 164 | return &PrecisionTimestampTzType{ 165 | PrecisionTimestampType: PrecisionTimestampType{ 166 | Precision: precision, 167 | Nullability: NullabilityNullable, 168 | }, 169 | } 170 | } 171 | 172 | func (m *PrecisionTimestampTzType) ToProtoFuncArg() *proto.FunctionArgument { 173 | return &proto.FunctionArgument{ 174 | ArgType: &proto.FunctionArgument_Type{Type: m.ToProto()}, 175 | } 176 | } 177 | 178 | func (m *PrecisionTimestampTzType) ToProto() *proto.Type { 179 | return &proto.Type{Kind: &proto.Type_PrecisionTimestampTz{ 180 | PrecisionTimestampTz: &proto.Type_PrecisionTimestampTZ{ 181 | Precision: m.Precision.ToProtoVal(), 182 | Nullability: m.Nullability, 183 | TypeVariationReference: m.TypeVariationRef}}} 184 | } 185 | 186 | func (m *PrecisionTimestampTzType) String() string { 187 | return fmt.Sprintf("%s%s<%d>", TypeNamePrecisionTimestampTz, strNullable(m), 188 | m.Precision.ToProtoVal()) 189 | } 190 | 191 | func (m *PrecisionTimestampTzType) WithNullability(n Nullability) Type { 192 | return &PrecisionTimestampTzType{ 193 | PrecisionTimestampType: *m.PrecisionTimestampType.withNullability(n), 194 | } 195 | } 196 | 197 | func (m *PrecisionTimestampTzType) GetParameters() []interface{} { 198 | return []interface{}{m.Precision} 199 | } 200 | 201 | func (m *PrecisionTimestampTzType) Equals(rhs Type) bool { 202 | if o, ok := rhs.(*PrecisionTimestampTzType); ok { 203 | return *o == *m 204 | } 205 | return false 206 | } 207 | 208 | func (m *PrecisionTimestampTzType) GetNullability() Nullability { 209 | return m.Nullability 210 | } 211 | 212 | func (*PrecisionTimestampTzType) ShortString() string { 213 | return GetShortTypeName(TypeNamePrecisionTimestampTz) 214 | } 215 | 216 | func (m *PrecisionTimestampTzType) BaseString() string { 217 | return typeNames[reflect.TypeOf(m)] 218 | } 219 | 220 | func (m *PrecisionTimestampTzType) GetReturnType(length int32, nullability Nullability) Type { 221 | out := *m 222 | out.Precision = TimePrecision(length) 223 | out.Nullability = nullability 224 | return &out 225 | } 226 | -------------------------------------------------------------------------------- /types/precison_timestamp_types_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package types 4 | 5 | import ( 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/google/go-cmp/cmp" 11 | "github.com/stretchr/testify/assert" 12 | proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" 13 | "google.golang.org/protobuf/testing/protocmp" 14 | ) 15 | 16 | var expectedProtoValMap = map[TimePrecision]int32{ 17 | PrecisionSeconds: 0, 18 | PrecisionDeciSeconds: 1, 19 | PrecisionCentiSeconds: 2, 20 | PrecisionMilliSeconds: 3, 21 | PrecisionEMinus4Seconds: 4, 22 | PrecisionEMinus5Seconds: 5, 23 | PrecisionMicroSeconds: 6, 24 | PrecisionEMinus7Seconds: 7, 25 | PrecisionEMinus8Seconds: 8, 26 | PrecisionNanoSeconds: 9, 27 | } 28 | 29 | func TestProtoToTimePrecision(t *testing.T) { 30 | for expectedTimePrecision, expectedProtoVal := range expectedProtoValMap { 31 | got, err := ProtoToTimePrecision(expectedProtoVal) 32 | assert.NoError(t, err) 33 | assert.Equal(t, expectedTimePrecision, got) 34 | } 35 | 36 | got, err := ProtoToTimePrecision(-1) 37 | assert.Error(t, err) 38 | assert.Equal(t, PrecisionUnknown, got) 39 | _, err = ProtoToTimePrecision(10) 40 | assert.Error(t, err) 41 | assert.Equal(t, PrecisionUnknown, got) 42 | } 43 | 44 | func TestNewPrecisionTimestampType(t *testing.T) { 45 | allPossibleTimePrecision := []TimePrecision{PrecisionSeconds, PrecisionDeciSeconds, PrecisionCentiSeconds, PrecisionMilliSeconds, 46 | PrecisionEMinus4Seconds, PrecisionEMinus5Seconds, PrecisionMicroSeconds, PrecisionEMinus7Seconds, PrecisionEMinus8Seconds, PrecisionNanoSeconds} 47 | allPossibleNullability := []Nullability{NullabilityUnspecified, NullabilityNullable, NullabilityRequired} 48 | 49 | for _, precision := range allPossibleTimePrecision { 50 | for _, nullability := range allPossibleNullability { 51 | expectedPrecisionTimeStampType := PrecisionTimestampType{Precision: precision, Nullability: nullability} 52 | expectedPrecisionTimeStampTzType := PrecisionTimestampTzType{PrecisionTimestampType: expectedPrecisionTimeStampType} 53 | expectedFormatString := fmt.Sprintf("%s<%d>", strNullable(&expectedPrecisionTimeStampType), precision.ToProtoVal()) 54 | 55 | parameters := expectedPrecisionTimeStampType.GetParameters() 56 | assert.Equal(t, parameters, []interface{}{precision}) 57 | parameters = expectedPrecisionTimeStampTzType.GetParameters() 58 | assert.Equal(t, parameters, []interface{}{precision}) 59 | // verify PrecisionTimestampType 60 | createdPrecTimeStampType := NewPrecisionTimestampType(precision).WithNullability(nullability) 61 | createdPrecTimeStamp := createdPrecTimeStampType.(*PrecisionTimestampType) 62 | assert.True(t, createdPrecTimeStamp.Equals(&expectedPrecisionTimeStampType)) 63 | assert.Equal(t, expectedProtoValMap[precision], createdPrecTimeStamp.GetPrecisionProtoVal()) 64 | assert.Equal(t, nullability, createdPrecTimeStamp.GetNullability()) 65 | assert.Zero(t, createdPrecTimeStamp.GetTypeVariationReference()) 66 | assert.Equal(t, fmt.Sprintf("precision_timestamp%s", expectedFormatString), createdPrecTimeStamp.String()) 67 | assert.Equal(t, "pts", createdPrecTimeStamp.ShortString()) 68 | assertPrecisionTimeStampProto(t, precision, nullability, *createdPrecTimeStamp) 69 | 70 | // verify PrecisionTimestampTzType 71 | createdPrecTimeStampTzType := NewPrecisionTimestampTzType(precision).WithNullability(nullability) 72 | createdPrecTimeStampTz := createdPrecTimeStampTzType.(*PrecisionTimestampTzType) 73 | assert.True(t, createdPrecTimeStampTz.Equals(&expectedPrecisionTimeStampTzType)) 74 | assert.Equal(t, expectedProtoValMap[precision], createdPrecTimeStampTz.GetPrecisionProtoVal()) 75 | assert.Equal(t, nullability, createdPrecTimeStampTz.GetNullability()) 76 | assert.Zero(t, createdPrecTimeStampTz.GetTypeVariationReference()) 77 | assert.Equal(t, fmt.Sprintf("precision_timestamp_tz%s", expectedFormatString), createdPrecTimeStampTz.String()) 78 | assert.Equal(t, "ptstz", createdPrecTimeStampTz.ShortString()) 79 | assertPrecisionTimeStampTzProto(t, precision, nullability, *createdPrecTimeStampTz) 80 | 81 | // assert that both types are not equal 82 | assert.False(t, createdPrecTimeStampType.Equals(createdPrecTimeStampTzType)) 83 | assert.False(t, createdPrecTimeStampTzType.Equals(createdPrecTimeStampType)) 84 | } 85 | } 86 | } 87 | 88 | func assertPrecisionTimeStampProto(t *testing.T, expectedPrecision TimePrecision, expectedNullability Nullability, 89 | toVerifyType PrecisionTimestampType) { 90 | 91 | expectedTypeProto := &proto.Type{Kind: &proto.Type_PrecisionTimestamp_{ 92 | PrecisionTimestamp: &proto.Type_PrecisionTimestamp{ 93 | Precision: expectedPrecision.ToProtoVal(), 94 | Nullability: expectedNullability, 95 | }, 96 | }} 97 | if diff := cmp.Diff(toVerifyType.ToProto(), expectedTypeProto, protocmp.Transform()); diff != "" { 98 | t.Errorf("precisionTimeStamp proto didn't match, diff:\n%v", diff) 99 | } 100 | 101 | expectedFuncArgProto := &proto.FunctionArgument{ArgType: &proto.FunctionArgument_Type{ 102 | Type: expectedTypeProto, 103 | }} 104 | if diff := cmp.Diff(toVerifyType.ToProtoFuncArg(), expectedFuncArgProto, protocmp.Transform()); diff != "" { 105 | t.Errorf("precisionTimeStamp proto didn't match, diff:\n%v", diff) 106 | } 107 | } 108 | 109 | func assertPrecisionTimeStampTzProto(t *testing.T, expectedPrecision TimePrecision, expectedNullability Nullability, toVerifyType PrecisionTimestampTzType) { 110 | expectedTypeProto := &proto.Type{Kind: &proto.Type_PrecisionTimestampTz{ 111 | PrecisionTimestampTz: &proto.Type_PrecisionTimestampTZ{ 112 | Precision: expectedPrecision.ToProtoVal(), 113 | Nullability: expectedNullability, 114 | }, 115 | }} 116 | if diff := cmp.Diff(toVerifyType.ToProto(), expectedTypeProto, protocmp.Transform()); diff != "" { 117 | t.Errorf("precisionTimeStampTz proto didn't match, diff:\n%v", diff) 118 | } 119 | expectedFuncArgProto := &proto.FunctionArgument{ArgType: &proto.FunctionArgument_Type{ 120 | Type: expectedTypeProto, 121 | }} 122 | if diff := cmp.Diff(toVerifyType.ToProtoFuncArg(), expectedFuncArgProto, protocmp.Transform()); diff != "" { 123 | t.Errorf("precisionTimeStampTz proto didn't match, diff:\n%v", diff) 124 | } 125 | } 126 | 127 | func TestSubSecondsToDuration(t *testing.T) { 128 | tests := []struct { 129 | name string 130 | subSeconds int64 131 | precision TimePrecision 132 | want time.Duration 133 | }{ 134 | {"0.000000001s", 1, PrecisionNanoSeconds, time.Nanosecond}, 135 | {"0.00000001s", 1, PrecisionEMinus8Seconds, time.Nanosecond * 10}, 136 | {"0.0000001s", 1, PrecisionEMinus7Seconds, time.Nanosecond * 100}, 137 | {"0.000001s", 1, PrecisionMicroSeconds, time.Microsecond}, 138 | {"0.00001s", 1, PrecisionEMinus5Seconds, time.Microsecond * 10}, 139 | {"0.0001s", 1, PrecisionEMinus4Seconds, time.Microsecond * 100}, 140 | {"0.001s", 1, PrecisionMilliSeconds, time.Millisecond}, 141 | {"0.01s", 1, PrecisionCentiSeconds, time.Millisecond * 10}, 142 | {"0.1s", 1, PrecisionDeciSeconds, time.Millisecond * 100}, 143 | {"1s", 1, PrecisionSeconds, time.Second}, 144 | } 145 | for _, tt := range tests { 146 | t.Run(tt.name, func(t *testing.T) { 147 | assert.Equalf(t, tt.want, SubSecondsToDuration(tt.subSeconds, tt.precision), "SubSecondsToDuration(%v, %v)", tt.subSeconds, tt.precision) 148 | }) 149 | } 150 | } 151 | --------------------------------------------------------------------------------