├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ ├── pr_title.yml │ ├── release.yml │ ├── ruff.yml │ └── test.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── buf.gen.yaml ├── buf.yaml ├── environment.yml ├── examples ├── adbc_example.py ├── builder_example.py ├── duckdb_example.py └── pyarrow_example.py ├── gen_proto.sh ├── pyproject.toml ├── src └── substrait │ ├── __init__.py │ ├── builders │ ├── __init__.py │ ├── extended_expression.py │ ├── plan.py │ └── type.py │ ├── derivation_expression.py │ ├── extension_registry.py │ ├── extensions │ ├── __init__.py │ ├── extension_types.yaml │ ├── functions_aggregate_approx.yaml │ ├── functions_aggregate_decimal_output.yaml │ ├── functions_aggregate_generic.yaml │ ├── functions_arithmetic.yaml │ ├── functions_arithmetic_decimal.yaml │ ├── functions_boolean.yaml │ ├── functions_comparison.yaml │ ├── functions_datetime.yaml │ ├── functions_geometry.yaml │ ├── functions_logarithmic.yaml │ ├── functions_rounding.yaml │ ├── functions_rounding_decimal.yaml │ ├── functions_set.yaml │ ├── functions_string.yaml │ ├── type_variations.yaml │ └── unknown.yaml │ ├── gen │ ├── __init__.py │ ├── __init__.pyi │ ├── antlr │ │ ├── SubstraitTypeLexer.py │ │ ├── SubstraitTypeListener.py │ │ ├── SubstraitTypeParser.py │ │ ├── __init__.py │ │ └── __init__.pyi │ ├── json │ │ └── simple_extensions.py │ └── proto │ │ ├── __init__.py │ │ ├── __init__.pyi │ │ ├── algebra_pb2.py │ │ ├── algebra_pb2.pyi │ │ ├── capabilities_pb2.py │ │ ├── capabilities_pb2.pyi │ │ ├── extended_expression_pb2.py │ │ ├── extended_expression_pb2.pyi │ │ ├── extensions │ │ ├── __init__.py │ │ ├── __init__.pyi │ │ ├── extensions_pb2.py │ │ └── extensions_pb2.pyi │ │ ├── function_pb2.py │ │ ├── function_pb2.pyi │ │ ├── parameterized_types_pb2.py │ │ ├── parameterized_types_pb2.pyi │ │ ├── plan_pb2.py │ │ ├── plan_pb2.pyi │ │ ├── type_expressions_pb2.py │ │ ├── type_expressions_pb2.pyi │ │ ├── type_pb2.py │ │ └── type_pb2.pyi │ ├── json.py │ ├── proto.py │ ├── simple_extension_utils.py │ ├── sql │ └── sql_to_substrait.py │ ├── type_inference.py │ └── utils.py ├── tests ├── builders │ ├── extended_expression │ │ ├── test_aggregate_function.py │ │ ├── test_cast.py │ │ ├── test_column.py │ │ ├── test_if_then.py │ │ ├── test_literal.py │ │ ├── test_multi_or_list.py │ │ ├── test_scalar_function.py │ │ ├── test_singular_or_list.py │ │ ├── test_switch.py │ │ └── test_window_function.py │ └── plan │ │ ├── test_aggregate.py │ │ ├── test_cross.py │ │ ├── test_fetch.py │ │ ├── test_filter.py │ │ ├── test_join.py │ │ ├── test_project.py │ │ ├── test_read.py │ │ ├── test_set.py │ │ └── test_sort.py ├── sql │ └── test_sql_to_substrait.py ├── test_derivation_expression.py ├── test_extension_registry.py ├── test_json.py ├── test_literal_type_inference.py ├── test_proto.py ├── test_type_inference.py └── test_utils.py ├── update_cpp.sh ├── update_proto.sh └── uv.lock /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mcr.microsoft.com/vscode/devcontainers/python:3.9-buster 2 | USER vscode 3 | RUN curl -s "https://get.sdkman.io" | bash 4 | SHELL ["/bin/bash", "-c"] 5 | RUN source "/home/vscode/.sdkman/bin/sdkman-init.sh" && sdk install java 20.0.2-graalce 6 | RUN mkdir -p ~/lib && cd ~/lib && curl -L -O http://www.antlr.org/download/antlr-4.13.1-complete.jar 7 | ENV ANTLR_JAR="~/lib/antlr-4.13.1-complete.jar" 8 | RUN cd ~ && curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v25.1/protoc-25.1-linux-x86_64.zip && \ 9 | unzip protoc-25.1-linux-x86_64.zip -d ~/.local && \ 10 | rm protoc-25.1-linux-x86_64.zip 11 | RUN curl -sSL "https://github.com/bufbuild/buf/releases/download/v1.50.0/buf-$(uname -s)-$(uname -m)" -o ~/.local/bin/buf && chmod +x ~/.local/bin/buf 12 | RUN curl -LsSf https://astral.sh/uv/0.7.11/install.sh | sh 13 | USER root -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "substrait-python-devcontainer", 3 | "build": { 4 | "context": "..", 5 | "dockerfile": "Dockerfile" 6 | }, 7 | 8 | // Features to add to the dev container. More info: https://containers.dev/features. 9 | // "features": { 10 | // "ghcr.io/devcontainers/features/nix:1": {} 11 | // }, 12 | 13 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 14 | // "forwardPorts": [], 15 | 16 | // Use 'postCreateCommand' to run commands after the container is created. 17 | // "postCreateCommand": "poetry install" 18 | 19 | // Configure tool-specific properties. 20 | // "customizations": {}, 21 | 22 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 23 | // "remoteUser": "root" 24 | } -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | src/substrait/gen/** linguist-generated=true 2 | src/substrait/extensions/** linguist-generated=true 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/pr_title.yml: -------------------------------------------------------------------------------- 1 | name: PR Title Check 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened, edited, synchronize, reopened] 6 | jobs: 7 | commitlint: 8 | name: PR title / description conforms to semantic-release 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/setup-node@v4 12 | with: 13 | node-version: "18" 14 | - run: npm install @commitlint/config-conventional 15 | - run: > 16 | echo 'module.exports = { 17 | // Workaround for https://github.com/dependabot/dependabot-core/issues/5923 18 | "ignores": [(message) => /^Bumps \[.+]\(.+\) from .+ to .+\.$/m.test(message)] 19 | }' > .commitlintrc.js 20 | - run: npx commitlint --extends @commitlint/config-conventional --verbose <<< $COMMIT_MSG 21 | env: 22 | COMMIT_MSG: > 23 | ${{ github.event.pull_request.title }} 24 | 25 | ${{ github.event.pull_request.body }} 26 | - if: failure() 27 | uses: actions/github-script@v7 28 | with: 29 | script: | 30 | const message = `**ACTION NEEDED** 31 | 32 | Substrait follows the [Conventional Commits 33 | specification](https://www.conventionalcommits.org/en/v1.0.0/) for 34 | release automation. 35 | 36 | The PR title and description are used as the merge commit message.\ 37 | Please update your PR title and description to match the specification. 38 | ` 39 | // Get list of current comments 40 | const comments = await github.paginate(github.rest.issues.listComments, { 41 | owner: context.repo.owner, 42 | repo: context.repo.repo, 43 | issue_number: context.issue.number 44 | }); 45 | // Check if this job already commented 46 | for (const comment of comments) { 47 | if (comment.body === message) { 48 | return // Already commented 49 | } 50 | } 51 | // Post the comment about Conventional Commits 52 | github.rest.issues.createComment({ 53 | owner: context.repo.owner, 54 | repo: context.repo.repo, 55 | issue_number: context.issue.number, 56 | body: message 57 | }) 58 | core.setFailed(message) 59 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release and Publish 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | tags: [ 'v[0-9]+.[0-9]+.[0-9]+' ] 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | build: 13 | name: Build 14 | runs-on: ubuntu-latest 15 | if: ${{ github.repository == 'substrait-io/substrait-python' }} 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | with: 20 | submodules: recursive 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.x" 25 | - name: Install build dependencies 26 | run: | 27 | python -m pip install build --user 28 | - name: Build package 29 | run: | 30 | python -m build 31 | - name: Upload package 32 | uses: actions/upload-artifact@v4 33 | with: 34 | name: dist 35 | path: dist/ 36 | retention-days: 1 37 | release: 38 | name: Release to GitHub 39 | runs-on: ubuntu-latest 40 | needs: build 41 | if: startsWith(github.ref, 'refs/tags/v') 42 | steps: 43 | - name: Download artifact 44 | uses: actions/download-artifact@v4 45 | with: 46 | name: dist 47 | path: dist/ 48 | - name: Publish to GitHub release page 49 | uses: softprops/action-gh-release@v2 50 | with: 51 | files: | 52 | ./dist/*.whl 53 | ./dist/*.tar.gz 54 | publish: 55 | name: Publish to PyPI 56 | runs-on: ubuntu-latest 57 | needs: release 58 | steps: 59 | - name: Download artifact 60 | uses: actions/download-artifact@v4 61 | with: 62 | name: dist 63 | path: dist/ 64 | - name: Publish package to PyPI 65 | uses: pypa/gh-action-pypi-publish@release/v1 66 | with: 67 | password: ${{ secrets.PYPI_API_TOKEN }} 68 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Run linter and formatter 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [ main ] 7 | tags: [ 'v*.*.*' ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | test: 14 | name: Lint and Format 15 | strategy: 16 | matrix: 17 | os: [ubuntu-latest] 18 | python: ["3.9"] 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v4 23 | with: 24 | submodules: recursive 25 | - name: Install uv with python 26 | uses: astral-sh/setup-uv@v6 27 | with: 28 | python-version: ${{ matrix.python }} 29 | - name: Run ruff linter 30 | run: | 31 | uvx ruff@0.11.11 check 32 | - name: Run ruff formatter 33 | run: | 34 | uvx ruff@0.11.11 format -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [ main ] 7 | tags: [ 'v*.*.*' ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | test: 14 | name: Test 15 | strategy: 16 | matrix: 17 | os: [macos-latest, ubuntu-latest, windows-latest] 18 | python: ["3.9", "3.10", "3.11", "3.12", "3.13"] 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v4 23 | with: 24 | submodules: recursive 25 | - name: Set up Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python }} 29 | - name: Install package and test dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | python -m pip install ".[test]" 33 | - name: Run tests 34 | run: | 35 | python -m pytest tests 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # setuptools_scm dynamic versioning 132 | src/substrait/_version.py 133 | 134 | # Buf working directory 135 | ./buf_work_dir 136 | 137 | # Editor files 138 | .idea 139 | .vscode 140 | 141 | # OS generated files 142 | .directory 143 | .gdb_history 144 | .DS_Store 145 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/substrait"] 2 | path = third_party/substrait 3 | url = https://github.com/substrait-io/substrait 4 | [submodule "third_party/substrait-cpp"] 5 | path = third_party/substrait-cpp 6 | url = https://github.com/substrait-io/substrait-cpp 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_commit_msg: "style: [pre-commit.ci] autofix" 3 | autoupdate_commit_msg: "chore(deps): [pre-commit.ci] autoupdate" 4 | 5 | repos: 6 | - repo: https://github.com/astral-sh/ruff-pre-commit 7 | rev: v0.1.6 8 | hooks: 9 | - id: ruff 10 | args: [ --fix ] 11 | - id: ruff-format 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | ## Get the repo 3 | Fork and clone the repo. 4 | ``` 5 | git clone --recursive https://github.com//substrait-python.git 6 | cd substrait-python 7 | ``` 8 | 9 | ## Conda env 10 | Create a conda environment with developer dependencies. 11 | ``` 12 | conda env create -f environment.yml 13 | conda activate substrait-python-env 14 | ``` 15 | 16 | ## Update the substrait submodule locally 17 | This might be necessary if you are updating an existing checkout. 18 | ``` 19 | git submodule sync --recursive 20 | git submodule update --init --recursive 21 | ``` 22 | 23 | 24 | # Upgrade the substrait protocol definition 25 | 26 | ## a) Use the upgrade script 27 | 28 | Run the upgrade script to upgrade the submodule and regenerate the protobuf stubs. 29 | 30 | ``` 31 | ./update_proto.sh 32 | ``` 33 | 34 | ## b) Manual upgrade 35 | 36 | ### Upgrade the Substrait submodule 37 | 38 | ``` 39 | cd third_party/substrait 40 | git checkout 41 | cd - 42 | git commit . -m "Use submodule " 43 | ``` 44 | 45 | ### Generate protocol buffers 46 | Generate the protobuf files manually. Requires protobuf `v3.20.1`. 47 | ``` 48 | ./gen_proto.sh 49 | ``` 50 | 51 | 52 | # Build 53 | ## Python package 54 | Editable installation. 55 | ``` 56 | pip install -e . 57 | ``` 58 | 59 | # Test 60 | Run tests in the project's root dir. 61 | ``` 62 | pytest 63 | ``` 64 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | antlr: 2 | cd third_party/substrait/grammar \ 3 | && java -jar ${ANTLR_JAR} -o ../../../src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 \ 4 | && rm ../../../src/substrait/gen/antlr/*.tokens \ 5 | && rm ../../../src/substrait/gen/antlr/*.interp 6 | 7 | codegen-extensions: 8 | uv run --with datamodel-code-generator datamodel-codegen \ 9 | --input-file-type jsonschema \ 10 | --input third_party/substrait/text/simple_extensions_schema.yaml \ 11 | --output src/substrait/gen/json/simple_extensions.py \ 12 | --output-model-type dataclasses.dataclass 13 | 14 | lint: 15 | uvx ruff@0.11.11 check 16 | 17 | format: 18 | uvx ruff@0.11.11 format 19 | -------------------------------------------------------------------------------- /buf.gen.yaml: -------------------------------------------------------------------------------- 1 | version: v2 2 | plugins: 3 | - protoc_builtin: python 4 | out: src/substrait/gen 5 | - remote: buf.build/community/nipunn1313-mypy:v3.5.0 6 | out: src/substrait/gen 7 | -------------------------------------------------------------------------------- /buf.yaml: -------------------------------------------------------------------------------- 1 | version: v2 2 | modules: 3 | - path: buf_work_dir 4 | lint: 5 | use: 6 | - DEFAULT 7 | except: 8 | - FIELD_NOT_REQUIRED 9 | - PACKAGE_NO_IMPORT_CYCLE 10 | disallow_comment_ignores: true 11 | breaking: 12 | use: 13 | - FILE 14 | except: 15 | - EXTENSION_NO_DELETE 16 | - FIELD_SAME_DEFAULT 17 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: substrait-python-env 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - buf 6 | - pip 7 | - pre-commit 8 | - protobuf = 3.20.1 # protobuf==3.20 C extensions aren't compatible with 3.19.4 9 | - protoletariat >= 2.0.0 10 | - mypy-protobuf 11 | - pytest >= 7.0.0 12 | - python >= 3.8.1 13 | - setuptools >= 61.0.0 14 | - setuptools_scm >= 6.2.0 15 | -------------------------------------------------------------------------------- /examples/adbc_example.py: -------------------------------------------------------------------------------- 1 | # Install pyarrow, adbc-driver-manager and duckdb before running this example 2 | # This example currently can be run only with duckdb<=1.1.3, later versions of duckdb no longer support substrait in adbc 3 | # /// script 4 | # dependencies = [ 5 | # "pyarrow==20.0.0", 6 | # "adbc-driver-manager==1.5.0", 7 | # "duckdb==1.1.3", 8 | # "substrait[extensions] @ file:///${PROJECT_ROOT}/" 9 | # ] 10 | # /// 11 | 12 | 13 | import adbc_driver_duckdb.dbapi 14 | import pyarrow 15 | from substrait.builders.plan import read_named_table, filter 16 | from substrait.builders.extended_expression import scalar_function, column, literal 17 | from substrait.builders.type import i64 18 | from substrait.extension_registry import ExtensionRegistry 19 | import pyarrow.substrait as pa_substrait 20 | 21 | registry = ExtensionRegistry() 22 | 23 | data = pyarrow.record_batch( 24 | [[1, 2, 3, 4], ["a", "b", "c", "d"]], 25 | names=["ints", "strs"], 26 | ) 27 | 28 | 29 | def read_adbc_named_table(name: str, conn): 30 | pa_schema = conn.adbc_get_table_schema(name) 31 | substrait_schema = ( 32 | pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema 33 | ) 34 | return read_named_table(name, substrait_schema) 35 | 36 | 37 | with adbc_driver_duckdb.dbapi.connect(":memory:") as conn: 38 | with conn.cursor() as cur: 39 | cur.adbc_ingest("AnswerToEverything", data) 40 | 41 | cur.executescript("INSTALL substrait;") 42 | cur.executescript("LOAD substrait;") 43 | 44 | table = read_adbc_named_table("AnswerToEverything", conn) 45 | table = filter( 46 | table, 47 | expression=scalar_function( 48 | "functions_comparison.yaml", 49 | "gte", 50 | expressions=[column("ints"), literal(3, i64())], 51 | ), 52 | ) 53 | 54 | cur.execute(table(registry).SerializeToString()) 55 | print(cur.fetch_arrow_table()) 56 | -------------------------------------------------------------------------------- /examples/builder_example.py: -------------------------------------------------------------------------------- 1 | from substrait.builders.plan import read_named_table, project, filter 2 | from substrait.builders.extended_expression import column, scalar_function, literal 3 | from substrait.builders.type import i64, boolean, struct, named_struct 4 | from substrait.extension_registry import ExtensionRegistry 5 | 6 | registry = ExtensionRegistry(load_default_extensions=True) 7 | 8 | ns = named_struct( 9 | names=["id", "is_applicable"], struct=struct(types=[i64(nullable=False), boolean()]) 10 | ) 11 | 12 | table = read_named_table("example_table", ns) 13 | table = filter(table, expression=column("is_applicable")) 14 | table = filter( 15 | table, 16 | expression=scalar_function( 17 | "functions_comparison.yaml", 18 | "lt", 19 | expressions=[column("id"), literal(100, i64())], 20 | ), 21 | ) 22 | table = project(table, expressions=[column("id")]) 23 | 24 | print(table(registry)) 25 | 26 | """ 27 | extension_uris { 28 | extension_uri_anchor: 13 29 | uri: "functions_comparison.yaml" 30 | } 31 | extensions { 32 | extension_function { 33 | extension_uri_reference: 13 34 | function_anchor: 495 35 | name: "lt" 36 | } 37 | } 38 | relations { 39 | root { 40 | input { 41 | project { 42 | common { 43 | emit { 44 | output_mapping: 2 45 | } 46 | } 47 | input { 48 | filter { 49 | input { 50 | filter { 51 | input { 52 | read { 53 | common { 54 | direct { 55 | } 56 | } 57 | base_schema { 58 | names: "id" 59 | names: "is_applicable" 60 | struct { 61 | types { 62 | i64 { 63 | nullability: NULLABILITY_REQUIRED 64 | } 65 | } 66 | types { 67 | bool { 68 | nullability: NULLABILITY_NULLABLE 69 | } 70 | } 71 | nullability: NULLABILITY_NULLABLE 72 | } 73 | } 74 | named_table { 75 | names: "example_table" 76 | } 77 | } 78 | } 79 | condition { 80 | selection { 81 | direct_reference { 82 | struct_field { 83 | field: 1 84 | } 85 | } 86 | root_reference { 87 | } 88 | } 89 | } 90 | } 91 | } 92 | condition { 93 | scalar_function { 94 | function_reference: 495 95 | output_type { 96 | bool { 97 | nullability: NULLABILITY_NULLABLE 98 | } 99 | } 100 | arguments { 101 | value { 102 | selection { 103 | direct_reference { 104 | struct_field { 105 | } 106 | } 107 | root_reference { 108 | } 109 | } 110 | } 111 | } 112 | arguments { 113 | value { 114 | literal { 115 | i64: 100 116 | nullable: true 117 | } 118 | } 119 | } 120 | } 121 | } 122 | } 123 | } 124 | expressions { 125 | selection { 126 | direct_reference { 127 | struct_field { 128 | } 129 | } 130 | root_reference { 131 | } 132 | } 133 | } 134 | } 135 | } 136 | names: "id" 137 | } 138 | } 139 | """ 140 | -------------------------------------------------------------------------------- /examples/duckdb_example.py: -------------------------------------------------------------------------------- 1 | # Install duckdb and pyarrow before running this example 2 | # /// script 3 | # dependencies = [ 4 | # "pyarrow==20.0.0", 5 | # "duckdb==1.2.1", 6 | # "substrait[extensions] @ file:///${PROJECT_ROOT}/" 7 | # ] 8 | # /// 9 | 10 | 11 | import duckdb 12 | from substrait.builders.plan import read_named_table, project, filter 13 | from substrait.builders.extended_expression import column, scalar_function, literal 14 | from substrait.builders.type import i32 15 | from substrait.extension_registry import ExtensionRegistry 16 | from substrait.json import dump_json 17 | import pyarrow.substrait as pa_substrait 18 | 19 | try: 20 | duckdb.install_extension("substrait") 21 | except duckdb.duckdb.HTTPException: 22 | duckdb.install_extension("substrait", repository="community") 23 | duckdb.load_extension("substrait") 24 | 25 | duckdb.install_extension("tpch") 26 | duckdb.load_extension("tpch") 27 | 28 | duckdb.sql("CALL dbgen(sf = 1);") 29 | 30 | registry = ExtensionRegistry(load_default_extensions=True) 31 | 32 | 33 | def read_duckdb_named_table(name: str, conn): 34 | pa_schema = conn.sql(f"SELECT * FROM {name} LIMIT 0").arrow().schema 35 | substrait_schema = ( 36 | pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema 37 | ) 38 | return read_named_table(name, substrait_schema) 39 | 40 | 41 | table = read_duckdb_named_table("customer", duckdb) 42 | table = filter( 43 | table, 44 | expression=scalar_function( 45 | "functions_comparison.yaml", 46 | "equal", 47 | expressions=[column("c_nationkey"), literal(3, i32())], 48 | ), 49 | ) 50 | table = project( 51 | table, expressions=[column("c_name"), column("c_address"), column("c_nationkey")] 52 | ) 53 | 54 | sql = f"CALL from_substrait_json('{dump_json(table(registry))}')" 55 | print(duckdb.sql(sql)) 56 | -------------------------------------------------------------------------------- /examples/pyarrow_example.py: -------------------------------------------------------------------------------- 1 | # Install pyarrow before running this example 2 | # /// script 3 | # dependencies = [ 4 | # "pyarrow==20.0.0", 5 | # "substrait[extensions] @ file:///${PROJECT_ROOT}/" 6 | # ] 7 | # /// 8 | import pyarrow as pa 9 | import pyarrow.compute as pc 10 | import pyarrow.substrait as pa_substrait 11 | import substrait 12 | from substrait.builders.plan import project, read_named_table 13 | 14 | arrow_schema = pa.schema([pa.field("x", pa.int32()), pa.field("y", pa.int32())]) 15 | 16 | substrait_schema = ( 17 | pa_substrait.serialize_schema(arrow_schema).to_pysubstrait().base_schema 18 | ) 19 | 20 | substrait_expr = pa_substrait.serialize_expressions( 21 | exprs=[pc.field("x") + pc.field("y")], names=["total"], schema=arrow_schema 22 | ) 23 | 24 | pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(bytes(substrait_expr)) 25 | 26 | table = read_named_table("example", substrait_schema) 27 | table = project(table, expressions=[pysubstrait_expr])(None) 28 | print(table) 29 | -------------------------------------------------------------------------------- /gen_proto.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eou pipefail 4 | 5 | namespace=proto 6 | submodule_dir=./third_party/substrait 7 | src_dir="$submodule_dir"/proto 8 | tmp_dir=./buf_work_dir 9 | dest_dir=./src/substrait/gen 10 | extension_dir=./src/substrait/extensions 11 | 12 | # Prefix the protobuf files with a unique configuration to prevent namespace conflicts 13 | # with other substrait packages. Save output to the work dir. 14 | python "$submodule_dir"/tools/proto_prefix.py "$tmp_dir" "$namespace" "$src_dir" 15 | 16 | # Remove the old python protobuf files 17 | rm -rf "$dest_dir/proto" 18 | 19 | # Generate the new python protobuf files 20 | buf generate 21 | protol --in-place --create-package --python-out "$dest_dir" buf 22 | 23 | # Remove the old extension files 24 | rm -rf "$extension_dir" 25 | 26 | # Copy over new yaml files 27 | cp -fr "$submodule_dir"/extensions "$extension_dir" 28 | find "$extension_dir" -type f -exec chmod u+rw {} + 29 | 30 | # Ensure there's an __init__.py file in the extension directory 31 | touch $extension_dir/__init__.py 32 | 33 | # Remove the temporary work dir 34 | rm -rf "$tmp_dir" 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "substrait" 3 | description = "A python package for Substrait." 4 | authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com"}] 5 | license = {text = "Apache-2.0"} 6 | readme = "README.md" 7 | requires-python = ">=3.9" 8 | dependencies = ["protobuf >= 3.20"] 9 | dynamic = ["version"] 10 | 11 | [tool.setuptools_scm] 12 | write_to = "src/substrait/_version.py" 13 | 14 | [project.optional-dependencies] 15 | extensions = ["antlr4-python3-runtime", "pyyaml"] 16 | gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"] 17 | sql = ["sqloxide", "deepdiff"] 18 | test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml", "sqloxide", "deepdiff", "duckdb<=1.2.2", "datafusion"] 19 | 20 | [tool.pytest.ini_options] 21 | pythonpath = "src" 22 | testpaths = "tests" 23 | 24 | [build-system] 25 | requires = ["setuptools>=61.0.0", "setuptools_scm[toml]>=6.2.0"] 26 | build-backend = "setuptools.build_meta" 27 | 28 | [tool.ruff] 29 | respect-gitignore = true 30 | # should target minimum supported version 31 | target-version = "py39" 32 | # never autoformat upstream or generated code 33 | exclude = ["third_party/", "src/substrait/gen"] 34 | -------------------------------------------------------------------------------- /src/substrait/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import __version__ # noqa: F401 3 | except ImportError: 4 | pass 5 | 6 | __substrait_version__ = "0.74.0" 7 | __substrait_hash__ = "793c64b" 8 | __minimum_substrait_version__ = "0.30.0" 9 | -------------------------------------------------------------------------------- /src/substrait/builders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/substrait-io/substrait-python/0be6fc9371f4d4a2201e1dfb12825daadc03bc31/src/substrait/builders/__init__.py -------------------------------------------------------------------------------- /src/substrait/builders/type.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import substrait.gen.proto.type_pb2 as stt 3 | 4 | 5 | def boolean(nullable=True) -> stt.Type: 6 | return stt.Type( 7 | bool=stt.Type.Boolean( 8 | nullability=stt.Type.NULLABILITY_NULLABLE 9 | if nullable 10 | else stt.Type.NULLABILITY_REQUIRED 11 | ) 12 | ) 13 | 14 | 15 | def i8(nullable=True) -> stt.Type: 16 | return stt.Type( 17 | i8=stt.Type.I8( 18 | nullability=stt.Type.NULLABILITY_NULLABLE 19 | if nullable 20 | else stt.Type.NULLABILITY_REQUIRED 21 | ) 22 | ) 23 | 24 | 25 | def i16(nullable=True) -> stt.Type: 26 | return stt.Type( 27 | i16=stt.Type.I16( 28 | nullability=stt.Type.NULLABILITY_NULLABLE 29 | if nullable 30 | else stt.Type.NULLABILITY_REQUIRED 31 | ) 32 | ) 33 | 34 | 35 | def i32(nullable=True) -> stt.Type: 36 | return stt.Type( 37 | i32=stt.Type.I32( 38 | nullability=stt.Type.NULLABILITY_NULLABLE 39 | if nullable 40 | else stt.Type.NULLABILITY_REQUIRED 41 | ) 42 | ) 43 | 44 | 45 | def i64(nullable=True) -> stt.Type: 46 | return stt.Type( 47 | i64=stt.Type.I64( 48 | nullability=stt.Type.NULLABILITY_NULLABLE 49 | if nullable 50 | else stt.Type.NULLABILITY_REQUIRED 51 | ) 52 | ) 53 | 54 | 55 | def fp32(nullable=True) -> stt.Type: 56 | return stt.Type( 57 | fp32=stt.Type.FP32( 58 | nullability=stt.Type.NULLABILITY_NULLABLE 59 | if nullable 60 | else stt.Type.NULLABILITY_REQUIRED 61 | ) 62 | ) 63 | 64 | 65 | def fp64(nullable=True) -> stt.Type: 66 | return stt.Type( 67 | fp64=stt.Type.FP64( 68 | nullability=stt.Type.NULLABILITY_NULLABLE 69 | if nullable 70 | else stt.Type.NULLABILITY_REQUIRED 71 | ) 72 | ) 73 | 74 | 75 | def string(nullable=True) -> stt.Type: 76 | return stt.Type( 77 | string=stt.Type.String( 78 | nullability=stt.Type.NULLABILITY_NULLABLE 79 | if nullable 80 | else stt.Type.NULLABILITY_REQUIRED 81 | ) 82 | ) 83 | 84 | 85 | def binary(nullable=True) -> stt.Type: 86 | return stt.Type( 87 | binary=stt.Type.Binary( 88 | nullability=stt.Type.NULLABILITY_NULLABLE 89 | if nullable 90 | else stt.Type.NULLABILITY_REQUIRED 91 | ) 92 | ) 93 | 94 | 95 | def date(nullable=True) -> stt.Type: 96 | return stt.Type( 97 | date=stt.Type.Date( 98 | nullability=stt.Type.NULLABILITY_NULLABLE 99 | if nullable 100 | else stt.Type.NULLABILITY_REQUIRED 101 | ) 102 | ) 103 | 104 | 105 | def interval_year(nullable=True) -> stt.Type: 106 | return stt.Type( 107 | interval_year=stt.Type.IntervalYear( 108 | nullability=stt.Type.NULLABILITY_NULLABLE 109 | if nullable 110 | else stt.Type.NULLABILITY_REQUIRED 111 | ) 112 | ) 113 | 114 | 115 | def interval_day(precision: int, nullable=True) -> stt.Type: 116 | return stt.Type( 117 | interval_day=stt.Type.IntervalDay( 118 | precision=precision, 119 | nullability=stt.Type.NULLABILITY_NULLABLE 120 | if nullable 121 | else stt.Type.NULLABILITY_REQUIRED, 122 | ) 123 | ) 124 | 125 | 126 | def interval_compound(precision: int, nullable=True) -> stt.Type: 127 | return stt.Type( 128 | interval_compound=stt.Type.IntervalCompound( 129 | precision=precision, 130 | nullability=stt.Type.NULLABILITY_NULLABLE 131 | if nullable 132 | else stt.Type.NULLABILITY_REQUIRED, 133 | ) 134 | ) 135 | 136 | 137 | def uuid(nullable=True) -> stt.Type: 138 | return stt.Type( 139 | uuid=stt.Type.UUID( 140 | nullability=stt.Type.NULLABILITY_NULLABLE 141 | if nullable 142 | else stt.Type.NULLABILITY_REQUIRED 143 | ) 144 | ) 145 | 146 | 147 | def fixed_char(length: int, nullable=True) -> stt.Type: 148 | return stt.Type( 149 | fixed_char=stt.Type.FixedChar( 150 | length=length, 151 | nullability=stt.Type.NULLABILITY_NULLABLE 152 | if nullable 153 | else stt.Type.NULLABILITY_REQUIRED, 154 | ) 155 | ) 156 | 157 | 158 | def var_char(length: int, nullable=True) -> stt.Type: 159 | return stt.Type( 160 | varchar=stt.Type.VarChar( 161 | length=length, 162 | nullability=stt.Type.NULLABILITY_NULLABLE 163 | if nullable 164 | else stt.Type.NULLABILITY_REQUIRED, 165 | ) 166 | ) 167 | 168 | 169 | def fixed_binary(length: int, nullable=True) -> stt.Type: 170 | return stt.Type( 171 | fixed_binary=stt.Type.FixedBinary( 172 | length=length, 173 | nullability=stt.Type.NULLABILITY_NULLABLE 174 | if nullable 175 | else stt.Type.NULLABILITY_REQUIRED, 176 | ) 177 | ) 178 | 179 | 180 | def decimal(scale: int, precision: int, nullable=True) -> stt.Type: 181 | return stt.Type( 182 | decimal=stt.Type.Decimal( 183 | scale=scale, 184 | precision=precision, 185 | nullability=stt.Type.NULLABILITY_NULLABLE 186 | if nullable 187 | else stt.Type.NULLABILITY_REQUIRED, 188 | ) 189 | ) 190 | 191 | 192 | def precision_time(precision: int, nullable=True) -> stt.Type: 193 | return stt.Type( 194 | precision_time=stt.Type.PrecisionTime( 195 | precision=precision, 196 | nullability=stt.Type.NULLABILITY_NULLABLE 197 | if nullable 198 | else stt.Type.NULLABILITY_REQUIRED, 199 | ) 200 | ) 201 | 202 | 203 | def precision_timestamp(precision: int, nullable=True) -> stt.Type: 204 | return stt.Type( 205 | precision_timestamp=stt.Type.PrecisionTimestamp( 206 | precision=precision, 207 | nullability=stt.Type.NULLABILITY_NULLABLE 208 | if nullable 209 | else stt.Type.NULLABILITY_REQUIRED, 210 | ) 211 | ) 212 | 213 | 214 | def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: 215 | return stt.Type( 216 | precision_timestamp_tz=stt.Type.PrecisionTimestampTZ( 217 | precision=precision, 218 | nullability=stt.Type.NULLABILITY_NULLABLE 219 | if nullable 220 | else stt.Type.NULLABILITY_REQUIRED, 221 | ) 222 | ) 223 | 224 | 225 | def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: 226 | return stt.Type( 227 | struct=stt.Type.Struct( 228 | types=types, 229 | nullability=stt.Type.NULLABILITY_NULLABLE 230 | if nullable 231 | else stt.Type.NULLABILITY_REQUIRED, 232 | ) 233 | ) 234 | 235 | 236 | def list(type: stt.Type, nullable=True) -> stt.Type: 237 | return stt.Type( 238 | list=stt.Type.List( 239 | type=type, 240 | nullability=stt.Type.NULLABILITY_NULLABLE 241 | if nullable 242 | else stt.Type.NULLABILITY_REQUIRED, 243 | ) 244 | ) 245 | 246 | 247 | def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type: 248 | return stt.Type( 249 | map=stt.Type.Map( 250 | key=key, 251 | value=value, 252 | nullability=stt.Type.NULLABILITY_NULLABLE 253 | if nullable 254 | else stt.Type.NULLABILITY_REQUIRED, 255 | ) 256 | ) 257 | 258 | 259 | def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct: 260 | return stt.NamedStruct(names=names, struct=struct.struct) 261 | -------------------------------------------------------------------------------- /src/substrait/derivation_expression.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from antlr4 import InputStream, CommonTokenStream 3 | from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer 4 | from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser 5 | from substrait.gen.proto.type_pb2 import Type 6 | 7 | 8 | def _evaluate(x, values: dict): 9 | if isinstance(x, SubstraitTypeParser.BinaryExprContext): 10 | left = _evaluate(x.left, values) 11 | right = _evaluate(x.right, values) 12 | 13 | if x.op.text == "+": 14 | return left + right 15 | elif x.op.text == "-": 16 | return left - right 17 | elif x.op.text == "*": 18 | return left * right 19 | elif x.op.text == ">": 20 | return left > right 21 | elif x.op.text == ">=": 22 | return left >= right 23 | elif x.op.text == "<": 24 | return left < right 25 | elif x.op.text == "<=": 26 | return left <= right 27 | else: 28 | raise Exception(f"Unknown binary op {x.op.text}") 29 | elif isinstance(x, SubstraitTypeParser.LiteralNumberContext): 30 | return int(x.Number().symbol.text) 31 | elif isinstance(x, SubstraitTypeParser.ParameterNameContext): 32 | return values[x.Identifier().symbol.text] 33 | elif isinstance(x, SubstraitTypeParser.NumericParameterNameContext): 34 | return values[x.Identifier().symbol.text] 35 | elif isinstance(x, SubstraitTypeParser.ParenExpressionContext): 36 | return _evaluate(x.expr(), values) 37 | elif isinstance(x, SubstraitTypeParser.FunctionCallContext): 38 | exprs = [_evaluate(e, values) for e in x.expr()] 39 | func = x.Identifier().symbol.text 40 | if func == "min": 41 | return min(*exprs) 42 | elif func == "max": 43 | return max(*exprs) 44 | else: 45 | raise Exception(f"Unknown function {func}") 46 | elif isinstance(x, SubstraitTypeParser.TypeDefContext): 47 | scalar_type = x.scalarType() 48 | parametrized_type = x.parameterizedType() 49 | any_type = x.anyType() 50 | if scalar_type: 51 | nullability = ( 52 | Type.NULLABILITY_NULLABLE if x.isnull else Type.NULLABILITY_REQUIRED 53 | ) 54 | if isinstance(scalar_type, SubstraitTypeParser.I8Context): 55 | return Type(i8=Type.I8(nullability=nullability)) 56 | elif isinstance(scalar_type, SubstraitTypeParser.I16Context): 57 | return Type(i16=Type.I16(nullability=nullability)) 58 | elif isinstance(scalar_type, SubstraitTypeParser.I32Context): 59 | return Type(i32=Type.I32(nullability=nullability)) 60 | elif isinstance(scalar_type, SubstraitTypeParser.I64Context): 61 | return Type(i64=Type.I64(nullability=nullability)) 62 | elif isinstance(scalar_type, SubstraitTypeParser.Fp32Context): 63 | return Type(fp32=Type.FP32(nullability=nullability)) 64 | elif isinstance(scalar_type, SubstraitTypeParser.Fp64Context): 65 | return Type(fp64=Type.FP64(nullability=nullability)) 66 | elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): 67 | return Type(bool=Type.Boolean(nullability=nullability)) 68 | else: 69 | raise Exception(f"Unknown scalar type {type(scalar_type)}") 70 | elif parametrized_type: 71 | if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): 72 | precision = _evaluate(parametrized_type.precision, values) 73 | scale = _evaluate(parametrized_type.scale, values) 74 | nullability = ( 75 | Type.NULLABILITY_NULLABLE 76 | if parametrized_type.isnull 77 | else Type.NULLABILITY_REQUIRED 78 | ) 79 | return Type( 80 | decimal=Type.Decimal( 81 | precision=precision, scale=scale, nullability=nullability 82 | ) 83 | ) 84 | raise Exception(f"Unknown parametrized type {type(parametrized_type)}") 85 | elif any_type: 86 | any_var = any_type.AnyVar() 87 | if any_var: 88 | return values[any_var.symbol.text] 89 | else: 90 | raise Exception() 91 | else: 92 | raise Exception( 93 | "either scalar_type, parametrized_type or any_type is required" 94 | ) 95 | elif isinstance(x, SubstraitTypeParser.NumericExpressionContext): 96 | return _evaluate(x.expr(), values) 97 | elif isinstance(x, SubstraitTypeParser.TernaryContext): 98 | ifExpr = _evaluate(x.ifExpr, values) 99 | thenExpr = _evaluate(x.thenExpr, values) 100 | elseExpr = _evaluate(x.elseExpr, values) 101 | 102 | return thenExpr if ifExpr else elseExpr 103 | elif isinstance(x, SubstraitTypeParser.MultilineDefinitionContext): 104 | lines = zip(x.Identifier(), x.expr()) 105 | 106 | for i, e in lines: 107 | identifier = i.symbol.text 108 | expr_eval = _evaluate(e, values) 109 | values[identifier] = expr_eval 110 | 111 | return _evaluate(x.finalType, values) 112 | elif isinstance(x, SubstraitTypeParser.TypeLiteralContext): 113 | return _evaluate(x.typeDef(), values) 114 | elif isinstance(x, SubstraitTypeParser.NumericLiteralContext): 115 | return int(str(x.Number())) 116 | else: 117 | raise Exception(f"Unknown token type {type(x)}") 118 | 119 | 120 | def _parse(x: str): 121 | lexer = SubstraitTypeLexer(InputStream(x)) 122 | stream = CommonTokenStream(lexer) 123 | parser = SubstraitTypeParser(stream) 124 | return parser.expr() 125 | 126 | 127 | def evaluate(x: str, values: Optional[dict] = None): 128 | return _evaluate(_parse(x), values) 129 | -------------------------------------------------------------------------------- /src/substrait/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/substrait-io/substrait-python/0be6fc9371f4d4a2201e1dfb12825daadc03bc31/src/substrait/extensions/__init__.py -------------------------------------------------------------------------------- /src/substrait/extensions/extension_types.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | types: 3 | - name: point 4 | structure: 5 | latitude: i32 6 | longitude: i32 7 | - name: line 8 | structure: 9 | start: point 10 | end: point 11 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_aggregate_approx.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | aggregate_functions: 4 | - name: "approx_count_distinct" 5 | description: >- 6 | Calculates the approximate number of rows that contain distinct values of the expression argument using 7 | HyperLogLog. This function provides an alternative to the COUNT (DISTINCT expression) function, which 8 | returns the exact number of rows that contain distinct values of an expression. APPROX_COUNT_DISTINCT 9 | processes large amounts of data significantly faster than COUNT, with negligible deviation from the exact 10 | result. 11 | impls: 12 | - args: 13 | - name: x 14 | value: any 15 | nullability: DECLARED_OUTPUT 16 | decomposable: MANY 17 | intermediate: binary 18 | return: i64 19 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_aggregate_decimal_output.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | aggregate_functions: 4 | - name: "count" 5 | description: Count a set of values. Result is returned as a decimal instead of i64. 6 | impls: 7 | - args: 8 | - name: x 9 | value: any 10 | options: 11 | overflow: 12 | values: [SILENT, SATURATE, ERROR] 13 | nullability: DECLARED_OUTPUT 14 | decomposable: MANY 15 | intermediate: decimal<38,0> 16 | return: decimal<38,0> 17 | - name: "count" 18 | description: "Count a set of records (not field referenced). Result is returned as a decimal instead of i64." 19 | impls: 20 | - options: 21 | overflow: 22 | values: [SILENT, SATURATE, ERROR] 23 | nullability: DECLARED_OUTPUT 24 | decomposable: MANY 25 | intermediate: decimal<38,0> 26 | return: decimal<38,0> 27 | - name: "approx_count_distinct" 28 | description: >- 29 | Calculates the approximate number of rows that contain distinct values of the expression argument using 30 | HyperLogLog. This function provides an alternative to the COUNT (DISTINCT expression) function, which 31 | returns the exact number of rows that contain distinct values of an expression. APPROX_COUNT_DISTINCT 32 | processes large amounts of data significantly faster than COUNT, with negligible deviation from the exact 33 | result. Result is returned as a decimal instead of i64. 34 | impls: 35 | - args: 36 | - name: x 37 | value: any 38 | nullability: DECLARED_OUTPUT 39 | decomposable: MANY 40 | intermediate: binary 41 | return: decimal<38,0> 42 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_aggregate_generic.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | aggregate_functions: 4 | - name: "count" 5 | description: Count a set of values 6 | impls: 7 | - args: 8 | - name: x 9 | value: any 10 | options: 11 | overflow: 12 | values: [SILENT, SATURATE, ERROR] 13 | nullability: DECLARED_OUTPUT 14 | decomposable: MANY 15 | intermediate: i64 16 | return: i64 17 | - name: "count" 18 | description: "Count a set of records (not field referenced)" 19 | impls: 20 | - options: 21 | overflow: 22 | values: [SILENT, SATURATE, ERROR] 23 | nullability: DECLARED_OUTPUT 24 | decomposable: MANY 25 | intermediate: i64 26 | return: i64 27 | - name: "any_value" 28 | description: > 29 | Selects an arbitrary value from a group of values. 30 | 31 | If the input is empty, the function returns null. 32 | impls: 33 | - args: 34 | - name: x 35 | value: any1 36 | options: 37 | ignore_nulls: 38 | values: [ "TRUE", "FALSE" ] 39 | nullability: DECLARED_OUTPUT 40 | decomposable: MANY 41 | intermediate: any1? 42 | return: any1? 43 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_arithmetic_decimal.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | scalar_functions: 4 | - 5 | name: "add" 6 | description: "Add two decimal values." 7 | impls: 8 | - args: 9 | - name: x 10 | value: decimal 11 | - name: y 12 | value: decimal 13 | options: 14 | overflow: 15 | values: [ SILENT, SATURATE, ERROR ] 16 | return: |- 17 | init_scale = max(S1,S2) 18 | init_prec = init_scale + max(P1 - S1, P2 - S2) + 1 19 | min_scale = min(init_scale, 6) 20 | delta = init_prec - 38 21 | prec = min(init_prec, 38) 22 | scale_after_borrow = max(init_scale - delta, min_scale) 23 | scale = init_prec > 38 ? scale_after_borrow : init_scale 24 | DECIMAL 25 | - 26 | name: "subtract" 27 | impls: 28 | - args: 29 | - name: x 30 | value: decimal 31 | - name: y 32 | value: decimal 33 | options: 34 | overflow: 35 | values: [ SILENT, SATURATE, ERROR ] 36 | return: |- 37 | init_scale = max(S1,S2) 38 | init_prec = init_scale + max(P1 - S1, P2 - S2) + 1 39 | min_scale = min(init_scale, 6) 40 | delta = init_prec - 38 41 | prec = min(init_prec, 38) 42 | scale_after_borrow = max(init_scale - delta, min_scale) 43 | scale = init_prec > 38 ? scale_after_borrow : init_scale 44 | DECIMAL 45 | - 46 | name: "multiply" 47 | impls: 48 | - args: 49 | - name: x 50 | value: decimal 51 | - name: y 52 | value: decimal 53 | options: 54 | overflow: 55 | values: [ SILENT, SATURATE, ERROR ] 56 | return: |- 57 | init_scale = S1 + S2 58 | init_prec = P1 + P2 + 1 59 | min_scale = min(init_scale, 6) 60 | delta = init_prec - 38 61 | prec = min(init_prec, 38) 62 | scale_after_borrow = max(init_scale - delta, min_scale) 63 | scale = init_prec > 38 ? scale_after_borrow : init_scale 64 | DECIMAL 65 | - 66 | name: "divide" 67 | impls: 68 | - args: 69 | - name: x 70 | value: decimal 71 | - name: y 72 | value: decimal 73 | options: 74 | overflow: 75 | values: [ SILENT, SATURATE, ERROR ] 76 | return: |- 77 | init_scale = max(6, S1 + P2 + 1) 78 | init_prec = P1 - S1 + P2 + init_scale 79 | min_scale = min(init_scale, 6) 80 | delta = init_prec - 38 81 | prec = min(init_prec, 38) 82 | scale_after_borrow = max(init_scale - delta, min_scale) 83 | scale = init_prec > 38 ? scale_after_borrow : init_scale 84 | DECIMAL 85 | - 86 | name: "modulus" 87 | impls: 88 | - args: 89 | - name: x 90 | value: decimal 91 | - name: y 92 | value: decimal 93 | options: 94 | overflow: 95 | values: [ SILENT, SATURATE, ERROR ] 96 | return: |- 97 | init_scale = max(S1,S2) 98 | init_prec = min(P1 - S1, P2 - S2) + init_scale 99 | min_scale = min(init_scale, 6) 100 | delta = init_prec - 38 101 | prec = min(init_prec, 38) 102 | scale_after_borrow = max(init_scale - delta, min_scale) 103 | scale = init_prec > 38 ? scale_after_borrow : init_scale 104 | DECIMAL 105 | - 106 | name: "abs" 107 | description: Calculate the absolute value of the argument. 108 | impls: 109 | - args: 110 | - name: x 111 | value: decimal 112 | return: decimal 113 | - name: "bitwise_and" 114 | description: > 115 | Return the bitwise AND result for two decimal inputs. 116 | In inputs scale must be 0 (i.e. only integer types are allowed) 117 | impls: 118 | - args: 119 | - name: x 120 | value: "DECIMAL" 121 | - name: y 122 | value: "DECIMAL" 123 | return: |- 124 | max_precision = max(P1, P2) 125 | DECIMAL 126 | - name: "bitwise_or" 127 | description: > 128 | Return the bitwise OR result for two given decimal inputs. 129 | In inputs scale must be 0 (i.e. only integer types are allowed) 130 | impls: 131 | - args: 132 | - name: x 133 | value: "DECIMAL" 134 | - name: y 135 | value: "DECIMAL" 136 | return: |- 137 | max_precision = max(P1, P2) 138 | DECIMAL 139 | - name: "bitwise_xor" 140 | description: > 141 | Return the bitwise XOR result for two given decimal inputs. 142 | In inputs scale must be 0 (i.e. only integer types are allowed) 143 | impls: 144 | - args: 145 | - name: x 146 | value: "DECIMAL" 147 | - name: y 148 | value: "DECIMAL" 149 | return: |- 150 | max_precision = max(P1, P2) 151 | DECIMAL 152 | - name: "sqrt" 153 | description: Square root of the value. Sqrt of 0 is 0 and sqrt of negative values will raise an error. 154 | impls: 155 | - args: 156 | - name: x 157 | value: "DECIMAL" 158 | return: fp64 159 | - name: "factorial" 160 | description: > 161 | Return the factorial of a given decimal input. Scale should be 0 for factorial decimal input. 162 | The factorial of 0! is 1 by convention. Negative inputs will raise an error. 163 | Input which cause overflow of result will raise an error. 164 | impls: 165 | - args: 166 | - name: "n" 167 | value: "DECIMAL" 168 | return: "DECIMAL<38,0>" 169 | - 170 | name: "power" 171 | description: "Take the power with x as the base and y as exponent. 172 | Behavior for complex number result is indicated by option complex_number_result" 173 | impls: 174 | - args: 175 | - name: x 176 | value: "DECIMAL" 177 | - name: y 178 | value: "DECIMAL" 179 | options: 180 | overflow: 181 | values: [ SILENT, SATURATE, ERROR ] 182 | complex_number_result: 183 | values: [ NAN, ERROR ] 184 | return: fp64 185 | 186 | aggregate_functions: 187 | - name: "sum" 188 | description: Sum a set of values. 189 | impls: 190 | - args: 191 | - name: x 192 | value: "DECIMAL" 193 | options: 194 | overflow: 195 | values: [ SILENT, SATURATE, ERROR ] 196 | nullability: DECLARED_OUTPUT 197 | decomposable: MANY 198 | intermediate: "DECIMAL?<38,S>" 199 | return: "DECIMAL?<38,S>" 200 | - name: "avg" 201 | description: Average a set of values. 202 | impls: 203 | - args: 204 | - name: x 205 | value: "DECIMAL" 206 | options: 207 | overflow: 208 | values: [ SILENT, SATURATE, ERROR ] 209 | nullability: DECLARED_OUTPUT 210 | decomposable: MANY 211 | intermediate: "STRUCT,i64>" 212 | return: "DECIMAL<38,S>" 213 | - name: "min" 214 | description: Min a set of values. 215 | impls: 216 | - args: 217 | - name: x 218 | value: "DECIMAL" 219 | nullability: DECLARED_OUTPUT 220 | decomposable: MANY 221 | intermediate: "DECIMAL?" 222 | return: "DECIMAL?" 223 | - name: "max" 224 | description: Max a set of values. 225 | impls: 226 | - args: 227 | - name: x 228 | value: "DECIMAL" 229 | nullability: DECLARED_OUTPUT 230 | decomposable: MANY 231 | intermediate: "DECIMAL?" 232 | return: "DECIMAL?" 233 | - name: "sum0" 234 | description: > 235 | Sum a set of values. The sum of zero elements yields zero. 236 | 237 | Null values are ignored. 238 | impls: 239 | - args: 240 | - name: x 241 | value: "DECIMAL" 242 | options: 243 | overflow: 244 | values: [ SILENT, SATURATE, ERROR ] 245 | nullability: DECLARED_OUTPUT 246 | decomposable: MANY 247 | intermediate: "DECIMAL<38,S>" 248 | return: "DECIMAL<38,S>" 249 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_boolean.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | scalar_functions: 4 | - 5 | name: or 6 | description: > 7 | The boolean `or` using Kleene logic. 8 | 9 | This function behaves as follows with nulls: 10 | 11 | true or null = true 12 | 13 | null or true = true 14 | 15 | false or null = null 16 | 17 | null or false = null 18 | 19 | null or null = null 20 | 21 | In other words, in this context a null value really means "unknown", and 22 | an unknown value `or` true is always true. 23 | 24 | Behavior for 0 or 1 inputs is as follows: 25 | or() -> false 26 | or(x) -> x 27 | impls: 28 | - args: 29 | - value: boolean? 30 | name: a 31 | variadic: 32 | min: 0 33 | return: boolean? 34 | - 35 | name: and 36 | description: > 37 | The boolean `and` using Kleene logic. 38 | 39 | This function behaves as follows with nulls: 40 | 41 | true and null = null 42 | 43 | null and true = null 44 | 45 | false and null = false 46 | 47 | null and false = false 48 | 49 | null and null = null 50 | 51 | In other words, in this context a null value really means "unknown", and 52 | an unknown value `and` false is always false. 53 | 54 | Behavior for 0 or 1 inputs is as follows: 55 | and() -> true 56 | and(x) -> x 57 | impls: 58 | - args: 59 | - value: boolean? 60 | name: a 61 | variadic: 62 | min: 0 63 | return: boolean? 64 | - 65 | name: and_not 66 | description: > 67 | The boolean `and` of one value and the negation of the other using Kleene logic. 68 | 69 | This function behaves as follows with nulls: 70 | 71 | true and not null = null 72 | 73 | null and not false = null 74 | 75 | false and not null = false 76 | 77 | null and not true = false 78 | 79 | null and not null = null 80 | 81 | In other words, in this context a null value really means "unknown", and 82 | an unknown value `and not` true is always false, as is false `and not` an 83 | unknown value. 84 | impls: 85 | - args: 86 | - value: boolean? 87 | name: a 88 | - value: boolean? 89 | name: b 90 | return: boolean? 91 | - 92 | name: xor 93 | description: > 94 | The boolean `xor` of two values using Kleene logic. 95 | 96 | When a null is encountered in either input, a null is output. 97 | impls: 98 | - args: 99 | - value: boolean? 100 | name: a 101 | - value: boolean? 102 | name: b 103 | return: boolean? 104 | - 105 | name: not 106 | description: > 107 | The `not` of a boolean value. 108 | 109 | When a null is input, a null is output. 110 | impls: 111 | - args: 112 | - value: boolean? 113 | name: a 114 | return: boolean? 115 | 116 | aggregate_functions: 117 | - 118 | name: "bool_and" 119 | description: > 120 | If any value in the input is false, false is returned. If the input is 121 | empty or only contains nulls, null is returned. Otherwise, true is 122 | returned. 123 | impls: 124 | - args: 125 | - value: boolean 126 | name: a 127 | nullability: DECLARED_OUTPUT 128 | return: boolean? 129 | - 130 | name: "bool_or" 131 | description: > 132 | If any value in the input is true, true is returned. If the input is 133 | empty or only contains nulls, null is returned. Otherwise, false is 134 | returned. 135 | impls: 136 | - args: 137 | - value: boolean 138 | name: a 139 | nullability: DECLARED_OUTPUT 140 | return: boolean? 141 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_comparison.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | scalar_functions: 4 | - 5 | name: "not_equal" 6 | description: > 7 | Whether two values are not_equal. 8 | 9 | `not_equal(x, y) := (x != y)` 10 | 11 | If either/both of `x` and `y` are `null`, `null` is returned. 12 | impls: 13 | - args: 14 | - value: any1 15 | name: x 16 | - value: any1 17 | name: y 18 | return: boolean 19 | - 20 | name: "equal" 21 | description: > 22 | Whether two values are equal. 23 | 24 | `equal(x, y) := (x == y)` 25 | 26 | If either/both of `x` and `y` are `null`, `null` is returned. 27 | impls: 28 | - args: 29 | - value: any1 30 | name: x 31 | - value: any1 32 | name: y 33 | return: boolean 34 | - 35 | name: "is_not_distinct_from" 36 | description: > 37 | Whether two values are equal. 38 | 39 | This function treats `null` values as comparable, so 40 | 41 | `is_not_distinct_from(null, null) == True` 42 | 43 | This is in contrast to `equal`, in which `null` values do not compare. 44 | impls: 45 | - args: 46 | - value: any1 47 | name: x 48 | - value: any1 49 | name: y 50 | return: boolean 51 | nullability: DECLARED_OUTPUT 52 | - 53 | name: "is_distinct_from" 54 | description: > 55 | Whether two values are not equal. 56 | 57 | This function treats `null` values as comparable, so 58 | 59 | `is_distinct_from(null, null) == False` 60 | 61 | This is in contrast to `equal`, in which `null` values do not compare. 62 | impls: 63 | - args: 64 | - value: any1 65 | name: x 66 | - value: any1 67 | name: y 68 | return: boolean 69 | nullability: DECLARED_OUTPUT 70 | - 71 | name: "lt" 72 | description: > 73 | Less than. 74 | 75 | lt(x, y) := (x < y) 76 | 77 | If either/both of `x` and `y` are `null`, `null` is returned. 78 | impls: 79 | - args: 80 | - value: any1 81 | name: x 82 | - value: any1 83 | name: y 84 | return: boolean 85 | - 86 | name: "gt" 87 | description: > 88 | Greater than. 89 | 90 | gt(x, y) := (x > y) 91 | 92 | If either/both of `x` and `y` are `null`, `null` is returned. 93 | impls: 94 | - args: 95 | - value: any1 96 | name: x 97 | - value: any1 98 | name: y 99 | return: boolean 100 | - 101 | name: "lte" 102 | description: > 103 | Less than or equal to. 104 | 105 | lte(x, y) := (x <= y) 106 | 107 | If either/both of `x` and `y` are `null`, `null` is returned. 108 | impls: 109 | - args: 110 | - value: any1 111 | name: x 112 | - value: any1 113 | name: y 114 | return: boolean 115 | - 116 | name: "gte" 117 | description: > 118 | Greater than or equal to. 119 | 120 | gte(x, y) := (x >= y) 121 | 122 | If either/both of `x` and `y` are `null`, `null` is returned. 123 | impls: 124 | - args: 125 | - value: any1 126 | name: x 127 | - value: any1 128 | name: y 129 | return: boolean 130 | - 131 | name: "between" 132 | description: >- 133 | Whether the `expression` is greater than or equal to `low` and less than or equal to `high`. 134 | 135 | `expression` BETWEEN `low` AND `high` 136 | 137 | If `low`, `high`, or `expression` are `null`, `null` is returned. 138 | impls: 139 | - args: 140 | - value: any1 141 | name: expression 142 | description: The expression to test for in the range defined by `low` and `high`. 143 | - value: any1 144 | name: low 145 | description: The value to check if greater than or equal to. 146 | - value: any1 147 | name: high 148 | description: The value to check if less than or equal to. 149 | return: boolean 150 | - name: "is_true" 151 | description: Whether a value is true. 152 | impls: 153 | - args: 154 | - value: boolean? 155 | name: x 156 | return: BOOLEAN 157 | nullability: DECLARED_OUTPUT 158 | - name: "is_not_true" 159 | description: Whether a value is not true. 160 | impls: 161 | - args: 162 | - value: boolean? 163 | name: x 164 | return: BOOLEAN 165 | nullability: DECLARED_OUTPUT 166 | - name: "is_false" 167 | description: Whether a value is false. 168 | impls: 169 | - args: 170 | - value: boolean? 171 | name: x 172 | return: BOOLEAN 173 | nullability: DECLARED_OUTPUT 174 | - name: "is_not_false" 175 | description: Whether a value is not false. 176 | impls: 177 | - args: 178 | - value: boolean? 179 | name: x 180 | return: BOOLEAN 181 | nullability: DECLARED_OUTPUT 182 | - 183 | name: "is_null" 184 | description: Whether a value is null. NaN is not null. 185 | impls: 186 | - args: 187 | - value: any1 188 | name: x 189 | return: boolean 190 | nullability: DECLARED_OUTPUT 191 | - 192 | name: "is_not_null" 193 | description: Whether a value is not null. NaN is not null. 194 | impls: 195 | - args: 196 | - value: any1 197 | name: x 198 | return: boolean 199 | nullability: DECLARED_OUTPUT 200 | - 201 | name: "is_nan" 202 | description: > 203 | Whether a value is not a number. 204 | 205 | If `x` is `null`, `null` is returned. 206 | impls: 207 | - args: 208 | - value: fp32 209 | name: x 210 | return: boolean 211 | - args: 212 | - value: fp64 213 | name: x 214 | return: boolean 215 | - 216 | name: "is_finite" 217 | description: > 218 | Whether a value is finite (neither infinite nor NaN). 219 | 220 | If `x` is `null`, `null` is returned. 221 | impls: 222 | - args: 223 | - value: fp32 224 | name: x 225 | return: boolean 226 | - args: 227 | - value: fp64 228 | name: x 229 | return: boolean 230 | - 231 | name: "is_infinite" 232 | description: > 233 | Whether a value is infinite. 234 | 235 | If `x` is `null`, `null` is returned. 236 | impls: 237 | - args: 238 | - value: fp32 239 | name: x 240 | return: boolean 241 | - args: 242 | - value: fp64 243 | name: x 244 | return: boolean 245 | - 246 | name: "nullif" 247 | description: If two values are equal, return null. Otherwise, return the first value. 248 | impls: 249 | - args: 250 | - value: any1 251 | name: x 252 | - value: any1 253 | name: y 254 | return: any1 255 | - 256 | name: "coalesce" 257 | description: >- 258 | Evaluate arguments from left to right and return the first argument that is not null. Once 259 | a non-null argument is found, the remaining arguments are not evaluated. 260 | 261 | If all arguments are null, return null. 262 | impls: 263 | - args: 264 | - value: any1 265 | variadic: 266 | min: 2 267 | return: any1 268 | - 269 | name: "least" 270 | description: >- 271 | Evaluates each argument and returns the smallest one. 272 | The function will return null if any argument evaluates to null. 273 | impls: 274 | - args: 275 | - value: any1 276 | variadic: 277 | min: 2 278 | return: any1 279 | nullability: MIRROR 280 | - 281 | name: "least_skip_null" 282 | description: >- 283 | Evaluates each argument and returns the smallest one. 284 | The function will return null only if all arguments evaluate to null. 285 | impls: 286 | - args: 287 | - value: any1 288 | variadic: 289 | min: 2 290 | return: any1 291 | # NOTE: The return type nullability as described above cannot be expressed currently 292 | # See https://github.com/substrait-io/substrait/issues/601 293 | # Using MIRROR for now until it can be expressed 294 | nullability: MIRROR 295 | - 296 | name: "greatest" 297 | description: >- 298 | Evaluates each argument and returns the largest one. 299 | The function will return null if any argument evaluates to null. 300 | impls: 301 | - args: 302 | - value: any1 303 | variadic: 304 | min: 2 305 | return: any1 306 | nullability: MIRROR 307 | - 308 | name: "greatest_skip_null" 309 | description: >- 310 | Evaluates each argument and returns the largest one. 311 | The function will return null only if all arguments evaluate to null. 312 | impls: 313 | - args: 314 | - value: any1 315 | variadic: 316 | min: 2 317 | return: any1 318 | # NOTE: The return type nullability as described above cannot be expressed currently 319 | # See https://github.com/substrait-io/substrait/issues/601 320 | # Using MIRROR for now until it can be expressed 321 | nullability: MIRROR 322 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_geometry.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | types: 4 | - name: geometry 5 | structure: "BINARY" 6 | # description: | 7 | # An opaque type that can represent one or many points, lines, or shapes encompassing 8 | # 2, 3 or 4 dimension. 9 | scalar_functions: 10 | - 11 | name: "point" 12 | description: > 13 | Returns a 2D point with the given `x` and `y` coordinate values. 14 | impls: 15 | - args: 16 | - name: x 17 | value: fp64 18 | - name: y 19 | value: fp64 20 | return: u!geometry 21 | - 22 | name: "make_line" 23 | description: > 24 | Returns a linestring connecting the endpoint of geometry `geom1` to the begin point of 25 | geometry `geom2`. Repeated points at the beginning of input geometries are collapsed to a single point. 26 | 27 | A linestring can be closed or simple. A closed linestring starts and ends on the same 28 | point. A simple linestring does not cross or touch itself. 29 | impls: 30 | - args: 31 | - name: geom1 32 | value: u!geometry 33 | - name: geom2 34 | value: u!geometry 35 | return: u!geometry 36 | - 37 | name: "x_coordinate" 38 | description: > 39 | Return the x coordinate of the point. Return null if not available. 40 | impls: 41 | - args: 42 | - name: point 43 | value: u!geometry 44 | return: fp64 45 | - 46 | name: "y_coordinate" 47 | description: > 48 | Return the y coordinate of the point. Return null if not available. 49 | impls: 50 | - args: 51 | - name: point 52 | value: u!geometry 53 | return: fp64 54 | - 55 | name: "num_points" 56 | description: > 57 | Return the number of points in the geometry. The geometry should be an linestring 58 | or circularstring. 59 | impls: 60 | - args: 61 | - name: geom 62 | value: u!geometry 63 | return: i64 64 | - 65 | name: "is_empty" 66 | description: > 67 | Return true is the geometry is an empty geometry. 68 | impls: 69 | - args: 70 | - name: geom 71 | value: u!geometry 72 | return: boolean 73 | - 74 | name: "is_closed" 75 | description: > 76 | Return true if the geometry's start and end points are the same. 77 | impls: 78 | - args: 79 | - name: geom 80 | value: u!geometry 81 | return: boolean 82 | - 83 | name: "is_simple" 84 | description: > 85 | Return true if the geometry does not self intersect. 86 | impls: 87 | - args: 88 | - name: geom 89 | value: u!geometry 90 | return: boolean 91 | - 92 | name: "is_ring" 93 | description: > 94 | Return true if the geometry's start and end points are the same and it does not self 95 | intersect. 96 | impls: 97 | - args: 98 | - name: geom 99 | value: u!geometry 100 | return: boolean 101 | - 102 | name: "geometry_type" 103 | description: > 104 | Return the type of geometry as a string. 105 | impls: 106 | - args: 107 | - name: geom 108 | value: u!geometry 109 | return: string 110 | - 111 | name: "envelope" 112 | description: > 113 | Return the minimum bounding box for the input geometry as a geometry. 114 | 115 | The returned geometry is defined by the corner points of the bounding box. If the 116 | input geometry is a point or a line, the returned geometry can also be a point or line. 117 | impls: 118 | - args: 119 | - name: geom 120 | value: u!geometry 121 | return: u!geometry 122 | - 123 | name: "dimension" 124 | description: > 125 | Return the dimension of the input geometry. If the input is a collection of geometries, 126 | return the largest dimension from the collection. Dimensionality is determined by 127 | the complexity of the input and not the coordinate system being used. 128 | 129 | Type dimensions: 130 | POINT - 0 131 | LINE - 1 132 | POLYGON - 2 133 | impls: 134 | - args: 135 | - name: geom 136 | value: u!geometry 137 | return: i8 138 | - 139 | name: "is_valid" 140 | description: > 141 | Return true if the input geometry is a valid 2D geometry. 142 | 143 | For 3 dimensional and 4 dimensional geometries, the validity is still only tested 144 | in 2 dimensions. 145 | impls: 146 | - args: 147 | - name: geom 148 | value: u!geometry 149 | return: boolean 150 | - 151 | name: "collection_extract" 152 | description: > 153 | Given the input geometry collection, return a homogenous multi-geometry. All geometries 154 | in the multi-geometry will have the same dimension. 155 | 156 | If type is not specified, the multi-geometry will only contain geometries of the highest 157 | dimension. If type is specified, the multi-geometry will only contain geometries 158 | of that type. If there are no geometries of the specified type, an empty geometry 159 | is returned. Only points, linestrings, and polygons are supported. 160 | 161 | Type numbers: 162 | POINT - 0 163 | LINE - 1 164 | POLYGON - 2 165 | impls: 166 | - args: 167 | - name: geom_collection 168 | value: u!geometry 169 | return: u!geometry 170 | - args: 171 | - name: geom_collection 172 | value: u!geometry 173 | - name: type 174 | value: i8 175 | return: u!geometry 176 | - 177 | name: "flip_coordinates" 178 | description: > 179 | Return a version of the input geometry with the X and Y axis flipped. 180 | 181 | This operation can be performed on geometries with more than 2 dimensions. However, 182 | only X and Y axis will be flipped. 183 | impls: 184 | - args: 185 | - name: geom_collection 186 | value: u!geometry 187 | return: u!geometry 188 | - 189 | name: "remove_repeated_points" 190 | description: > 191 | Return a version of the input geometry with duplicate consecutive points removed. 192 | 193 | If the `tolerance` argument is provided, consecutive points within the tolerance 194 | distance of one another are considered to be duplicates. 195 | impls: 196 | - args: 197 | - name: geom 198 | value: u!geometry 199 | return: u!geometry 200 | - args: 201 | - name: geom 202 | value: u!geometry 203 | - name: tolerance 204 | value: fp64 205 | return: u!geometry 206 | - 207 | name: "buffer" 208 | description: > 209 | Compute and return an expanded version of the input geometry. All the points 210 | of the returned geometry are at a distance of `buffer_radius` away from the points 211 | of the input geometry. If a negative `buffer_radius` is provided, the geometry will 212 | shrink instead of expand. A negative `buffer_radius` may shrink the geometry completely, 213 | in which case an empty geometry is returned. For input the geometries of points or lines, 214 | a negative `buffer_radius` will always return an emtpy geometry. 215 | impls: 216 | - args: 217 | - name: geom 218 | value: u!geometry 219 | - name: buffer_radius 220 | value: fp64 221 | return: u!geometry 222 | - 223 | name: "centroid" 224 | description: > 225 | Return a point which is the geometric center of mass of the input geometry. 226 | impls: 227 | - args: 228 | - name: geom 229 | value: u!geometry 230 | return: u!geometry 231 | - 232 | name: "minimum_bounding_circle" 233 | description: > 234 | Return the smallest circle polygon that contains the input geometry. 235 | impls: 236 | - args: 237 | - name: geom 238 | value: u!geometry 239 | return: u!geometry 240 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_logarithmic.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | scalar_functions: 4 | - 5 | name: "ln" 6 | description: "Natural logarithm of the value" 7 | impls: 8 | - args: 9 | - name: x 10 | value: i64 11 | options: 12 | rounding: 13 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 14 | on_domain_error: 15 | values: [ NAN, "NULL", ERROR ] 16 | on_log_zero: 17 | values: [NAN, ERROR, MINUS_INFINITY] 18 | return: fp64 19 | - args: 20 | - name: x 21 | value: fp32 22 | options: 23 | rounding: 24 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 25 | on_domain_error: 26 | values: [ NAN, "NULL", ERROR ] 27 | on_log_zero: 28 | values: [NAN, ERROR, MINUS_INFINITY] 29 | return: fp32 30 | - args: 31 | - name: x 32 | value: fp64 33 | options: 34 | rounding: 35 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 36 | on_domain_error: 37 | values: [ NAN, "NULL", ERROR ] 38 | on_log_zero: 39 | values: [NAN, ERROR, MINUS_INFINITY] 40 | return: fp64 41 | - args: 42 | - name: x 43 | value: decimal 44 | options: 45 | rounding: 46 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 47 | on_domain_error: 48 | values: [ NAN, "NULL", ERROR ] 49 | on_log_zero: 50 | values: [ NAN, ERROR, MINUS_INFINITY ] 51 | return: fp64 52 | - 53 | name: "log10" 54 | description: "Logarithm to base 10 of the value" 55 | impls: 56 | - args: 57 | - name: x 58 | value: i64 59 | options: 60 | rounding: 61 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 62 | on_domain_error: 63 | values: [ NAN, "NULL", ERROR ] 64 | on_log_zero: 65 | values: [NAN, ERROR, MINUS_INFINITY] 66 | return: fp64 67 | - args: 68 | - name: x 69 | value: fp32 70 | options: 71 | rounding: 72 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 73 | on_domain_error: 74 | values: [ NAN, "NULL", ERROR ] 75 | on_log_zero: 76 | values: [NAN, ERROR, MINUS_INFINITY] 77 | return: fp32 78 | - args: 79 | - name: x 80 | value: fp64 81 | options: 82 | rounding: 83 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 84 | on_domain_error: 85 | values: [ NAN, "NULL", ERROR ] 86 | on_log_zero: 87 | values: [NAN, ERROR, MINUS_INFINITY] 88 | return: fp64 89 | - args: 90 | - name: x 91 | value: decimal 92 | options: 93 | rounding: 94 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 95 | on_domain_error: 96 | values: [ NAN, "NULL", ERROR ] 97 | on_log_zero: 98 | values: [ NAN, ERROR, MINUS_INFINITY ] 99 | return: fp64 100 | - 101 | name: "log2" 102 | description: "Logarithm to base 2 of the value" 103 | impls: 104 | - args: 105 | - name: x 106 | value: i64 107 | options: 108 | rounding: 109 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 110 | on_domain_error: 111 | values: [ NAN, "NULL", ERROR ] 112 | on_log_zero: 113 | values: [NAN, ERROR, MINUS_INFINITY] 114 | return: fp64 115 | - args: 116 | - name: x 117 | value: fp32 118 | options: 119 | rounding: 120 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 121 | on_domain_error: 122 | values: [ NAN, "NULL", ERROR ] 123 | on_log_zero: 124 | values: [NAN, ERROR, MINUS_INFINITY] 125 | return: fp32 126 | - args: 127 | - name: x 128 | value: fp64 129 | options: 130 | rounding: 131 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 132 | on_domain_error: 133 | values: [ NAN, "NULL", ERROR ] 134 | on_log_zero: 135 | values: [NAN, ERROR, MINUS_INFINITY] 136 | return: fp64 137 | - args: 138 | - name: x 139 | value: decimal 140 | options: 141 | rounding: 142 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 143 | on_domain_error: 144 | values: [ NAN, "NULL", ERROR ] 145 | on_log_zero: 146 | values: [ NAN, ERROR, MINUS_INFINITY ] 147 | return: fp64 148 | - 149 | name: "logb" 150 | description: > 151 | Logarithm of the value with the given base 152 | 153 | logb(x, b) => log_{b} (x) 154 | impls: 155 | - args: 156 | - value: i64 157 | name: "x" 158 | description: "The number `x` to compute the logarithm of" 159 | - value: i64 160 | name: "base" 161 | description: "The logarithm base `b` to use" 162 | options: 163 | rounding: 164 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 165 | on_domain_error: 166 | values: [ NAN, "NULL", ERROR ] 167 | on_log_zero: 168 | values: [NAN, ERROR, MINUS_INFINITY] 169 | return: fp64 170 | - args: 171 | - value: fp32 172 | name: "x" 173 | description: "The number `x` to compute the logarithm of" 174 | - value: fp32 175 | name: "base" 176 | description: "The logarithm base `b` to use" 177 | options: 178 | rounding: 179 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 180 | on_domain_error: 181 | values: [ NAN, "NULL", ERROR ] 182 | on_log_zero: 183 | values: [NAN, ERROR, MINUS_INFINITY] 184 | return: fp32 185 | - args: 186 | - value: fp64 187 | name: "x" 188 | description: "The number `x` to compute the logarithm of" 189 | - value: fp64 190 | name: "base" 191 | description: "The logarithm base `b` to use" 192 | options: 193 | rounding: 194 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 195 | on_domain_error: 196 | values: [ NAN, "NULL", ERROR ] 197 | on_log_zero: 198 | values: [NAN, ERROR, MINUS_INFINITY] 199 | return: fp64 200 | - args: 201 | - value: decimal 202 | name: "x" 203 | description: "The number `x` to compute the logarithm of" 204 | - value: decimal 205 | name: "base" 206 | description: "The logarithm base `b` to use" 207 | options: 208 | rounding: 209 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 210 | on_domain_error: 211 | values: [ NAN, "NULL", ERROR ] 212 | on_log_zero: 213 | values: [NAN, ERROR, MINUS_INFINITY] 214 | return: fp64 215 | - 216 | name: "log1p" 217 | description: > 218 | Natural logarithm (base e) of 1 + x 219 | 220 | log1p(x) => log(1+x) 221 | impls: 222 | - args: 223 | - name: x 224 | value: fp32 225 | options: 226 | rounding: 227 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 228 | on_domain_error: 229 | values: [ NAN, "NULL", ERROR ] 230 | on_log_zero: 231 | values: [NAN, ERROR, MINUS_INFINITY] 232 | return: fp32 233 | - args: 234 | - name: x 235 | value: fp64 236 | options: 237 | rounding: 238 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 239 | on_domain_error: 240 | values: [ NAN, "NULL", ERROR ] 241 | on_log_zero: 242 | values: [NAN, ERROR, MINUS_INFINITY] 243 | return: fp64 244 | - args: 245 | - name: x 246 | value: decimal 247 | options: 248 | rounding: 249 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR ] 250 | on_domain_error: 251 | values: [ NAN, "NULL", ERROR ] 252 | on_log_zero: 253 | values: [NAN, ERROR, MINUS_INFINITY] 254 | return: fp64 255 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_rounding_decimal.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | scalar_functions: 4 | - 5 | name: "ceil" 6 | description: > 7 | Rounding to the ceiling of the value `x`. 8 | impls: 9 | - args: 10 | - value: decimal 11 | name: x 12 | return: |- 13 | integral_least_num_digits = P - S + 1 14 | precision = min(integral_least_num_digits, 38) 15 | decimal? 16 | - 17 | name: "floor" 18 | description: > 19 | Rounding to the floor of the value `x`. 20 | impls: 21 | - args: 22 | - value: decimal 23 | name: x 24 | return: |- 25 | integral_least_num_digits = P - S + 1 26 | precision = min(integral_least_num_digits, 38) 27 | decimal? 28 | - 29 | name: "round" 30 | description: > 31 | Rounding the value `x` to `s` decimal places. 32 | impls: 33 | - args: 34 | - value: decimal 35 | name: x 36 | description: > 37 | Numerical expression to be rounded. 38 | - value: i32 39 | name: s 40 | description: > 41 | Number of decimal places to be rounded to. 42 | 43 | When `s` is a positive number, the rounding 44 | is performed to a `s` number of decimal places. 45 | 46 | When `s` is a negative number, the rounding is 47 | performed to the left side of the decimal point 48 | as specified by `s`. 49 | 50 | The precision of the resultant decimal type is one 51 | more than the precision of the input decimal type to 52 | allow for numbers that round up or down to the next 53 | decimal magnitude. 54 | E.g. `round(9.9, 0)` -> `10.0`. 55 | The scale of the resultant decimal type cannot be 56 | larger than the scale of the input decimal type. 57 | options: 58 | rounding: 59 | description: > 60 | When a boundary is computed to lie somewhere between two values, 61 | and this value cannot be exactly represented, this specifies how 62 | to round it. 63 | 64 | - TIE_TO_EVEN: round to nearest value; if exactly halfway, tie 65 | to the even option. 66 | - TIE_AWAY_FROM_ZERO: round to nearest value; if exactly 67 | halfway, tie away from zero. 68 | - TRUNCATE: always round toward zero. 69 | - CEILING: always round toward positive infinity. 70 | - FLOOR: always round toward negative infinity. 71 | - AWAY_FROM_ZERO: round negative values with FLOOR rule, round positive values with CEILING rule 72 | - TIE_DOWN: round ties with FLOOR rule 73 | - TIE_UP: round ties with CEILING rule 74 | - TIE_TOWARDS_ZERO: round ties with TRUNCATE rule 75 | - TIE_TO_ODD: round to nearest value; if exactly halfway, tie 76 | to the odd option. 77 | values: [ TIE_TO_EVEN, TIE_AWAY_FROM_ZERO, TRUNCATE, CEILING, FLOOR, 78 | AWAY_FROM_ZERO, TIE_DOWN, TIE_UP, TIE_TOWARDS_ZERO, TIE_TO_ODD ] 79 | nullability: DECLARED_OUTPUT 80 | return: |- 81 | precision = min(P + 1, 38) 82 | decimal? 83 | -------------------------------------------------------------------------------- /src/substrait/extensions/functions_set.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | scalar_functions: 4 | - 5 | name: "index_in" 6 | description: > 7 | Checks the membership of a value in a list of values 8 | 9 | Returns the first 0-based index value of some input `needle` if `needle` is equal to 10 | any element in `haystack`. Returns `NULL` if not found. 11 | 12 | If `needle` is `NULL`, returns `NULL`. 13 | 14 | If `needle` is `NaN`: 15 | - Returns 0-based index of `NaN` in `input` (default) 16 | - Returns `NULL` (if `NAN_IS_NOT_NAN` is specified) 17 | impls: 18 | - args: 19 | - name: needle 20 | value: any1 21 | - name: haystack 22 | value: list 23 | options: 24 | nan_equality: 25 | values: [ NAN_IS_NAN, NAN_IS_NOT_NAN ] 26 | nullability: DECLARED_OUTPUT 27 | return: i64? 28 | -------------------------------------------------------------------------------- /src/substrait/extensions/type_variations.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | type_variations: 4 | - parent: string 5 | name: dict4 6 | description: a four-byte dictionary encoded string 7 | functions: INHERITS 8 | - parent: string 9 | name: bigoffset 10 | description: >- 11 | The arrow large string representation of strings, still restricted to the default string size defined in 12 | Substrait. 13 | functions: SEPARATE 14 | - parent: struct 15 | name: avro 16 | description: an avro encoded struct 17 | functions: SEPARATE 18 | - parent: struct 19 | name: cstruct 20 | description: a cstruct representation of the struct 21 | functions: SEPARATE 22 | - parent: struct 23 | name: dict2 24 | description: a 2-byte dictionary encoded string. 25 | functions: INHERITS 26 | -------------------------------------------------------------------------------- /src/substrait/extensions/unknown.yaml: -------------------------------------------------------------------------------- 1 | %YAML 1.2 2 | --- 3 | types: 4 | - name: unknown 5 | scalar_functions: 6 | - name: "add" 7 | impls: 8 | - args: 9 | - value: unknown 10 | - value: unknown 11 | return: unknown 12 | - name: "subtract" 13 | impls: 14 | - args: 15 | - value: unknown 16 | - value: unknown 17 | return: unknown 18 | - name: "multiply" 19 | impls: 20 | - args: 21 | - value: unknown 22 | - value: unknown 23 | return: unknown 24 | - name: "divide" 25 | impls: 26 | - args: 27 | - value: unknown 28 | - value: unknown 29 | return: unknown 30 | - name: "modulus" 31 | impls: 32 | - args: 33 | - value: unknown 34 | - value: unknown 35 | return: unknown 36 | aggregate_functions: 37 | - name: "sum" 38 | impls: 39 | - args: 40 | - value: unknown 41 | intermediate: unknown 42 | return: unknown 43 | - name: "avg" 44 | impls: 45 | - args: 46 | - value: unknown 47 | intermediate: unknown 48 | return: unknown 49 | - name: "min" 50 | impls: 51 | - args: 52 | - value: unknown 53 | intermediate: unknown 54 | return: unknown 55 | - name: "max" 56 | impls: 57 | - args: 58 | - value: unknown 59 | intermediate: unknown 60 | return: unknown 61 | - name: "count" 62 | impls: 63 | - args: 64 | - value: unknown 65 | intermediate: unknown 66 | return: unknown 67 | -------------------------------------------------------------------------------- /src/substrait/gen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/substrait-io/substrait-python/0be6fc9371f4d4a2201e1dfb12825daadc03bc31/src/substrait/gen/__init__.py -------------------------------------------------------------------------------- /src/substrait/gen/__init__.pyi: -------------------------------------------------------------------------------- 1 | from . import proto 2 | from . import antlr 3 | from . import json 4 | -------------------------------------------------------------------------------- /src/substrait/gen/antlr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/substrait-io/substrait-python/0be6fc9371f4d4a2201e1dfb12825daadc03bc31/src/substrait/gen/antlr/__init__.py -------------------------------------------------------------------------------- /src/substrait/gen/antlr/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/substrait-io/substrait-python/0be6fc9371f4d4a2201e1dfb12825daadc03bc31/src/substrait/gen/antlr/__init__.pyi -------------------------------------------------------------------------------- /src/substrait/gen/json/simple_extensions.py: -------------------------------------------------------------------------------- 1 | # generated by datamodel-codegen: 2 | # filename: simple_extensions_schema.yaml 3 | # timestamp: 2025-06-06T08:43:35+00:00 4 | 5 | from __future__ import annotations 6 | 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from typing import Any, Dict, List, Optional, Union 10 | 11 | 12 | class Functions(Enum): 13 | INHERITS = 'INHERITS' 14 | SEPARATE = 'SEPARATE' 15 | 16 | 17 | Type = Union[str, Dict[str, Any]] 18 | 19 | 20 | class Type1(Enum): 21 | dataType = 'dataType' 22 | boolean = 'boolean' 23 | integer = 'integer' 24 | enumeration = 'enumeration' 25 | string = 'string' 26 | 27 | 28 | EnumOptions = List[str] 29 | 30 | 31 | @dataclass 32 | class EnumerationArg: 33 | options: EnumOptions 34 | name: Optional[str] = None 35 | description: Optional[str] = None 36 | 37 | 38 | @dataclass 39 | class ValueArg: 40 | value: Type 41 | name: Optional[str] = None 42 | description: Optional[str] = None 43 | constant: Optional[bool] = None 44 | 45 | 46 | @dataclass 47 | class TypeArg: 48 | type: str 49 | name: Optional[str] = None 50 | description: Optional[str] = None 51 | 52 | 53 | Arguments = List[Union[EnumerationArg, ValueArg, TypeArg]] 54 | 55 | 56 | @dataclass 57 | class Options1: 58 | values: List[str] 59 | description: Optional[str] = None 60 | 61 | 62 | Options = Optional[Dict[str, Options1]] 63 | 64 | 65 | class ParameterConsistency(Enum): 66 | CONSISTENT = 'CONSISTENT' 67 | INCONSISTENT = 'INCONSISTENT' 68 | 69 | 70 | @dataclass 71 | class VariadicBehavior: 72 | min: Optional[float] = None 73 | max: Optional[float] = None 74 | parameterConsistency: Optional[ParameterConsistency] = None 75 | 76 | 77 | Deterministic = bool 78 | 79 | 80 | SessionDependent = bool 81 | 82 | 83 | class NullabilityHandling(Enum): 84 | MIRROR = 'MIRROR' 85 | DECLARED_OUTPUT = 'DECLARED_OUTPUT' 86 | DISCRETE = 'DISCRETE' 87 | 88 | 89 | ReturnValue = Type 90 | 91 | 92 | Implementation = Optional[Dict[str, str]] 93 | 94 | 95 | Intermediate = Type 96 | 97 | 98 | class Decomposable(Enum): 99 | NONE = 'NONE' 100 | ONE = 'ONE' 101 | MANY = 'MANY' 102 | 103 | 104 | Maxset = float 105 | 106 | 107 | Ordered = bool 108 | 109 | 110 | @dataclass 111 | class Impl: 112 | return_: ReturnValue 113 | args: Optional[Arguments] = None 114 | options: Optional[Options] = None 115 | variadic: Optional[VariadicBehavior] = None 116 | sessionDependent: Optional[SessionDependent] = None 117 | deterministic: Optional[Deterministic] = None 118 | nullability: Optional[NullabilityHandling] = None 119 | implementation: Optional[Implementation] = None 120 | 121 | 122 | @dataclass 123 | class ScalarFunction: 124 | name: str 125 | impls: List[Impl] 126 | description: Optional[str] = None 127 | 128 | 129 | @dataclass 130 | class Impl1: 131 | return_: ReturnValue 132 | args: Optional[Arguments] = None 133 | options: Optional[Options] = None 134 | variadic: Optional[VariadicBehavior] = None 135 | sessionDependent: Optional[SessionDependent] = None 136 | deterministic: Optional[Deterministic] = None 137 | nullability: Optional[NullabilityHandling] = None 138 | implementation: Optional[Implementation] = None 139 | intermediate: Optional[Intermediate] = None 140 | ordered: Optional[Ordered] = None 141 | maxset: Optional[Maxset] = None 142 | decomposable: Optional[Decomposable] = None 143 | 144 | 145 | @dataclass 146 | class AggregateFunction: 147 | name: str 148 | impls: List[Impl1] 149 | description: Optional[str] = None 150 | 151 | 152 | class WindowType(Enum): 153 | STREAMING = 'STREAMING' 154 | PARTITION = 'PARTITION' 155 | 156 | 157 | @dataclass 158 | class Impl2: 159 | return_: ReturnValue 160 | args: Optional[Arguments] = None 161 | options: Optional[Options] = None 162 | variadic: Optional[VariadicBehavior] = None 163 | sessionDependent: Optional[SessionDependent] = None 164 | deterministic: Optional[Deterministic] = None 165 | nullability: Optional[NullabilityHandling] = None 166 | implementation: Optional[Implementation] = None 167 | intermediate: Optional[Intermediate] = None 168 | ordered: Optional[Ordered] = None 169 | maxset: Optional[Maxset] = None 170 | decomposable: Optional[Decomposable] = None 171 | window_type: Optional[WindowType] = None 172 | 173 | 174 | @dataclass 175 | class WindowFunction: 176 | name: str 177 | impls: List[Impl2] 178 | description: Optional[str] = None 179 | 180 | 181 | @dataclass 182 | class TypeVariation: 183 | parent: Type 184 | name: str 185 | description: Optional[str] = None 186 | functions: Optional[Functions] = None 187 | 188 | 189 | @dataclass 190 | class TypeParamDef: 191 | type: Type1 192 | name: Optional[str] = None 193 | description: Optional[str] = None 194 | min: Optional[float] = None 195 | max: Optional[float] = None 196 | options: Optional[EnumOptions] = None 197 | optional: Optional[bool] = None 198 | 199 | 200 | TypeParamDefs = List[TypeParamDef] 201 | 202 | 203 | @dataclass 204 | class TypeModel: 205 | name: str 206 | structure: Optional[Type] = None 207 | parameters: Optional[TypeParamDefs] = None 208 | variadic: Optional[bool] = None 209 | 210 | 211 | @dataclass 212 | class SimpleExtensions: 213 | dependencies: Optional[Dict[str, str]] = None 214 | types: Optional[List[TypeModel]] = None 215 | type_variations: Optional[List[TypeVariation]] = None 216 | scalar_functions: Optional[List[ScalarFunction]] = None 217 | aggregate_functions: Optional[List[AggregateFunction]] = None 218 | window_functions: Optional[List[WindowFunction]] = None 219 | -------------------------------------------------------------------------------- /src/substrait/gen/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/substrait-io/substrait-python/0be6fc9371f4d4a2201e1dfb12825daadc03bc31/src/substrait/gen/proto/__init__.py -------------------------------------------------------------------------------- /src/substrait/gen/proto/__init__.pyi: -------------------------------------------------------------------------------- 1 | from . import algebra_pb2 2 | from . import capabilities_pb2 3 | from . import extended_expression_pb2 4 | from . import extensions 5 | from . import function_pb2 6 | from . import parameterized_types_pb2 7 | from . import plan_pb2 8 | from . import type_expressions_pb2 9 | from . import type_pb2 10 | -------------------------------------------------------------------------------- /src/substrait/gen/proto/capabilities_pb2.py: -------------------------------------------------------------------------------- 1 | """Generated protocol buffer code.""" 2 | from google.protobuf import descriptor as _descriptor 3 | from google.protobuf import descriptor_pool as _descriptor_pool 4 | from google.protobuf import symbol_database as _symbol_database 5 | from google.protobuf.internal import builder as _builder 6 | _sym_db = _symbol_database.Default() 7 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18proto/capabilities.proto\x12\x05proto"\xe8\x02\n\x0cCapabilities\x12-\n\x12substrait_versions\x18\x01 \x03(\tR\x11substraitVersions\x12?\n\x1cadvanced_extension_type_urls\x18\x02 \x03(\tR\x19advancedExtensionTypeUrls\x12P\n\x11simple_extensions\x18\x03 \x03(\x0b2#.proto.Capabilities.SimpleExtensionR\x10simpleExtensions\x1a\x95\x01\n\x0fSimpleExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12#\n\rfunction_keys\x18\x02 \x03(\tR\x0cfunctionKeys\x12\x1b\n\ttype_keys\x18\x03 \x03(\tR\x08typeKeys\x12.\n\x13type_variation_keys\x18\x04 \x03(\tR\x11typeVariationKeysB#\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobufb\x06proto3') 8 | _globals = globals() 9 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 10 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.capabilities_pb2', _globals) 11 | if _descriptor._USE_C_DESCRIPTORS == False: 12 | _globals['DESCRIPTOR']._options = None 13 | _globals['DESCRIPTOR']._serialized_options = b'\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobuf' 14 | _globals['_CAPABILITIES']._serialized_start = 36 15 | _globals['_CAPABILITIES']._serialized_end = 396 16 | _globals['_CAPABILITIES_SIMPLEEXTENSION']._serialized_start = 247 17 | _globals['_CAPABILITIES_SIMPLEEXTENSION']._serialized_end = 396 -------------------------------------------------------------------------------- /src/substrait/gen/proto/capabilities_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | SPDX-License-Identifier: Apache-2.0""" 5 | import builtins 6 | import collections.abc 7 | import google.protobuf.descriptor 8 | import google.protobuf.internal.containers 9 | import google.protobuf.message 10 | import sys 11 | if sys.version_info >= (3, 8): 12 | import typing as typing_extensions 13 | else: 14 | import typing_extensions 15 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 16 | 17 | @typing_extensions.final 18 | class Capabilities(google.protobuf.message.Message): 19 | """Defines a set of Capabilities that a system (producer or consumer) supports.""" 20 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 21 | 22 | @typing_extensions.final 23 | class SimpleExtension(google.protobuf.message.Message): 24 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 25 | URI_FIELD_NUMBER: builtins.int 26 | FUNCTION_KEYS_FIELD_NUMBER: builtins.int 27 | TYPE_KEYS_FIELD_NUMBER: builtins.int 28 | TYPE_VARIATION_KEYS_FIELD_NUMBER: builtins.int 29 | uri: builtins.str 30 | 31 | @property 32 | def function_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 33 | ... 34 | 35 | @property 36 | def type_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 37 | ... 38 | 39 | @property 40 | def type_variation_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 41 | ... 42 | 43 | def __init__(self, *, uri: builtins.str=..., function_keys: collections.abc.Iterable[builtins.str] | None=..., type_keys: collections.abc.Iterable[builtins.str] | None=..., type_variation_keys: collections.abc.Iterable[builtins.str] | None=...) -> None: 44 | ... 45 | 46 | def ClearField(self, field_name: typing_extensions.Literal['function_keys', b'function_keys', 'type_keys', b'type_keys', 'type_variation_keys', b'type_variation_keys', 'uri', b'uri']) -> None: 47 | ... 48 | SUBSTRAIT_VERSIONS_FIELD_NUMBER: builtins.int 49 | ADVANCED_EXTENSION_TYPE_URLS_FIELD_NUMBER: builtins.int 50 | SIMPLE_EXTENSIONS_FIELD_NUMBER: builtins.int 51 | 52 | @property 53 | def substrait_versions(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 54 | """List of Substrait versions this system supports""" 55 | 56 | @property 57 | def advanced_extension_type_urls(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 58 | """list of com.google.Any message types this system supports for advanced 59 | extensions. 60 | """ 61 | 62 | @property 63 | def simple_extensions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Capabilities.SimpleExtension]: 64 | """list of simple extensions this system supports.""" 65 | 66 | def __init__(self, *, substrait_versions: collections.abc.Iterable[builtins.str] | None=..., advanced_extension_type_urls: collections.abc.Iterable[builtins.str] | None=..., simple_extensions: collections.abc.Iterable[global___Capabilities.SimpleExtension] | None=...) -> None: 67 | ... 68 | 69 | def ClearField(self, field_name: typing_extensions.Literal['advanced_extension_type_urls', b'advanced_extension_type_urls', 'simple_extensions', b'simple_extensions', 'substrait_versions', b'substrait_versions']) -> None: 70 | ... 71 | global___Capabilities = Capabilities -------------------------------------------------------------------------------- /src/substrait/gen/proto/extended_expression_pb2.py: -------------------------------------------------------------------------------- 1 | """Generated protocol buffer code.""" 2 | from google.protobuf import descriptor as _descriptor 3 | from google.protobuf import descriptor_pool as _descriptor_pool 4 | from google.protobuf import symbol_database as _symbol_database 5 | from google.protobuf.internal import builder as _builder 6 | _sym_db = _symbol_database.Default() 7 | from ..proto import algebra_pb2 as proto_dot_algebra__pb2 8 | from ..proto.extensions import extensions_pb2 as proto_dot_extensions_dot_extensions__pb2 9 | from ..proto import plan_pb2 as proto_dot_plan__pb2 10 | from ..proto import type_pb2 as proto_dot_type__pb2 11 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fproto/extended_expression.proto\x12\x05proto\x1a\x13proto/algebra.proto\x1a!proto/extensions/extensions.proto\x1a\x10proto/plan.proto\x1a\x10proto/type.proto"\xb0\x01\n\x13ExpressionReference\x123\n\nexpression\x18\x01 \x01(\x0b2\x11.proto.ExpressionH\x00R\nexpression\x124\n\x07measure\x18\x02 \x01(\x0b2\x18.proto.AggregateFunctionH\x00R\x07measure\x12!\n\x0coutput_names\x18\x03 \x03(\tR\x0boutputNamesB\x0b\n\texpr_type"\xd3\x03\n\x12ExtendedExpression\x12(\n\x07version\x18\x07 \x01(\x0b2\x0e.proto.VersionR\x07version\x12K\n\x0eextension_uris\x18\x01 \x03(\x0b2$.proto.extensions.SimpleExtensionURIR\rextensionUris\x12L\n\nextensions\x18\x02 \x03(\x0b2,.proto.extensions.SimpleExtensionDeclarationR\nextensions\x12?\n\rreferred_expr\x18\x03 \x03(\x0b2\x1a.proto.ExpressionReferenceR\x0creferredExpr\x123\n\x0bbase_schema\x18\x04 \x01(\x0b2\x12.proto.NamedStructR\nbaseSchema\x12T\n\x13advanced_extensions\x18\x05 \x01(\x0b2#.proto.extensions.AdvancedExtensionR\x12advancedExtensions\x12,\n\x12expected_type_urls\x18\x06 \x03(\tR\x10expectedTypeUrlsB#\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobufb\x06proto3') 12 | _globals = globals() 13 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 14 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.extended_expression_pb2', _globals) 15 | if _descriptor._USE_C_DESCRIPTORS == False: 16 | _globals['DESCRIPTOR']._options = None 17 | _globals['DESCRIPTOR']._serialized_options = b'\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobuf' 18 | _globals['_EXPRESSIONREFERENCE']._serialized_start = 135 19 | _globals['_EXPRESSIONREFERENCE']._serialized_end = 311 20 | _globals['_EXTENDEDEXPRESSION']._serialized_start = 314 21 | _globals['_EXTENDEDEXPRESSION']._serialized_end = 781 -------------------------------------------------------------------------------- /src/substrait/gen/proto/extended_expression_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | SPDX-License-Identifier: Apache-2.0""" 5 | import builtins 6 | import collections.abc 7 | import google.protobuf.descriptor 8 | import google.protobuf.internal.containers 9 | import google.protobuf.message 10 | from .. import proto 11 | import sys 12 | if sys.version_info >= (3, 8): 13 | import typing as typing_extensions 14 | else: 15 | import typing_extensions 16 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 17 | 18 | @typing_extensions.final 19 | class ExpressionReference(google.protobuf.message.Message): 20 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 21 | EXPRESSION_FIELD_NUMBER: builtins.int 22 | MEASURE_FIELD_NUMBER: builtins.int 23 | OUTPUT_NAMES_FIELD_NUMBER: builtins.int 24 | 25 | @property 26 | def expression(self) -> proto.algebra_pb2.Expression: 27 | ... 28 | 29 | @property 30 | def measure(self) -> proto.algebra_pb2.AggregateFunction: 31 | ... 32 | 33 | @property 34 | def output_names(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 35 | """Field names in depth-first order""" 36 | 37 | def __init__(self, *, expression: proto.algebra_pb2.Expression | None=..., measure: proto.algebra_pb2.AggregateFunction | None=..., output_names: collections.abc.Iterable[builtins.str] | None=...) -> None: 38 | ... 39 | 40 | def HasField(self, field_name: typing_extensions.Literal['expr_type', b'expr_type', 'expression', b'expression', 'measure', b'measure']) -> builtins.bool: 41 | ... 42 | 43 | def ClearField(self, field_name: typing_extensions.Literal['expr_type', b'expr_type', 'expression', b'expression', 'measure', b'measure', 'output_names', b'output_names']) -> None: 44 | ... 45 | 46 | def WhichOneof(self, oneof_group: typing_extensions.Literal['expr_type', b'expr_type']) -> typing_extensions.Literal['expression', 'measure'] | None: 47 | ... 48 | global___ExpressionReference = ExpressionReference 49 | 50 | @typing_extensions.final 51 | class ExtendedExpression(google.protobuf.message.Message): 52 | """Describe a set of operations to complete. 53 | For compactness sake, identifiers are normalized at the plan level. 54 | """ 55 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 56 | VERSION_FIELD_NUMBER: builtins.int 57 | EXTENSION_URIS_FIELD_NUMBER: builtins.int 58 | EXTENSIONS_FIELD_NUMBER: builtins.int 59 | REFERRED_EXPR_FIELD_NUMBER: builtins.int 60 | BASE_SCHEMA_FIELD_NUMBER: builtins.int 61 | ADVANCED_EXTENSIONS_FIELD_NUMBER: builtins.int 62 | EXPECTED_TYPE_URLS_FIELD_NUMBER: builtins.int 63 | 64 | @property 65 | def version(self) -> proto.plan_pb2.Version: 66 | """Substrait version of the expression. Optional up to 0.17.0, required for later 67 | versions. 68 | """ 69 | 70 | @property 71 | def extension_uris(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[proto.extensions.extensions_pb2.SimpleExtensionURI]: 72 | """a list of yaml specifications this expression may depend on""" 73 | 74 | @property 75 | def extensions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[proto.extensions.extensions_pb2.SimpleExtensionDeclaration]: 76 | """a list of extensions this expression may depend on""" 77 | 78 | @property 79 | def referred_expr(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ExpressionReference]: 80 | """one or more expression trees with same order in plan rel""" 81 | 82 | @property 83 | def base_schema(self) -> proto.type_pb2.NamedStruct: 84 | ... 85 | 86 | @property 87 | def advanced_extensions(self) -> proto.extensions.extensions_pb2.AdvancedExtension: 88 | """additional extensions associated with this expression.""" 89 | 90 | @property 91 | def expected_type_urls(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 92 | """A list of com.google.Any entities that this plan may use. Can be used to 93 | warn if some embedded message types are unknown. Note that this list may 94 | include message types that are ignorable (optimizations) or that are 95 | unused. In many cases, a consumer may be able to work with a plan even if 96 | one or more message types defined here are unknown. 97 | """ 98 | 99 | def __init__(self, *, version: proto.plan_pb2.Version | None=..., extension_uris: collections.abc.Iterable[proto.extensions.extensions_pb2.SimpleExtensionURI] | None=..., extensions: collections.abc.Iterable[proto.extensions.extensions_pb2.SimpleExtensionDeclaration] | None=..., referred_expr: collections.abc.Iterable[global___ExpressionReference] | None=..., base_schema: proto.type_pb2.NamedStruct | None=..., advanced_extensions: proto.extensions.extensions_pb2.AdvancedExtension | None=..., expected_type_urls: collections.abc.Iterable[builtins.str] | None=...) -> None: 100 | ... 101 | 102 | def HasField(self, field_name: typing_extensions.Literal['advanced_extensions', b'advanced_extensions', 'base_schema', b'base_schema', 'version', b'version']) -> builtins.bool: 103 | ... 104 | 105 | def ClearField(self, field_name: typing_extensions.Literal['advanced_extensions', b'advanced_extensions', 'base_schema', b'base_schema', 'expected_type_urls', b'expected_type_urls', 'extension_uris', b'extension_uris', 'extensions', b'extensions', 'referred_expr', b'referred_expr', 'version', b'version']) -> None: 106 | ... 107 | global___ExtendedExpression = ExtendedExpression -------------------------------------------------------------------------------- /src/substrait/gen/proto/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/substrait-io/substrait-python/0be6fc9371f4d4a2201e1dfb12825daadc03bc31/src/substrait/gen/proto/extensions/__init__.py -------------------------------------------------------------------------------- /src/substrait/gen/proto/extensions/__init__.pyi: -------------------------------------------------------------------------------- 1 | from . import extensions_pb2 2 | -------------------------------------------------------------------------------- /src/substrait/gen/proto/extensions/extensions_pb2.py: -------------------------------------------------------------------------------- 1 | """Generated protocol buffer code.""" 2 | from google.protobuf import descriptor as _descriptor 3 | from google.protobuf import descriptor_pool as _descriptor_pool 4 | from google.protobuf import symbol_database as _symbol_database 5 | from google.protobuf.internal import builder as _builder 6 | _sym_db = _symbol_database.Default() 7 | from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 8 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n!proto/extensions/extensions.proto\x12\x10proto.extensions\x1a\x19google/protobuf/any.proto"X\n\x12SimpleExtensionURI\x120\n\x14extension_uri_anchor\x18\x01 \x01(\rR\x12extensionUriAnchor\x12\x10\n\x03uri\x18\x02 \x01(\tR\x03uri"\xa7\x06\n\x1aSimpleExtensionDeclaration\x12c\n\x0eextension_type\x18\x01 \x01(\x0b2:.proto.extensions.SimpleExtensionDeclaration.ExtensionTypeH\x00R\rextensionType\x12\x7f\n\x18extension_type_variation\x18\x02 \x01(\x0b2C.proto.extensions.SimpleExtensionDeclaration.ExtensionTypeVariationH\x00R\x16extensionTypeVariation\x12o\n\x12extension_function\x18\x03 \x01(\x0b2>.proto.extensions.SimpleExtensionDeclaration.ExtensionFunctionH\x00R\x11extensionFunction\x1a|\n\rExtensionType\x126\n\x17extension_uri_reference\x18\x01 \x01(\rR\x15extensionUriReference\x12\x1f\n\x0btype_anchor\x18\x02 \x01(\rR\ntypeAnchor\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x1a\x98\x01\n\x16ExtensionTypeVariation\x126\n\x17extension_uri_reference\x18\x01 \x01(\rR\x15extensionUriReference\x122\n\x15type_variation_anchor\x18\x02 \x01(\rR\x13typeVariationAnchor\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x1a\x88\x01\n\x11ExtensionFunction\x126\n\x17extension_uri_reference\x18\x01 \x01(\rR\x15extensionUriReference\x12\'\n\x0ffunction_anchor\x18\x02 \x01(\rR\x0efunctionAnchor\x12\x12\n\x04name\x18\x03 \x01(\tR\x04nameB\x0e\n\x0cmapping_type"\x85\x01\n\x11AdvancedExtension\x128\n\x0coptimization\x18\x01 \x03(\x0b2\x14.google.protobuf.AnyR\x0coptimization\x126\n\x0benhancement\x18\x02 \x01(\x0b2\x14.google.protobuf.AnyR\x0benhancementB#\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobufb\x06proto3') 9 | _globals = globals() 10 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 11 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.extensions.extensions_pb2', _globals) 12 | if _descriptor._USE_C_DESCRIPTORS == False: 13 | _globals['DESCRIPTOR']._options = None 14 | _globals['DESCRIPTOR']._serialized_options = b'\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobuf' 15 | _globals['_SIMPLEEXTENSIONURI']._serialized_start = 82 16 | _globals['_SIMPLEEXTENSIONURI']._serialized_end = 170 17 | _globals['_SIMPLEEXTENSIONDECLARATION']._serialized_start = 173 18 | _globals['_SIMPLEEXTENSIONDECLARATION']._serialized_end = 980 19 | _globals['_SIMPLEEXTENSIONDECLARATION_EXTENSIONTYPE']._serialized_start = 546 20 | _globals['_SIMPLEEXTENSIONDECLARATION_EXTENSIONTYPE']._serialized_end = 670 21 | _globals['_SIMPLEEXTENSIONDECLARATION_EXTENSIONTYPEVARIATION']._serialized_start = 673 22 | _globals['_SIMPLEEXTENSIONDECLARATION_EXTENSIONTYPEVARIATION']._serialized_end = 825 23 | _globals['_SIMPLEEXTENSIONDECLARATION_EXTENSIONFUNCTION']._serialized_start = 828 24 | _globals['_SIMPLEEXTENSIONDECLARATION_EXTENSIONFUNCTION']._serialized_end = 964 25 | _globals['_ADVANCEDEXTENSION']._serialized_start = 983 26 | _globals['_ADVANCEDEXTENSION']._serialized_end = 1116 -------------------------------------------------------------------------------- /src/substrait/gen/proto/extensions/extensions_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | SPDX-License-Identifier: Apache-2.0""" 5 | import builtins 6 | import collections.abc 7 | import google.protobuf.any_pb2 8 | import google.protobuf.descriptor 9 | import google.protobuf.internal.containers 10 | import google.protobuf.message 11 | import sys 12 | if sys.version_info >= (3, 8): 13 | import typing as typing_extensions 14 | else: 15 | import typing_extensions 16 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 17 | 18 | @typing_extensions.final 19 | class SimpleExtensionURI(google.protobuf.message.Message): 20 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 21 | EXTENSION_URI_ANCHOR_FIELD_NUMBER: builtins.int 22 | URI_FIELD_NUMBER: builtins.int 23 | extension_uri_anchor: builtins.int 24 | 'A surrogate key used in the context of a single plan used to reference the\n URI associated with an extension.\n ' 25 | uri: builtins.str 26 | 'The URI where this extension YAML can be retrieved. This is the "namespace"\n of this extension.\n ' 27 | 28 | def __init__(self, *, extension_uri_anchor: builtins.int=..., uri: builtins.str=...) -> None: 29 | ... 30 | 31 | def ClearField(self, field_name: typing_extensions.Literal['extension_uri_anchor', b'extension_uri_anchor', 'uri', b'uri']) -> None: 32 | ... 33 | global___SimpleExtensionURI = SimpleExtensionURI 34 | 35 | @typing_extensions.final 36 | class SimpleExtensionDeclaration(google.protobuf.message.Message): 37 | """Describes a mapping between a specific extension entity and the uri where 38 | that extension can be found. 39 | """ 40 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 41 | 42 | @typing_extensions.final 43 | class ExtensionType(google.protobuf.message.Message): 44 | """Describes a Type""" 45 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 46 | EXTENSION_URI_REFERENCE_FIELD_NUMBER: builtins.int 47 | TYPE_ANCHOR_FIELD_NUMBER: builtins.int 48 | NAME_FIELD_NUMBER: builtins.int 49 | extension_uri_reference: builtins.int 50 | 'references the extension_uri_anchor defined for a specific extension URI.' 51 | type_anchor: builtins.int 52 | 'A surrogate key used in the context of a single plan to reference a\n specific extension type\n ' 53 | name: builtins.str 54 | 'the name of the type in the defined extension YAML.' 55 | 56 | def __init__(self, *, extension_uri_reference: builtins.int=..., type_anchor: builtins.int=..., name: builtins.str=...) -> None: 57 | ... 58 | 59 | def ClearField(self, field_name: typing_extensions.Literal['extension_uri_reference', b'extension_uri_reference', 'name', b'name', 'type_anchor', b'type_anchor']) -> None: 60 | ... 61 | 62 | @typing_extensions.final 63 | class ExtensionTypeVariation(google.protobuf.message.Message): 64 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 65 | EXTENSION_URI_REFERENCE_FIELD_NUMBER: builtins.int 66 | TYPE_VARIATION_ANCHOR_FIELD_NUMBER: builtins.int 67 | NAME_FIELD_NUMBER: builtins.int 68 | extension_uri_reference: builtins.int 69 | 'references the extension_uri_anchor defined for a specific extension URI.' 70 | type_variation_anchor: builtins.int 71 | 'A surrogate key used in the context of a single plan to reference a\n specific type variation\n ' 72 | name: builtins.str 73 | 'the name of the type in the defined extension YAML.' 74 | 75 | def __init__(self, *, extension_uri_reference: builtins.int=..., type_variation_anchor: builtins.int=..., name: builtins.str=...) -> None: 76 | ... 77 | 78 | def ClearField(self, field_name: typing_extensions.Literal['extension_uri_reference', b'extension_uri_reference', 'name', b'name', 'type_variation_anchor', b'type_variation_anchor']) -> None: 79 | ... 80 | 81 | @typing_extensions.final 82 | class ExtensionFunction(google.protobuf.message.Message): 83 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 84 | EXTENSION_URI_REFERENCE_FIELD_NUMBER: builtins.int 85 | FUNCTION_ANCHOR_FIELD_NUMBER: builtins.int 86 | NAME_FIELD_NUMBER: builtins.int 87 | extension_uri_reference: builtins.int 88 | 'references the extension_uri_anchor defined for a specific extension URI.' 89 | function_anchor: builtins.int 90 | 'A surrogate key used in the context of a single plan to reference a\n specific function\n ' 91 | name: builtins.str 92 | 'A function signature compound name' 93 | 94 | def __init__(self, *, extension_uri_reference: builtins.int=..., function_anchor: builtins.int=..., name: builtins.str=...) -> None: 95 | ... 96 | 97 | def ClearField(self, field_name: typing_extensions.Literal['extension_uri_reference', b'extension_uri_reference', 'function_anchor', b'function_anchor', 'name', b'name']) -> None: 98 | ... 99 | EXTENSION_TYPE_FIELD_NUMBER: builtins.int 100 | EXTENSION_TYPE_VARIATION_FIELD_NUMBER: builtins.int 101 | EXTENSION_FUNCTION_FIELD_NUMBER: builtins.int 102 | 103 | @property 104 | def extension_type(self) -> global___SimpleExtensionDeclaration.ExtensionType: 105 | ... 106 | 107 | @property 108 | def extension_type_variation(self) -> global___SimpleExtensionDeclaration.ExtensionTypeVariation: 109 | ... 110 | 111 | @property 112 | def extension_function(self) -> global___SimpleExtensionDeclaration.ExtensionFunction: 113 | ... 114 | 115 | def __init__(self, *, extension_type: global___SimpleExtensionDeclaration.ExtensionType | None=..., extension_type_variation: global___SimpleExtensionDeclaration.ExtensionTypeVariation | None=..., extension_function: global___SimpleExtensionDeclaration.ExtensionFunction | None=...) -> None: 116 | ... 117 | 118 | def HasField(self, field_name: typing_extensions.Literal['extension_function', b'extension_function', 'extension_type', b'extension_type', 'extension_type_variation', b'extension_type_variation', 'mapping_type', b'mapping_type']) -> builtins.bool: 119 | ... 120 | 121 | def ClearField(self, field_name: typing_extensions.Literal['extension_function', b'extension_function', 'extension_type', b'extension_type', 'extension_type_variation', b'extension_type_variation', 'mapping_type', b'mapping_type']) -> None: 122 | ... 123 | 124 | def WhichOneof(self, oneof_group: typing_extensions.Literal['mapping_type', b'mapping_type']) -> typing_extensions.Literal['extension_type', 'extension_type_variation', 'extension_function'] | None: 125 | ... 126 | global___SimpleExtensionDeclaration = SimpleExtensionDeclaration 127 | 128 | @typing_extensions.final 129 | class AdvancedExtension(google.protobuf.message.Message): 130 | """A generic object that can be used to embed additional extension information 131 | into the serialized substrait plan. 132 | """ 133 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 134 | OPTIMIZATION_FIELD_NUMBER: builtins.int 135 | ENHANCEMENT_FIELD_NUMBER: builtins.int 136 | 137 | @property 138 | def optimization(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[google.protobuf.any_pb2.Any]: 139 | """An optimization is helpful information that don't influence semantics. May 140 | be ignored by a consumer. 141 | """ 142 | 143 | @property 144 | def enhancement(self) -> google.protobuf.any_pb2.Any: 145 | """An enhancement alter semantics. Cannot be ignored by a consumer.""" 146 | 147 | def __init__(self, *, optimization: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None=..., enhancement: google.protobuf.any_pb2.Any | None=...) -> None: 148 | ... 149 | 150 | def HasField(self, field_name: typing_extensions.Literal['enhancement', b'enhancement']) -> builtins.bool: 151 | ... 152 | 153 | def ClearField(self, field_name: typing_extensions.Literal['enhancement', b'enhancement', 'optimization', b'optimization']) -> None: 154 | ... 155 | global___AdvancedExtension = AdvancedExtension -------------------------------------------------------------------------------- /src/substrait/gen/proto/function_pb2.py: -------------------------------------------------------------------------------- 1 | """Generated protocol buffer code.""" 2 | from google.protobuf import descriptor as _descriptor 3 | from google.protobuf import descriptor_pool as _descriptor_pool 4 | from google.protobuf import symbol_database as _symbol_database 5 | from google.protobuf.internal import builder as _builder 6 | _sym_db = _symbol_database.Default() 7 | from ..proto import parameterized_types_pb2 as proto_dot_parameterized__types__pb2 8 | from ..proto import type_pb2 as proto_dot_type__pb2 9 | from ..proto import type_expressions_pb2 as proto_dot_type__expressions__pb2 10 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14proto/function.proto\x12\x05proto\x1a\x1fproto/parameterized_types.proto\x1a\x10proto/type.proto\x1a\x1cproto/type_expressions.proto"\xe9\x18\n\x11FunctionSignature\x1a\xb8\x02\n\x10FinalArgVariadic\x12\x19\n\x08min_args\x18\x01 \x01(\x03R\x07minArgs\x12\x19\n\x08max_args\x18\x02 \x01(\x03R\x07maxArgs\x12`\n\x0bconsistency\x18\x03 \x01(\x0e2>.proto.FunctionSignature.FinalArgVariadic.ParameterConsistencyR\x0bconsistency"\x8b\x01\n\x14ParameterConsistency\x12%\n!PARAMETER_CONSISTENCY_UNSPECIFIED\x10\x00\x12$\n PARAMETER_CONSISTENCY_CONSISTENT\x10\x01\x12&\n"PARAMETER_CONSISTENCY_INCONSISTENT\x10\x02\x1a\x10\n\x0eFinalArgNormal\x1a\xb0\x04\n\x06Scalar\x12?\n\targuments\x18\x02 \x03(\x0b2!.proto.FunctionSignature.ArgumentR\targuments\x12\x12\n\x04name\x18\x03 \x03(\tR\x04name\x12F\n\x0bdescription\x18\x04 \x01(\x0b2$.proto.FunctionSignature.DescriptionR\x0bdescription\x12$\n\rdeterministic\x18\x07 \x01(\x08R\rdeterministic\x12+\n\x11session_dependent\x18\x08 \x01(\x08R\x10sessionDependent\x12<\n\x0boutput_type\x18\t \x01(\x0b2\x1b.proto.DerivationExpressionR\noutputType\x12G\n\x08variadic\x18\n \x01(\x0b2).proto.FunctionSignature.FinalArgVariadicH\x00R\x08variadic\x12A\n\x06normal\x18\x0b \x01(\x0b2\'.proto.FunctionSignature.FinalArgNormalH\x00R\x06normal\x12Q\n\x0fimplementations\x18\x0c \x03(\x0b2\'.proto.FunctionSignature.ImplementationR\x0fimplementationsB\x19\n\x17final_variable_behavior\x1a\xa0\x05\n\tAggregate\x12?\n\targuments\x18\x02 \x03(\x0b2!.proto.FunctionSignature.ArgumentR\targuments\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12F\n\x0bdescription\x18\x04 \x01(\x0b2$.proto.FunctionSignature.DescriptionR\x0bdescription\x12$\n\rdeterministic\x18\x07 \x01(\x08R\rdeterministic\x12+\n\x11session_dependent\x18\x08 \x01(\x08R\x10sessionDependent\x12<\n\x0boutput_type\x18\t \x01(\x0b2\x1b.proto.DerivationExpressionR\noutputType\x12G\n\x08variadic\x18\n \x01(\x0b2).proto.FunctionSignature.FinalArgVariadicH\x00R\x08variadic\x12A\n\x06normal\x18\x0b \x01(\x0b2\'.proto.FunctionSignature.FinalArgNormalH\x00R\x06normal\x12\x18\n\x07ordered\x18\x0e \x01(\x08R\x07ordered\x12\x17\n\x07max_set\x18\x0c \x01(\x04R\x06maxSet\x128\n\x11intermediate_type\x18\r \x01(\x0b2\x0b.proto.TypeR\x10intermediateType\x12Q\n\x0fimplementations\x18\x0f \x03(\x0b2\'.proto.FunctionSignature.ImplementationR\x0fimplementationsB\x19\n\x17final_variable_behavior\x1a\xdb\x06\n\x06Window\x12?\n\targuments\x18\x02 \x03(\x0b2!.proto.FunctionSignature.ArgumentR\targuments\x12\x12\n\x04name\x18\x03 \x03(\tR\x04name\x12F\n\x0bdescription\x18\x04 \x01(\x0b2$.proto.FunctionSignature.DescriptionR\x0bdescription\x12$\n\rdeterministic\x18\x07 \x01(\x08R\rdeterministic\x12+\n\x11session_dependent\x18\x08 \x01(\x08R\x10sessionDependent\x12H\n\x11intermediate_type\x18\t \x01(\x0b2\x1b.proto.DerivationExpressionR\x10intermediateType\x12<\n\x0boutput_type\x18\n \x01(\x0b2\x1b.proto.DerivationExpressionR\noutputType\x12G\n\x08variadic\x18\x10 \x01(\x0b2).proto.FunctionSignature.FinalArgVariadicH\x00R\x08variadic\x12A\n\x06normal\x18\x11 \x01(\x0b2\'.proto.FunctionSignature.FinalArgNormalH\x00R\x06normal\x12\x18\n\x07ordered\x18\x0b \x01(\x08R\x07ordered\x12\x17\n\x07max_set\x18\x0c \x01(\x04R\x06maxSet\x12K\n\x0bwindow_type\x18\x0e \x01(\x0e2*.proto.FunctionSignature.Window.WindowTypeR\nwindowType\x12Q\n\x0fimplementations\x18\x0f \x03(\x0b2\'.proto.FunctionSignature.ImplementationR\x0fimplementations"_\n\nWindowType\x12\x1b\n\x17WINDOW_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15WINDOW_TYPE_STREAMING\x10\x01\x12\x19\n\x15WINDOW_TYPE_PARTITION\x10\x02B\x19\n\x17final_variable_behavior\x1a=\n\x0bDescription\x12\x1a\n\x08language\x18\x01 \x01(\tR\x08language\x12\x12\n\x04body\x18\x02 \x01(\tR\x04body\x1a\xad\x01\n\x0eImplementation\x12@\n\x04type\x18\x01 \x01(\x0e2,.proto.FunctionSignature.Implementation.TypeR\x04type\x12\x10\n\x03uri\x18\x02 \x01(\tR\x03uri"G\n\x04Type\x12\x14\n\x10TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11TYPE_WEB_ASSEMBLY\x10\x01\x12\x12\n\x0eTYPE_TRINO_JAR\x10\x02\x1a\xe3\x03\n\x08Argument\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12G\n\x05value\x18\x02 \x01(\x0b2/.proto.FunctionSignature.Argument.ValueArgumentH\x00R\x05value\x12D\n\x04type\x18\x03 \x01(\x0b2..proto.FunctionSignature.Argument.TypeArgumentH\x00R\x04type\x12D\n\x04enum\x18\x04 \x01(\x0b2..proto.FunctionSignature.Argument.EnumArgumentH\x00R\x04enum\x1aY\n\rValueArgument\x12,\n\x04type\x18\x01 \x01(\x0b2\x18.proto.ParameterizedTypeR\x04type\x12\x1a\n\x08constant\x18\x02 \x01(\x08R\x08constant\x1a<\n\x0cTypeArgument\x12,\n\x04type\x18\x01 \x01(\x0b2\x18.proto.ParameterizedTypeR\x04type\x1aD\n\x0cEnumArgument\x12\x18\n\x07options\x18\x01 \x03(\tR\x07options\x12\x1a\n\x08optional\x18\x02 \x01(\x08R\x08optionalB\x0f\n\rargument_kindB#\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobufb\x06proto3') 11 | _globals = globals() 12 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 13 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.function_pb2', _globals) 14 | if _descriptor._USE_C_DESCRIPTORS == False: 15 | _globals['DESCRIPTOR']._options = None 16 | _globals['DESCRIPTOR']._serialized_options = b'\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobuf' 17 | _globals['_FUNCTIONSIGNATURE']._serialized_start = 113 18 | _globals['_FUNCTIONSIGNATURE']._serialized_end = 3290 19 | _globals['_FUNCTIONSIGNATURE_FINALARGVARIADIC']._serialized_start = 135 20 | _globals['_FUNCTIONSIGNATURE_FINALARGVARIADIC']._serialized_end = 447 21 | _globals['_FUNCTIONSIGNATURE_FINALARGVARIADIC_PARAMETERCONSISTENCY']._serialized_start = 308 22 | _globals['_FUNCTIONSIGNATURE_FINALARGVARIADIC_PARAMETERCONSISTENCY']._serialized_end = 447 23 | _globals['_FUNCTIONSIGNATURE_FINALARGNORMAL']._serialized_start = 449 24 | _globals['_FUNCTIONSIGNATURE_FINALARGNORMAL']._serialized_end = 465 25 | _globals['_FUNCTIONSIGNATURE_SCALAR']._serialized_start = 468 26 | _globals['_FUNCTIONSIGNATURE_SCALAR']._serialized_end = 1028 27 | _globals['_FUNCTIONSIGNATURE_AGGREGATE']._serialized_start = 1031 28 | _globals['_FUNCTIONSIGNATURE_AGGREGATE']._serialized_end = 1703 29 | _globals['_FUNCTIONSIGNATURE_WINDOW']._serialized_start = 1706 30 | _globals['_FUNCTIONSIGNATURE_WINDOW']._serialized_end = 2565 31 | _globals['_FUNCTIONSIGNATURE_WINDOW_WINDOWTYPE']._serialized_start = 2443 32 | _globals['_FUNCTIONSIGNATURE_WINDOW_WINDOWTYPE']._serialized_end = 2538 33 | _globals['_FUNCTIONSIGNATURE_DESCRIPTION']._serialized_start = 2567 34 | _globals['_FUNCTIONSIGNATURE_DESCRIPTION']._serialized_end = 2628 35 | _globals['_FUNCTIONSIGNATURE_IMPLEMENTATION']._serialized_start = 2631 36 | _globals['_FUNCTIONSIGNATURE_IMPLEMENTATION']._serialized_end = 2804 37 | _globals['_FUNCTIONSIGNATURE_IMPLEMENTATION_TYPE']._serialized_start = 2733 38 | _globals['_FUNCTIONSIGNATURE_IMPLEMENTATION_TYPE']._serialized_end = 2804 39 | _globals['_FUNCTIONSIGNATURE_ARGUMENT']._serialized_start = 2807 40 | _globals['_FUNCTIONSIGNATURE_ARGUMENT']._serialized_end = 3290 41 | _globals['_FUNCTIONSIGNATURE_ARGUMENT_VALUEARGUMENT']._serialized_start = 3052 42 | _globals['_FUNCTIONSIGNATURE_ARGUMENT_VALUEARGUMENT']._serialized_end = 3141 43 | _globals['_FUNCTIONSIGNATURE_ARGUMENT_TYPEARGUMENT']._serialized_start = 3143 44 | _globals['_FUNCTIONSIGNATURE_ARGUMENT_TYPEARGUMENT']._serialized_end = 3203 45 | _globals['_FUNCTIONSIGNATURE_ARGUMENT_ENUMARGUMENT']._serialized_start = 3205 46 | _globals['_FUNCTIONSIGNATURE_ARGUMENT_ENUMARGUMENT']._serialized_end = 3273 -------------------------------------------------------------------------------- /src/substrait/gen/proto/plan_pb2.py: -------------------------------------------------------------------------------- 1 | """Generated protocol buffer code.""" 2 | from google.protobuf import descriptor as _descriptor 3 | from google.protobuf import descriptor_pool as _descriptor_pool 4 | from google.protobuf import symbol_database as _symbol_database 5 | from google.protobuf.internal import builder as _builder 6 | _sym_db = _symbol_database.Default() 7 | from ..proto import algebra_pb2 as proto_dot_algebra__pb2 8 | from ..proto.extensions import extensions_pb2 as proto_dot_extensions_dot_extensions__pb2 9 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10proto/plan.proto\x12\x05proto\x1a\x13proto/algebra.proto\x1a!proto/extensions/extensions.proto"[\n\x07PlanRel\x12\x1e\n\x03rel\x18\x01 \x01(\x0b2\n.proto.RelH\x00R\x03rel\x12$\n\x04root\x18\x02 \x01(\x0b2\x0e.proto.RelRootH\x00R\x04rootB\n\n\x08rel_type"\xcc\x03\n\x04Plan\x12(\n\x07version\x18\x06 \x01(\x0b2\x0e.proto.VersionR\x07version\x12K\n\x0eextension_uris\x18\x01 \x03(\x0b2$.proto.extensions.SimpleExtensionURIR\rextensionUris\x12L\n\nextensions\x18\x02 \x03(\x0b2,.proto.extensions.SimpleExtensionDeclarationR\nextensions\x12,\n\trelations\x18\x03 \x03(\x0b2\x0e.proto.PlanRelR\trelations\x12T\n\x13advanced_extensions\x18\x04 \x01(\x0b2#.proto.extensions.AdvancedExtensionR\x12advancedExtensions\x12,\n\x12expected_type_urls\x18\x05 \x03(\tR\x10expectedTypeUrls\x12M\n\x12parameter_bindings\x18\x07 \x03(\x0b2\x1e.proto.DynamicParameterBindingR\x11parameterBindings"7\n\x0bPlanVersion\x12(\n\x07version\x18\x06 \x01(\x0b2\x0e.proto.VersionR\x07version"\xa9\x01\n\x07Version\x12!\n\x0cmajor_number\x18\x01 \x01(\rR\x0bmajorNumber\x12!\n\x0cminor_number\x18\x02 \x01(\rR\x0bminorNumber\x12!\n\x0cpatch_number\x18\x03 \x01(\rR\x0bpatchNumber\x12\x19\n\x08git_hash\x18\x04 \x01(\tR\x07gitHash\x12\x1a\n\x08producer\x18\x05 \x01(\tR\x08producer"u\n\x17DynamicParameterBinding\x12)\n\x10parameter_anchor\x18\x01 \x01(\rR\x0fparameterAnchor\x12/\n\x05value\x18\x02 \x01(\x0b2\x19.proto.Expression.LiteralR\x05valueB#\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobufb\x06proto3') 10 | _globals = globals() 11 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 12 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.plan_pb2', _globals) 13 | if _descriptor._USE_C_DESCRIPTORS == False: 14 | _globals['DESCRIPTOR']._options = None 15 | _globals['DESCRIPTOR']._serialized_options = b'\n\x0eio.proto.protoP\x01\xaa\x02\x0eProto.Protobuf' 16 | _globals['_PLANREL']._serialized_start = 83 17 | _globals['_PLANREL']._serialized_end = 174 18 | _globals['_PLAN']._serialized_start = 177 19 | _globals['_PLAN']._serialized_end = 637 20 | _globals['_PLANVERSION']._serialized_start = 639 21 | _globals['_PLANVERSION']._serialized_end = 694 22 | _globals['_VERSION']._serialized_start = 697 23 | _globals['_VERSION']._serialized_end = 866 24 | _globals['_DYNAMICPARAMETERBINDING']._serialized_start = 868 25 | _globals['_DYNAMICPARAMETERBINDING']._serialized_end = 985 -------------------------------------------------------------------------------- /src/substrait/json.py: -------------------------------------------------------------------------------- 1 | from google.protobuf import json_format 2 | 3 | from substrait.proto import Plan 4 | 5 | 6 | def load_json(filename): 7 | """Load a Substrait Plan from a json file""" 8 | with open(filename, encoding="utf-8") as f: 9 | return parse_json(f.read()) 10 | 11 | 12 | def parse_json(text): 13 | """Generate a Substrait Plan from its JSON definition""" 14 | return json_format.Parse(text=text, message=Plan()) 15 | 16 | 17 | def write_json(plan, filename): 18 | """Write a Substrait Plan to a json file""" 19 | with open(filename, "w+") as f: 20 | f.write(dump_json(plan)) 21 | 22 | 23 | def dump_json(plan): 24 | """Dump a Substrait Plan to a string in JSON format""" 25 | return json_format.MessageToJson(plan) 26 | -------------------------------------------------------------------------------- /src/substrait/proto.py: -------------------------------------------------------------------------------- 1 | def _load(): 2 | """Import all substrait protobuf classes as human friendly. 3 | 4 | Instead of forcing users to deal with autogenerated protobuf 5 | modules, importing individual components of the protocol 6 | from submodules etc... this functions loads into the module 7 | all classes representing substrait expressions and loads 8 | the protocol modules with a friendly name making the protocol 9 | more convenient to use. 10 | 11 | substrait.gen.proto.extensions.extensions_pb2.SimpleExtensionDeclaration 12 | becomes substrait.proto.SimpleExtensionDeclaration 13 | """ 14 | import sys 15 | import inspect 16 | import pkgutil 17 | import importlib 18 | from substrait.gen import proto as _proto 19 | 20 | selfmodule = sys.modules[__name__] 21 | for submodule_info in pkgutil.iter_modules(_proto.__path__): 22 | submodule_name = submodule_info.name 23 | attr_name = submodule_name.replace("_pb2", "") 24 | if submodule_name == "extensions": 25 | # Extensions are in a submodule 26 | submodule_name = "extensions.extensions_pb2" 27 | attr_name = "extensions" 28 | 29 | submodule = importlib.import_module(f".{submodule_name}", _proto.__name__) 30 | setattr(selfmodule, attr_name, submodule) 31 | 32 | for membername, _ in inspect.getmembers(submodule): 33 | member = getattr(submodule, membername) 34 | if inspect.isclass(member): 35 | setattr(selfmodule, membername, member) 36 | 37 | 38 | _load() 39 | -------------------------------------------------------------------------------- /src/substrait/simple_extension_utils.py: -------------------------------------------------------------------------------- 1 | from substrait.gen.json import simple_extensions as se 2 | from typing import Union 3 | 4 | 5 | def build_arg(d: dict) -> Union[se.ValueArg, se.TypeArg, se.EnumerationArg]: 6 | if "value" in d: 7 | return se.ValueArg( 8 | value=d["value"], 9 | name=d.get("name"), 10 | description=d.get("description"), 11 | constant=d.get("constant"), 12 | ) 13 | elif "type" in d: 14 | return se.TypeArg( 15 | type=d["type"], name=d.get("name"), description=d.get("description") 16 | ) 17 | elif "options" in d: 18 | return se.EnumerationArg( 19 | options=d["options"], name=d.get("name"), description=d.get("description") 20 | ) 21 | 22 | 23 | def build_variadic_behavior(d: dict) -> se.VariadicBehavior: 24 | return se.VariadicBehavior( 25 | min=d.get("min"), 26 | max=d.get("max"), 27 | parameterConsistency=se.ParameterConsistency(d["parameterConsistency"]) 28 | if "parameterConsistency" in d 29 | else None, 30 | ) 31 | 32 | 33 | def build_options(d: dict) -> se.Options: 34 | return { 35 | k: se.Options1(values=v["values"], description=v.get("description")) 36 | for k, v in d.items() 37 | } 38 | 39 | 40 | def build_scalar_function(d: dict) -> se.ScalarFunction: 41 | return se.ScalarFunction( 42 | name=d["name"], 43 | impls=[ 44 | se.Impl( 45 | return_=i["return"], 46 | args=[build_arg(arg) for arg in i["args"]] if "args" in i else None, 47 | options=build_options(i["options"]) if "options" in i else None, 48 | variadic=build_variadic_behavior(i["variadic"]) 49 | if "variadic" in i 50 | else None, 51 | sessionDependent=i.get("sessionDependent"), 52 | deterministic=i.get("deterministic"), 53 | nullability=se.NullabilityHandling(i["nullability"]) 54 | if "nullability" in i 55 | else None, 56 | implementation=i.get("implementation"), 57 | ) 58 | for i in d["impls"] 59 | ], 60 | description=d.get("description"), 61 | ) 62 | 63 | 64 | def build_aggregate_function(d: dict) -> se.AggregateFunction: 65 | return se.AggregateFunction( 66 | name=d["name"], 67 | impls=[ 68 | se.Impl1( 69 | return_=i["return"], 70 | args=[build_arg(arg) for arg in i["args"]] if "args" in i else None, 71 | options=build_options(i["options"]) if "options" in i else None, 72 | variadic=build_variadic_behavior(i["variadic"]) 73 | if "variadic" in i 74 | else None, 75 | sessionDependent=i.get("sessionDependent"), 76 | deterministic=i.get("deterministic"), 77 | nullability=se.NullabilityHandling(i["nullability"]) 78 | if "nullability" in i 79 | else None, 80 | implementation=i.get("implementation"), 81 | intermediate=i.get("intermediate"), 82 | ordered=i.get("ordered"), 83 | maxset=i.get("maxset"), 84 | decomposable=se.Decomposable(i["decomposable"]) 85 | if "decomposable" in i 86 | else None, 87 | ) 88 | for i in d["impls"] 89 | ], 90 | description=d.get("description"), 91 | ) 92 | 93 | 94 | def build_window_function(d: dict) -> se.WindowFunction: 95 | return se.WindowFunction( 96 | name=d["name"], 97 | impls=[ 98 | se.Impl2( 99 | return_=i["return"], 100 | args=[build_arg(arg) for arg in i["args"]] if "args" in i else None, 101 | options=build_options(i["options"]) if "options" in i else None, 102 | variadic=build_variadic_behavior(i["variadic"]) 103 | if "variadic" in i 104 | else None, 105 | sessionDependent=i.get("sessionDependent"), 106 | deterministic=i.get("deterministic"), 107 | nullability=se.NullabilityHandling(i["nullability"]) 108 | if "nullability" in i 109 | else None, 110 | implementation=i.get("implementation"), 111 | intermediate=i.get("intermediate"), 112 | ordered=i.get("ordered"), 113 | maxset=i.get("maxset"), 114 | decomposable=se.Decomposable(i["decomposable"]) 115 | if "decomposable" in i 116 | else None, 117 | window_type=se.WindowType(i["window_type"]) 118 | if "window_type" in i 119 | else None, 120 | ) 121 | for i in d["impls"] 122 | ], 123 | description=d.get("description"), 124 | ) 125 | 126 | 127 | def build_type_model(d: dict) -> se.TypeModel: 128 | return se.TypeModel( 129 | name=d["name"], 130 | structure=d.get("structure"), 131 | parameters=d.get("parameters"), 132 | variadic=d.get("variadic"), 133 | ) 134 | 135 | 136 | def build_type_variation(d: dict) -> se.TypeVariation: 137 | return se.TypeVariation( 138 | parent=d["parent"], 139 | name=d["name"], 140 | description=d.get("description"), 141 | functions=se.Functions(d["functions"]) if "functions" in d else None, 142 | ) 143 | 144 | 145 | def build_simple_extensions(d: dict) -> se.SimpleExtensions: 146 | return se.SimpleExtensions( 147 | dependencies=d.get("dependencies"), 148 | types=[build_type_model(t) for t in d["types"]] if "types" in d else None, 149 | type_variations=[build_type_variation(t) for t in d["type_variations"]] 150 | if "type_variations" in d 151 | else None, 152 | scalar_functions=[build_scalar_function(f) for f in d["scalar_functions"]] 153 | if "scalar_functions" in d 154 | else None, 155 | aggregate_functions=[ 156 | build_aggregate_function(f) for f in d["aggregate_functions"] 157 | ] 158 | if "aggregate_functions" in d 159 | else None, 160 | window_functions=[build_window_function(f) for f in d["window_functions"]] 161 | if "window_functions" in d 162 | else None, 163 | ) 164 | -------------------------------------------------------------------------------- /src/substrait/utils.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stp 2 | import substrait.gen.proto.extensions.extensions_pb2 as ste 3 | from typing import Iterable 4 | 5 | 6 | def type_num_names(typ: stp.Type): 7 | kind = typ.WhichOneof("kind") 8 | if kind == "struct": 9 | lengths = [type_num_names(t) for t in typ.struct.types] 10 | return sum(lengths) + 1 11 | elif kind == "list": 12 | return type_num_names(typ.list.type) 13 | elif kind == "map": 14 | return type_num_names(typ.map.key) + type_num_names(typ.map.value) 15 | else: 16 | return 1 17 | 18 | 19 | def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]): 20 | """Merges multiple sets of SimpleExtensionURI objects into a single set. 21 | The order of extensions is kept intact, while duplicates are discarded. 22 | Assumes that there are no collisions (different extensions having identical anchors). 23 | """ 24 | seen_uris = set() 25 | ret = [] 26 | 27 | for uris in extension_uris: 28 | for uri in uris: 29 | if uri.uri not in seen_uris: 30 | seen_uris.add(uri.uri) 31 | ret.append(uri) 32 | 33 | return ret 34 | 35 | 36 | def merge_extension_declarations( 37 | *extension_declarations: Iterable[ste.SimpleExtensionDeclaration], 38 | ): 39 | """Merges multiple sets of SimpleExtensionDeclaration objects into a single set. 40 | The order of extension declarations is kept intact, while duplicates are discarded. 41 | Assumes that there are no collisions (different extension declarations having identical anchors). 42 | """ 43 | 44 | seen_extension_functions = set() 45 | ret = [] 46 | 47 | for declarations in extension_declarations: 48 | for declaration in declarations: 49 | if declaration.WhichOneof("mapping_type") == "extension_function": 50 | ident = ( 51 | declaration.extension_function.extension_uri_reference, 52 | declaration.extension_function.name, 53 | ) 54 | if ident not in seen_extension_functions: 55 | seen_extension_functions.add(ident) 56 | ret.append(declaration) 57 | else: 58 | raise Exception("") # TODO handle extension types 59 | 60 | return ret 61 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_aggregate_function.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | import substrait.gen.proto.type_pb2 as stt 5 | import substrait.gen.proto.extended_expression_pb2 as stee 6 | import substrait.gen.proto.extensions.extensions_pb2 as ste 7 | from substrait.builders.extended_expression import aggregate_function, literal 8 | from substrait.extension_registry import ExtensionRegistry 9 | 10 | struct = stt.Type.Struct( 11 | types=[ 12 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 13 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 14 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 15 | ] 16 | ) 17 | 18 | named_struct = stt.NamedStruct( 19 | names=["order_id", "description", "order_total"], struct=struct 20 | ) 21 | 22 | content = """%YAML 1.2 23 | --- 24 | aggregate_functions: 25 | - name: "count" 26 | description: Count a set of values 27 | impls: 28 | - args: 29 | - name: x 30 | value: any 31 | nullability: DECLARED_OUTPUT 32 | decomposable: MANY 33 | intermediate: i64 34 | return: i64 35 | """ 36 | 37 | 38 | registry = ExtensionRegistry(load_default_extensions=False) 39 | registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") 40 | 41 | 42 | def test_aggregate_count(): 43 | e = aggregate_function( 44 | "test_uri", 45 | "count", 46 | expressions=[ 47 | literal( 48 | 10, 49 | type=stt.Type( 50 | i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) 51 | ), 52 | ) 53 | ], 54 | alias="count", 55 | )(named_struct, registry) 56 | 57 | expected = stee.ExtendedExpression( 58 | extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], 59 | extensions=[ 60 | ste.SimpleExtensionDeclaration( 61 | extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( 62 | extension_uri_reference=1, function_anchor=1, name="count" 63 | ) 64 | ) 65 | ], 66 | referred_expr=[ 67 | stee.ExpressionReference( 68 | measure=stalg.AggregateFunction( 69 | function_reference=1, 70 | arguments=[ 71 | stalg.FunctionArgument( 72 | value=stalg.Expression( 73 | literal=stalg.Expression.Literal(i8=10, nullable=False) 74 | ) 75 | ), 76 | ], 77 | output_type=stt.Type( 78 | i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED) 79 | ), 80 | ), 81 | output_names=["count"], 82 | ) 83 | ], 84 | base_schema=named_struct, 85 | ) 86 | 87 | assert e == expected 88 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_cast.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.algebra_pb2 as stalg 2 | import substrait.gen.proto.type_pb2 as stt 3 | import substrait.gen.proto.extended_expression_pb2 as stee 4 | from substrait.builders.extended_expression import cast, literal 5 | from substrait.builders.type import i8, i16 6 | from substrait.extension_registry import ExtensionRegistry 7 | 8 | struct = stt.Type.Struct( 9 | types=[ 10 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 11 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 12 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 13 | ] 14 | ) 15 | 16 | named_struct = stt.NamedStruct( 17 | names=["order_id", "description", "order_total"], struct=struct 18 | ) 19 | 20 | registry = ExtensionRegistry(load_default_extensions=False) 21 | 22 | 23 | def test_cast(): 24 | e = cast(input=literal(3, i8()), type=i16())(named_struct, registry) 25 | 26 | expected = stee.ExtendedExpression( 27 | referred_expr=[ 28 | stee.ExpressionReference( 29 | expression=stalg.Expression( 30 | cast=stalg.Expression.Cast( 31 | type=stt.Type( 32 | i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE) 33 | ), 34 | input=stalg.Expression( 35 | literal=stalg.Expression.Literal(i8=3, nullable=True) 36 | ), 37 | failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, 38 | ) 39 | ), 40 | output_names=["cast"], 41 | ) 42 | ], 43 | base_schema=named_struct, 44 | ) 45 | 46 | assert e == expected 47 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_column.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.algebra_pb2 as stalg 2 | import substrait.gen.proto.type_pb2 as stt 3 | import substrait.gen.proto.extended_expression_pb2 as stee 4 | from substrait.builders.extended_expression import column 5 | 6 | struct = stt.Type.Struct( 7 | types=[ 8 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 9 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 10 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 11 | ] 12 | ) 13 | 14 | named_struct = stt.NamedStruct( 15 | names=["order_id", "description", "order_total"], struct=struct 16 | ) 17 | 18 | nested_struct = stt.Type.Struct( 19 | types=[ 20 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 21 | stt.Type( 22 | struct=stt.Type.Struct( 23 | types=[ 24 | stt.Type( 25 | i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED) 26 | ), 27 | stt.Type( 28 | fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE) 29 | ), 30 | ], 31 | nullability=stt.Type.NULLABILITY_NULLABLE, 32 | ) 33 | ), 34 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 35 | ] 36 | ) 37 | 38 | nested_named_struct = stt.NamedStruct( 39 | names=["order_id", "shop_details", "shop_id", "shop_total", "order_total"], 40 | struct=nested_struct, 41 | ) 42 | 43 | 44 | def test_column_no_nesting(): 45 | assert column("description")(named_struct, None) == stee.ExtendedExpression( 46 | referred_expr=[ 47 | stee.ExpressionReference( 48 | expression=stalg.Expression( 49 | selection=stalg.Expression.FieldReference( 50 | root_reference=stalg.Expression.FieldReference.RootReference(), 51 | direct_reference=stalg.Expression.ReferenceSegment( 52 | struct_field=stalg.Expression.ReferenceSegment.StructField( 53 | field=1 54 | ) 55 | ), 56 | ) 57 | ), 58 | output_names=["description"], 59 | ) 60 | ], 61 | base_schema=named_struct, 62 | ) 63 | 64 | 65 | def test_column_nesting(): 66 | assert column("order_total")(nested_named_struct, None) == stee.ExtendedExpression( 67 | referred_expr=[ 68 | stee.ExpressionReference( 69 | expression=stalg.Expression( 70 | selection=stalg.Expression.FieldReference( 71 | root_reference=stalg.Expression.FieldReference.RootReference(), 72 | direct_reference=stalg.Expression.ReferenceSegment( 73 | struct_field=stalg.Expression.ReferenceSegment.StructField( 74 | field=2 75 | ) 76 | ), 77 | ) 78 | ), 79 | output_names=["order_total"], 80 | ) 81 | ], 82 | base_schema=nested_named_struct, 83 | ) 84 | 85 | 86 | def test_column_nested_struct(): 87 | assert column("shop_details")(nested_named_struct, None) == stee.ExtendedExpression( 88 | referred_expr=[ 89 | stee.ExpressionReference( 90 | expression=stalg.Expression( 91 | selection=stalg.Expression.FieldReference( 92 | root_reference=stalg.Expression.FieldReference.RootReference(), 93 | direct_reference=stalg.Expression.ReferenceSegment( 94 | struct_field=stalg.Expression.ReferenceSegment.StructField( 95 | field=1 96 | ) 97 | ), 98 | ) 99 | ), 100 | output_names=["shop_details", "shop_id", "shop_total"], 101 | ) 102 | ], 103 | base_schema=nested_named_struct, 104 | ) 105 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_if_then.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.algebra_pb2 as stalg 2 | import substrait.gen.proto.type_pb2 as stt 3 | import substrait.gen.proto.extended_expression_pb2 as stee 4 | from substrait.builders.extended_expression import if_then, literal 5 | 6 | 7 | struct = stt.Type.Struct( 8 | types=[ 9 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 10 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 11 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 12 | ] 13 | ) 14 | 15 | named_struct = stt.NamedStruct( 16 | names=["order_id", "description", "order_total"], struct=struct 17 | ) 18 | 19 | 20 | def test_if_else(): 21 | actual = if_then( 22 | ifs=[ 23 | ( 24 | literal( 25 | True, 26 | type=stt.Type( 27 | bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED) 28 | ), 29 | ), 30 | literal( 31 | 10, 32 | type=stt.Type( 33 | i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) 34 | ), 35 | ), 36 | ) 37 | ], 38 | _else=literal( 39 | 20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)) 40 | ), 41 | )(named_struct, None) 42 | 43 | expected = stee.ExtendedExpression( 44 | referred_expr=[ 45 | stee.ExpressionReference( 46 | expression=stalg.Expression( 47 | if_then=stalg.Expression.IfThen( 48 | **{ 49 | "ifs": [ 50 | stalg.Expression.IfThen.IfClause( 51 | **{ 52 | "if": stalg.Expression( 53 | literal=stalg.Expression.Literal( 54 | boolean=True, nullable=False 55 | ) 56 | ), 57 | "then": stalg.Expression( 58 | literal=stalg.Expression.Literal( 59 | i8=10, nullable=False 60 | ) 61 | ), 62 | } 63 | ) 64 | ], 65 | "else": stalg.Expression( 66 | literal=stalg.Expression.Literal(i8=20, nullable=False) 67 | ), 68 | } 69 | ) 70 | ), 71 | output_names=["IfThen(Literal(True),Literal(10),Literal(20))"], 72 | ) 73 | ], 74 | base_schema=named_struct, 75 | ) 76 | 77 | assert actual == expected 78 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_literal.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | import substrait.gen.proto.algebra_pb2 as stalg 3 | from substrait.builders.extended_expression import literal 4 | from substrait.builders import type as sttb 5 | 6 | 7 | def extract_literal(builder): 8 | return builder(None, None).referred_expr[0].expression.literal 9 | 10 | 11 | def test_boolean(): 12 | assert extract_literal(literal(True, sttb.boolean())) == stalg.Expression.Literal( 13 | boolean=True, nullable=True 14 | ) 15 | assert extract_literal(literal(False, sttb.boolean())) == stalg.Expression.Literal( 16 | boolean=False, nullable=True 17 | ) 18 | 19 | 20 | def test_integer(): 21 | assert extract_literal(literal(100, sttb.i16())) == stalg.Expression.Literal( 22 | i16=100, nullable=True 23 | ) 24 | 25 | 26 | def test_string(): 27 | assert extract_literal(literal("Hello", sttb.string())) == stalg.Expression.Literal( 28 | string="Hello", nullable=True 29 | ) 30 | 31 | 32 | def test_binary(): 33 | assert extract_literal( 34 | literal(b"Hello", sttb.binary()) 35 | ) == stalg.Expression.Literal(binary=b"Hello", nullable=True) 36 | 37 | 38 | def test_date(): 39 | assert extract_literal(literal(1000, sttb.date())) == stalg.Expression.Literal( 40 | date=1000, nullable=True 41 | ) 42 | assert extract_literal( 43 | literal(date(1970, 1, 11), sttb.date()) 44 | ) == stalg.Expression.Literal(date=10, nullable=True) 45 | 46 | 47 | def test_fixed_char(): 48 | assert extract_literal( 49 | literal("Hello", sttb.fixed_char(length=5)) 50 | ) == stalg.Expression.Literal(fixed_char="Hello", nullable=True) 51 | 52 | 53 | def test_var_char(): 54 | assert extract_literal( 55 | literal("Hello", sttb.var_char(length=5)) 56 | ) == stalg.Expression.Literal( 57 | var_char=stalg.Expression.Literal.VarChar(value="Hello", length=5), 58 | nullable=True, 59 | ) 60 | 61 | 62 | def test_fixed_binary(): 63 | assert extract_literal( 64 | literal(b"Hello", sttb.fixed_binary(length=5)) 65 | ) == stalg.Expression.Literal(fixed_binary=b"Hello", nullable=True) 66 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_multi_or_list.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.algebra_pb2 as stalg 2 | import substrait.gen.proto.type_pb2 as stt 3 | import substrait.gen.proto.extended_expression_pb2 as stee 4 | from substrait.builders.extended_expression import multi_or_list, literal 5 | from substrait.builders.type import i8 6 | from substrait.extension_registry import ExtensionRegistry 7 | 8 | struct = stt.Type.Struct( 9 | types=[ 10 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 11 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 12 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 13 | ] 14 | ) 15 | 16 | named_struct = stt.NamedStruct( 17 | names=["order_id", "description", "order_total"], struct=struct 18 | ) 19 | 20 | registry = ExtensionRegistry(load_default_extensions=False) 21 | 22 | 23 | def test_singular_or_list(): 24 | e = multi_or_list( 25 | value=[literal(1, i8()), literal(2, i8())], 26 | options=[ 27 | [literal(1, i8()), literal(2, i8())], 28 | [literal(3, i8()), literal(4, i8())], 29 | ], 30 | )(named_struct, registry) 31 | 32 | expected = stee.ExtendedExpression( 33 | referred_expr=[ 34 | stee.ExpressionReference( 35 | expression=stalg.Expression( 36 | multi_or_list=stalg.Expression.MultiOrList( 37 | value=[ 38 | stalg.Expression( 39 | literal=stalg.Expression.Literal(i8=1, nullable=True) 40 | ), 41 | stalg.Expression( 42 | literal=stalg.Expression.Literal(i8=2, nullable=True) 43 | ), 44 | ], 45 | options=[ 46 | stalg.Expression.MultiOrList.Record( 47 | fields=[ 48 | stalg.Expression( 49 | literal=stalg.Expression.Literal( 50 | i8=1, nullable=True 51 | ) 52 | ), 53 | stalg.Expression( 54 | literal=stalg.Expression.Literal( 55 | i8=2, nullable=True 56 | ) 57 | ), 58 | ] 59 | ), 60 | stalg.Expression.MultiOrList.Record( 61 | fields=[ 62 | stalg.Expression( 63 | literal=stalg.Expression.Literal( 64 | i8=3, nullable=True 65 | ) 66 | ), 67 | stalg.Expression( 68 | literal=stalg.Expression.Literal( 69 | i8=4, nullable=True 70 | ) 71 | ), 72 | ] 73 | ), 74 | ], 75 | ) 76 | ), 77 | output_names=["multi_or_list"], 78 | ) 79 | ], 80 | base_schema=named_struct, 81 | ) 82 | 83 | assert e == expected 84 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_scalar_function.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | import substrait.gen.proto.type_pb2 as stt 5 | import substrait.gen.proto.extended_expression_pb2 as stee 6 | import substrait.gen.proto.extensions.extensions_pb2 as ste 7 | from substrait.builders.extended_expression import scalar_function, literal 8 | from substrait.extension_registry import ExtensionRegistry 9 | 10 | struct = stt.Type.Struct( 11 | types=[ 12 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 13 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 14 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 15 | ] 16 | ) 17 | 18 | named_struct = stt.NamedStruct( 19 | names=["order_id", "description", "order_total"], struct=struct 20 | ) 21 | 22 | content = """%YAML 1.2 23 | --- 24 | scalar_functions: 25 | - name: "test_func" 26 | description: "" 27 | impls: 28 | - args: 29 | - value: i8 30 | variadic: 31 | min: 2 32 | return: i8 33 | - name: "is_positive" 34 | description: "" 35 | impls: 36 | - args: 37 | - value: i8 38 | return: boolean 39 | """ 40 | 41 | 42 | registry = ExtensionRegistry(load_default_extensions=False) 43 | registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") 44 | 45 | 46 | def test_sclar_add(): 47 | e = scalar_function( 48 | "test_uri", 49 | "test_func", 50 | expressions=[ 51 | literal( 52 | 10, 53 | type=stt.Type( 54 | i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) 55 | ), 56 | ), 57 | literal( 58 | 20, 59 | type=stt.Type( 60 | i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) 61 | ), 62 | ), 63 | ], 64 | )(named_struct, registry) 65 | 66 | expected = stee.ExtendedExpression( 67 | extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], 68 | extensions=[ 69 | ste.SimpleExtensionDeclaration( 70 | extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( 71 | extension_uri_reference=1, function_anchor=1, name="test_func" 72 | ) 73 | ) 74 | ], 75 | referred_expr=[ 76 | stee.ExpressionReference( 77 | expression=stalg.Expression( 78 | scalar_function=stalg.Expression.ScalarFunction( 79 | function_reference=1, 80 | arguments=[ 81 | stalg.FunctionArgument( 82 | value=stalg.Expression( 83 | literal=stalg.Expression.Literal( 84 | i8=10, nullable=False 85 | ) 86 | ) 87 | ), 88 | stalg.FunctionArgument( 89 | value=stalg.Expression( 90 | literal=stalg.Expression.Literal( 91 | i8=20, nullable=False 92 | ) 93 | ) 94 | ), 95 | ], 96 | output_type=stt.Type( 97 | i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) 98 | ), 99 | ) 100 | ), 101 | output_names=["test_func(Literal(10),Literal(20))"], 102 | ) 103 | ], 104 | base_schema=named_struct, 105 | ) 106 | 107 | assert e == expected 108 | 109 | 110 | def test_nested_scalar_calls(): 111 | e = scalar_function( 112 | "test_uri", 113 | "is_positive", 114 | expressions=[ 115 | scalar_function( 116 | "test_uri", 117 | "test_func", 118 | expressions=[ 119 | literal( 120 | 10, 121 | type=stt.Type( 122 | i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) 123 | ), 124 | ), 125 | literal( 126 | 20, 127 | type=stt.Type( 128 | i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) 129 | ), 130 | ), 131 | ], 132 | ) 133 | ], 134 | alias="positive", 135 | )(named_struct, registry) 136 | 137 | expected = stee.ExtendedExpression( 138 | extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], 139 | extensions=[ 140 | ste.SimpleExtensionDeclaration( 141 | extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( 142 | extension_uri_reference=1, function_anchor=2, name="is_positive" 143 | ) 144 | ), 145 | ste.SimpleExtensionDeclaration( 146 | extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( 147 | extension_uri_reference=1, function_anchor=1, name="test_func" 148 | ) 149 | ), 150 | ], 151 | referred_expr=[ 152 | stee.ExpressionReference( 153 | expression=stalg.Expression( 154 | scalar_function=stalg.Expression.ScalarFunction( 155 | function_reference=2, 156 | arguments=[ 157 | stalg.FunctionArgument( 158 | value=stalg.Expression( 159 | scalar_function=stalg.Expression.ScalarFunction( 160 | function_reference=1, 161 | arguments=[ 162 | stalg.FunctionArgument( 163 | value=stalg.Expression( 164 | literal=stalg.Expression.Literal( 165 | i8=10, nullable=False 166 | ) 167 | ) 168 | ), 169 | stalg.FunctionArgument( 170 | value=stalg.Expression( 171 | literal=stalg.Expression.Literal( 172 | i8=20, nullable=False 173 | ) 174 | ) 175 | ), 176 | ], 177 | output_type=stt.Type( 178 | i8=stt.Type.I8( 179 | nullability=stt.Type.NULLABILITY_REQUIRED 180 | ) 181 | ), 182 | ) 183 | ) 184 | ) 185 | ], 186 | output_type=stt.Type( 187 | bool=stt.Type.Boolean( 188 | nullability=stt.Type.NULLABILITY_REQUIRED 189 | ) 190 | ), 191 | ) 192 | ), 193 | output_names=["positive"], 194 | ) 195 | ], 196 | base_schema=named_struct, 197 | ) 198 | 199 | assert e == expected 200 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_singular_or_list.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.algebra_pb2 as stalg 2 | import substrait.gen.proto.type_pb2 as stt 3 | import substrait.gen.proto.extended_expression_pb2 as stee 4 | from substrait.builders.extended_expression import singular_or_list, literal 5 | from substrait.builders.type import i8 6 | from substrait.extension_registry import ExtensionRegistry 7 | 8 | struct = stt.Type.Struct( 9 | types=[ 10 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 11 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 12 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 13 | ] 14 | ) 15 | 16 | named_struct = stt.NamedStruct( 17 | names=["order_id", "description", "order_total"], struct=struct 18 | ) 19 | 20 | registry = ExtensionRegistry(load_default_extensions=False) 21 | 22 | 23 | def test_singular_or_list(): 24 | e = singular_or_list( 25 | value=literal(3, i8()), options=[literal(1, i8()), literal(2, i8())] 26 | )(named_struct, registry) 27 | 28 | expected = stee.ExtendedExpression( 29 | referred_expr=[ 30 | stee.ExpressionReference( 31 | expression=stalg.Expression( 32 | singular_or_list=stalg.Expression.SingularOrList( 33 | value=stalg.Expression( 34 | literal=stalg.Expression.Literal(i8=3, nullable=True) 35 | ), 36 | options=[ 37 | stalg.Expression( 38 | literal=stalg.Expression.Literal(i8=1, nullable=True) 39 | ), 40 | stalg.Expression( 41 | literal=stalg.Expression.Literal(i8=2, nullable=True) 42 | ), 43 | ], 44 | ) 45 | ), 46 | output_names=["singular_or_list"], 47 | ) 48 | ], 49 | base_schema=named_struct, 50 | ) 51 | 52 | assert e == expected 53 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_switch.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.algebra_pb2 as stalg 2 | import substrait.gen.proto.type_pb2 as stt 3 | import substrait.gen.proto.extended_expression_pb2 as stee 4 | from substrait.builders.extended_expression import switch, literal 5 | from substrait.builders.type import i8 6 | from substrait.extension_registry import ExtensionRegistry 7 | 8 | struct = stt.Type.Struct( 9 | types=[ 10 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 11 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 12 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 13 | ] 14 | ) 15 | 16 | named_struct = stt.NamedStruct( 17 | names=["order_id", "description", "order_total"], struct=struct 18 | ) 19 | 20 | registry = ExtensionRegistry(load_default_extensions=False) 21 | 22 | 23 | def test_switch(): 24 | e = switch( 25 | match=literal(3, i8()), 26 | ifs=[ 27 | (literal(1, i8()), literal(1, i8())), 28 | (literal(2, i8()), literal(4, i8())), 29 | ], 30 | _else=literal(9, i8()), 31 | )(named_struct, registry) 32 | 33 | expected = stee.ExtendedExpression( 34 | referred_expr=[ 35 | stee.ExpressionReference( 36 | expression=stalg.Expression( 37 | switch_expression=stalg.Expression.SwitchExpression( 38 | match=stalg.Expression( 39 | literal=stalg.Expression.Literal(i8=3, nullable=True) 40 | ), 41 | ifs=[ 42 | stalg.Expression.SwitchExpression.IfValue( 43 | **{ 44 | "if": stalg.Expression.Literal(i8=1, nullable=True), 45 | "then": stalg.Expression( 46 | literal=stalg.Expression.Literal( 47 | i8=1, nullable=True 48 | ) 49 | ), 50 | } 51 | ), 52 | stalg.Expression.SwitchExpression.IfValue( 53 | **{ 54 | "if": stalg.Expression.Literal(i8=2, nullable=True), 55 | "then": stalg.Expression( 56 | literal=stalg.Expression.Literal( 57 | i8=4, nullable=True 58 | ) 59 | ), 60 | } 61 | ), 62 | ], 63 | **{ 64 | "else": stalg.Expression( 65 | literal=stalg.Expression.Literal(i8=9, nullable=True) 66 | ) 67 | }, 68 | ) 69 | ), 70 | output_names=["switch"], 71 | ) 72 | ], 73 | base_schema=named_struct, 74 | ) 75 | 76 | assert e == expected 77 | -------------------------------------------------------------------------------- /tests/builders/extended_expression/test_window_function.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | import substrait.gen.proto.type_pb2 as stt 5 | import substrait.gen.proto.extended_expression_pb2 as stee 6 | import substrait.gen.proto.extensions.extensions_pb2 as ste 7 | from substrait.builders.extended_expression import window_function 8 | from substrait.extension_registry import ExtensionRegistry 9 | 10 | struct = stt.Type.Struct( 11 | types=[ 12 | stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), 13 | stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), 14 | stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), 15 | ] 16 | ) 17 | 18 | named_struct = stt.NamedStruct( 19 | names=["order_id", "description", "order_total"], struct=struct 20 | ) 21 | 22 | content = """%YAML 1.2 23 | --- 24 | window_functions: 25 | - name: "row_number" 26 | description: "the number of the current row within its partition, starting at 1" 27 | impls: 28 | - args: [] 29 | nullability: DECLARED_OUTPUT 30 | decomposable: NONE 31 | return: i64? 32 | window_type: PARTITION 33 | - name: "rank" 34 | description: "the rank of the current row, with gaps." 35 | impls: 36 | - args: [] 37 | nullability: DECLARED_OUTPUT 38 | decomposable: NONE 39 | return: i64? 40 | window_type: PARTITION 41 | """ 42 | 43 | 44 | registry = ExtensionRegistry(load_default_extensions=False) 45 | registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") 46 | 47 | 48 | def test_row_number(): 49 | e = window_function("test_uri", "row_number", expressions=[], alias="rn")( 50 | named_struct, registry 51 | ) 52 | 53 | expected = stee.ExtendedExpression( 54 | extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], 55 | extensions=[ 56 | ste.SimpleExtensionDeclaration( 57 | extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( 58 | extension_uri_reference=1, function_anchor=1, name="row_number" 59 | ) 60 | ) 61 | ], 62 | referred_expr=[ 63 | stee.ExpressionReference( 64 | expression=stalg.Expression( 65 | window_function=stalg.Expression.WindowFunction( 66 | function_reference=1, 67 | output_type=stt.Type( 68 | i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE) 69 | ), 70 | ) 71 | ), 72 | output_names=["rn"], 73 | ) 74 | ], 75 | base_schema=named_struct, 76 | ) 77 | 78 | assert e == expected 79 | -------------------------------------------------------------------------------- /tests/builders/plan/test_aggregate.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | import substrait.gen.proto.extensions.extensions_pb2 as ste 5 | from substrait.builders.type import boolean, i64 6 | from substrait.builders.plan import read_named_table, aggregate 7 | from substrait.builders.extended_expression import column, aggregate_function 8 | from substrait.extension_registry import ExtensionRegistry 9 | from substrait.type_inference import infer_plan_schema 10 | import yaml 11 | 12 | content = """%YAML 1.2 13 | --- 14 | aggregate_functions: 15 | - name: "count" 16 | description: Count a set of values 17 | impls: 18 | - args: 19 | - name: x 20 | value: any 21 | nullability: DECLARED_OUTPUT 22 | decomposable: MANY 23 | intermediate: i64 24 | return: i64 25 | """ 26 | 27 | 28 | registry = ExtensionRegistry(load_default_extensions=False) 29 | registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") 30 | 31 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 32 | 33 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 34 | 35 | 36 | def test_aggregate(): 37 | table = read_named_table("table", named_struct) 38 | 39 | group_expr = column("id") 40 | measure_expr = aggregate_function( 41 | "test_uri", "count", expressions=[column("is_applicable")], alias=["count"] 42 | ) 43 | 44 | actual = aggregate( 45 | table, grouping_expressions=[group_expr], measures=[measure_expr] 46 | )(registry) 47 | 48 | ns = infer_plan_schema(table(None)) 49 | 50 | expected = stp.Plan( 51 | extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], 52 | extensions=[ 53 | ste.SimpleExtensionDeclaration( 54 | extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( 55 | extension_uri_reference=1, function_anchor=1, name="count" 56 | ) 57 | ) 58 | ], 59 | relations=[ 60 | stp.PlanRel( 61 | root=stalg.RelRoot( 62 | input=stalg.Rel( 63 | aggregate=stalg.AggregateRel( 64 | input=table(None).relations[-1].root.input, 65 | grouping_expressions=[ 66 | group_expr(ns, registry).referred_expr[0].expression 67 | ], 68 | groupings=[ 69 | stalg.AggregateRel.Grouping( 70 | grouping_expressions=[ 71 | group_expr(ns, registry) 72 | .referred_expr[0] 73 | .expression 74 | ], 75 | expression_references=[0], 76 | ) 77 | ], 78 | measures=[ 79 | stalg.AggregateRel.Measure( 80 | measure=measure_expr(ns, registry) 81 | .referred_expr[0] 82 | .measure 83 | ) 84 | ], 85 | ) 86 | ), 87 | names=["id", "count"], 88 | ) 89 | ) 90 | ], 91 | ) 92 | 93 | assert actual == expected 94 | -------------------------------------------------------------------------------- /tests/builders/plan/test_cross.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64, string 5 | from substrait.builders.plan import read_named_table, cross 6 | from substrait.extension_registry import ExtensionRegistry 7 | 8 | registry = ExtensionRegistry(load_default_extensions=False) 9 | 10 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 11 | 12 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 13 | 14 | named_struct_2 = stt.NamedStruct( 15 | names=["fk_id", "name"], 16 | struct=stt.Type.Struct(types=[i64(nullable=False), string()]), 17 | ) 18 | 19 | 20 | def test_cross_join(): 21 | table = read_named_table("table", named_struct) 22 | table2 = read_named_table("table2", named_struct_2) 23 | 24 | actual = cross(table, table2)(registry) 25 | 26 | expected = stp.Plan( 27 | relations=[ 28 | stp.PlanRel( 29 | root=stalg.RelRoot( 30 | input=stalg.Rel( 31 | cross=stalg.CrossRel( 32 | left=table(None).relations[-1].root.input, 33 | right=table2(None).relations[-1].root.input, 34 | ) 35 | ), 36 | names=["id", "is_applicable", "fk_id", "name"], 37 | ) 38 | ) 39 | ] 40 | ) 41 | 42 | assert actual == expected 43 | -------------------------------------------------------------------------------- /tests/builders/plan/test_fetch.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64 5 | from substrait.builders.plan import read_named_table, fetch 6 | from substrait.builders.extended_expression import literal 7 | from substrait.extension_registry import ExtensionRegistry 8 | 9 | registry = ExtensionRegistry(load_default_extensions=False) 10 | 11 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 12 | 13 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 14 | 15 | 16 | def test_fetch(): 17 | table = read_named_table("table", named_struct) 18 | 19 | offset = literal(10, i64()) 20 | count = literal(5, i64()) 21 | 22 | actual = fetch(table, offset=offset, count=count)(registry) 23 | 24 | expected = stp.Plan( 25 | relations=[ 26 | stp.PlanRel( 27 | root=stalg.RelRoot( 28 | input=stalg.Rel( 29 | fetch=stalg.FetchRel( 30 | input=table(None).relations[-1].root.input, 31 | offset_expr=offset(None, None).referred_expr[0].expression, 32 | count_expr=count(None, None).referred_expr[0].expression, 33 | ) 34 | ), 35 | names=["id", "is_applicable"], 36 | ) 37 | ) 38 | ] 39 | ) 40 | 41 | assert actual == expected 42 | -------------------------------------------------------------------------------- /tests/builders/plan/test_filter.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64 5 | from substrait.builders.plan import read_named_table, filter 6 | from substrait.builders.extended_expression import literal 7 | from substrait.extension_registry import ExtensionRegistry 8 | 9 | registry = ExtensionRegistry(load_default_extensions=False) 10 | 11 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 12 | 13 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 14 | 15 | 16 | def test_filter(): 17 | table = read_named_table("table", named_struct) 18 | 19 | actual = filter(table, literal(True, boolean()))(registry) 20 | 21 | expected = stp.Plan( 22 | relations=[ 23 | stp.PlanRel( 24 | root=stalg.RelRoot( 25 | input=stalg.Rel( 26 | filter=stalg.FilterRel( 27 | input=table(None).relations[-1].root.input, 28 | condition=stalg.Expression( 29 | literal=stalg.Expression.Literal( 30 | boolean=True, nullable=True 31 | ) 32 | ), 33 | ) 34 | ), 35 | names=["id", "is_applicable"], 36 | ) 37 | ) 38 | ] 39 | ) 40 | 41 | assert actual == expected 42 | -------------------------------------------------------------------------------- /tests/builders/plan/test_join.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64, string 5 | from substrait.builders.plan import read_named_table, join 6 | from substrait.builders.extended_expression import literal 7 | from substrait.extension_registry import ExtensionRegistry 8 | 9 | registry = ExtensionRegistry(load_default_extensions=False) 10 | 11 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 12 | 13 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 14 | 15 | named_struct_2 = stt.NamedStruct( 16 | names=["fk_id", "name"], 17 | struct=stt.Type.Struct(types=[i64(nullable=False), string()]), 18 | ) 19 | 20 | 21 | def test_join(): 22 | table = read_named_table("table", named_struct) 23 | table2 = read_named_table("table2", named_struct_2) 24 | 25 | actual = join( 26 | table, table2, literal(True, boolean()), stalg.JoinRel.JOIN_TYPE_INNER 27 | )(registry) 28 | 29 | expected = stp.Plan( 30 | relations=[ 31 | stp.PlanRel( 32 | root=stalg.RelRoot( 33 | input=stalg.Rel( 34 | join=stalg.JoinRel( 35 | left=table(None).relations[-1].root.input, 36 | right=table2(None).relations[-1].root.input, 37 | expression=literal(True, boolean())(None, None) 38 | .referred_expr[0] 39 | .expression, 40 | type=stalg.JoinRel.JOIN_TYPE_INNER, 41 | ) 42 | ), 43 | names=["id", "is_applicable", "fk_id", "name"], 44 | ) 45 | ) 46 | ] 47 | ) 48 | 49 | assert actual == expected 50 | -------------------------------------------------------------------------------- /tests/builders/plan/test_project.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64 5 | from substrait.builders.plan import read_named_table, project 6 | from substrait.builders.extended_expression import column 7 | from substrait.extension_registry import ExtensionRegistry 8 | 9 | registry = ExtensionRegistry(load_default_extensions=False) 10 | 11 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 12 | 13 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 14 | 15 | 16 | def test_project(): 17 | table = read_named_table("table", named_struct) 18 | 19 | actual = project(table, [column("id")])(registry) 20 | 21 | expected = stp.Plan( 22 | relations=[ 23 | stp.PlanRel( 24 | root=stalg.RelRoot( 25 | input=stalg.Rel( 26 | project=stalg.ProjectRel( 27 | common=stalg.RelCommon( 28 | emit=stalg.RelCommon.Emit(output_mapping=[2]) 29 | ), 30 | input=table(None).relations[-1].root.input, 31 | expressions=[ 32 | stalg.Expression( 33 | selection=stalg.Expression.FieldReference( 34 | direct_reference=stalg.Expression.ReferenceSegment( 35 | struct_field=stalg.Expression.ReferenceSegment.StructField( 36 | field=0 37 | ) 38 | ), 39 | root_reference=stalg.Expression.FieldReference.RootReference(), 40 | ) 41 | ) 42 | ], 43 | ) 44 | ), 45 | names=["id"], 46 | ) 47 | ) 48 | ] 49 | ) 50 | 51 | assert actual == expected 52 | -------------------------------------------------------------------------------- /tests/builders/plan/test_read.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64 5 | from substrait.builders.plan import read_named_table 6 | 7 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 8 | 9 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 10 | 11 | 12 | def test_read_rel(): 13 | actual = read_named_table("example_table", named_struct)(None) 14 | 15 | expected = stp.Plan( 16 | relations=[ 17 | stp.PlanRel( 18 | root=stalg.RelRoot( 19 | input=stalg.Rel( 20 | read=stalg.ReadRel( 21 | common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), 22 | base_schema=named_struct, 23 | named_table=stalg.ReadRel.NamedTable( 24 | names=["example_table"] 25 | ), 26 | ) 27 | ), 28 | names=["id", "is_applicable"], 29 | ) 30 | ) 31 | ] 32 | ) 33 | 34 | assert actual == expected 35 | 36 | 37 | def test_read_rel_db(): 38 | actual = read_named_table(["example_db", "example_table"], named_struct)(None) 39 | 40 | expected = stp.Plan( 41 | relations=[ 42 | stp.PlanRel( 43 | root=stalg.RelRoot( 44 | input=stalg.Rel( 45 | read=stalg.ReadRel( 46 | common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), 47 | base_schema=named_struct, 48 | named_table=stalg.ReadRel.NamedTable( 49 | names=["example_db", "example_table"] 50 | ), 51 | ) 52 | ), 53 | names=["id", "is_applicable"], 54 | ) 55 | ) 56 | ] 57 | ) 58 | 59 | assert actual == expected 60 | -------------------------------------------------------------------------------- /tests/builders/plan/test_set.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64 5 | from substrait.builders.plan import read_named_table, set 6 | from substrait.extension_registry import ExtensionRegistry 7 | 8 | registry = ExtensionRegistry(load_default_extensions=False) 9 | 10 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 11 | 12 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 13 | 14 | 15 | def test_set(): 16 | table = read_named_table("table", named_struct) 17 | table2 = read_named_table("table2", named_struct) 18 | 19 | actual = set([table, table2], stalg.SetRel.SET_OP_UNION_ALL)(None) 20 | 21 | expected = stp.Plan( 22 | relations=[ 23 | stp.PlanRel( 24 | root=stalg.RelRoot( 25 | input=stalg.Rel( 26 | set=stalg.SetRel( 27 | inputs=[ 28 | table(None).relations[-1].root.input, 29 | table2(None).relations[-1].root.input, 30 | ], 31 | op=stalg.SetRel.SET_OP_UNION_ALL, 32 | ) 33 | ), 34 | names=["id", "is_applicable"], 35 | ) 36 | ) 37 | ] 38 | ) 39 | 40 | assert actual == expected 41 | -------------------------------------------------------------------------------- /tests/builders/plan/test_sort.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | import substrait.gen.proto.plan_pb2 as stp 3 | import substrait.gen.proto.algebra_pb2 as stalg 4 | from substrait.builders.type import boolean, i64 5 | from substrait.builders.plan import read_named_table, sort 6 | from substrait.builders.extended_expression import column 7 | from substrait.type_inference import infer_plan_schema 8 | from substrait.extension_registry import ExtensionRegistry 9 | 10 | registry = ExtensionRegistry(load_default_extensions=False) 11 | 12 | struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) 13 | 14 | named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) 15 | 16 | 17 | def test_sort_no_direction(): 18 | table = read_named_table("table", named_struct) 19 | 20 | col = column("id") 21 | 22 | actual = sort(table, expressions=[col])(registry) 23 | 24 | expected = stp.Plan( 25 | relations=[ 26 | stp.PlanRel( 27 | root=stalg.RelRoot( 28 | input=stalg.Rel( 29 | sort=stalg.SortRel( 30 | input=table(None).relations[-1].root.input, 31 | sorts=[ 32 | stalg.SortField( 33 | direction=stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST, 34 | expr=col(infer_plan_schema(table(None)), registry) 35 | .referred_expr[0] 36 | .expression, 37 | ) 38 | ], 39 | ) 40 | ), 41 | names=["id", "is_applicable"], 42 | ) 43 | ) 44 | ] 45 | ) 46 | 47 | assert actual == expected 48 | 49 | 50 | def test_sort_direction(): 51 | table = read_named_table("table", named_struct) 52 | 53 | col = column("id") 54 | 55 | actual = sort( 56 | table, expressions=[(col, stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST)] 57 | )(registry) 58 | 59 | expected = stp.Plan( 60 | relations=[ 61 | stp.PlanRel( 62 | root=stalg.RelRoot( 63 | input=stalg.Rel( 64 | sort=stalg.SortRel( 65 | input=table(None).relations[-1].root.input, 66 | sorts=[ 67 | stalg.SortField( 68 | direction=stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST, 69 | expr=col(infer_plan_schema(table(None)), registry) 70 | .referred_expr[0] 71 | .expression, 72 | ) 73 | ], 74 | ) 75 | ), 76 | names=["id", "is_applicable"], 77 | ) 78 | ) 79 | ] 80 | ) 81 | 82 | assert actual == expected 83 | -------------------------------------------------------------------------------- /tests/test_derivation_expression.py: -------------------------------------------------------------------------------- 1 | from substrait.gen.proto.type_pb2 import Type 2 | from substrait.derivation_expression import evaluate 3 | 4 | 5 | def test_simple_arithmetic(): 6 | assert evaluate("1 + 1") == 2 7 | 8 | 9 | def test_simple_arithmetic_with_variables(): 10 | assert evaluate("1 + var", {"var": 2}) == 3 11 | 12 | 13 | def test_simple_arithmetic_parenthesis(): 14 | assert evaluate("(1 + var) * 3", {"var": 2}) == 9 15 | 16 | 17 | def test_min_max(): 18 | assert evaluate("min(var, 7) + max(var, 7)", {"var": 5}) == 12 19 | 20 | 21 | def test_ternary(): 22 | assert evaluate("var > 3 ? 1 : 0", {"var": 5}) == 1 23 | assert evaluate("var > 3 ? 1 : 0", {"var": 2}) == 0 24 | 25 | 26 | def test_multiline(): 27 | assert evaluate( 28 | """temp = min(var, 7) + max(var, 7) 29 | decimal""", 30 | {"var": 5}, 31 | ) == Type( 32 | decimal=Type.Decimal( 33 | precision=13, scale=11, nullability=Type.NULLABILITY_REQUIRED 34 | ) 35 | ) 36 | 37 | 38 | def test_simple_data_types(): 39 | assert evaluate("i8") == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) 40 | assert evaluate("i16") == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) 41 | assert evaluate("i32") == Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)) 42 | assert evaluate("i64") == Type(i64=Type.I64(nullability=Type.NULLABILITY_REQUIRED)) 43 | assert evaluate("fp32") == Type( 44 | fp32=Type.FP32(nullability=Type.NULLABILITY_REQUIRED) 45 | ) 46 | assert evaluate("fp64") == Type( 47 | fp64=Type.FP64(nullability=Type.NULLABILITY_REQUIRED) 48 | ) 49 | assert evaluate("boolean") == Type( 50 | bool=Type.Boolean(nullability=Type.NULLABILITY_REQUIRED) 51 | ) 52 | assert evaluate("i8?") == Type(i8=Type.I8(nullability=Type.NULLABILITY_NULLABLE)) 53 | assert evaluate("i16?") == Type(i16=Type.I16(nullability=Type.NULLABILITY_NULLABLE)) 54 | assert evaluate("i32?") == Type(i32=Type.I32(nullability=Type.NULLABILITY_NULLABLE)) 55 | assert evaluate("i64?") == Type(i64=Type.I64(nullability=Type.NULLABILITY_NULLABLE)) 56 | assert evaluate("fp32?") == Type( 57 | fp32=Type.FP32(nullability=Type.NULLABILITY_NULLABLE) 58 | ) 59 | assert evaluate("fp64?") == Type( 60 | fp64=Type.FP64(nullability=Type.NULLABILITY_NULLABLE) 61 | ) 62 | assert evaluate("boolean?") == Type( 63 | bool=Type.Boolean(nullability=Type.NULLABILITY_NULLABLE) 64 | ) 65 | 66 | 67 | def test_data_type(): 68 | assert evaluate("decimal

", {"S": 10, "P": 20}) == Type( 69 | decimal=Type.Decimal( 70 | precision=21, scale=11, nullability=Type.NULLABILITY_REQUIRED 71 | ) 72 | ) 73 | 74 | 75 | def test_data_type_nullable(): 76 | assert evaluate("decimal?

", {"S": 10, "P": 20}) == Type( 77 | decimal=Type.Decimal( 78 | precision=21, scale=11, nullability=Type.NULLABILITY_NULLABLE 79 | ) 80 | ) 81 | 82 | 83 | def test_decimal_example(): 84 | def func(P1, S1, P2, S2): 85 | init_scale = max(S1, S2) 86 | init_prec = init_scale + max(P1 - S1, P2 - S2) + 1 87 | min_scale = min(init_scale, 6) 88 | delta = init_prec - 38 89 | prec = min(init_prec, 38) 90 | scale_after_borrow = max(init_scale - delta, min_scale) 91 | scale = scale_after_borrow if init_prec > 38 else init_scale 92 | return Type( 93 | decimal=Type.Decimal( 94 | precision=prec, scale=scale, nullability=Type.NULLABILITY_REQUIRED 95 | ) 96 | ) 97 | 98 | args = {"P1": 10, "S1": 8, "P2": 14, "S2": 2} 99 | 100 | func_eval = func(**args) 101 | 102 | assert ( 103 | evaluate( 104 | """init_scale = max(S1,S2) 105 | init_prec = init_scale + max(P1 - S1, P2 - S2) + 1 106 | min_scale = min(init_scale, 6) 107 | delta = init_prec - 38 108 | prec = min(init_prec, 38) 109 | scale_after_borrow = max(init_scale - delta, min_scale) 110 | scale = init_prec > 38 ? scale_after_borrow : init_scale 111 | DECIMAL""", 112 | args, 113 | ) 114 | == func_eval 115 | ) 116 | -------------------------------------------------------------------------------- /tests/test_extension_registry.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | from substrait.gen.proto.type_pb2 import Type 4 | from substrait.extension_registry import ExtensionRegistry, covers 5 | from substrait.derivation_expression import _parse 6 | 7 | content = """%YAML 1.2 8 | --- 9 | scalar_functions: 10 | - name: "test_fn" 11 | description: "" 12 | impls: 13 | - args: 14 | - value: i8 15 | variadic: 16 | min: 2 17 | return: i8 18 | - name: "test_fn_variadic_any" 19 | description: "" 20 | impls: 21 | - args: 22 | - value: any1 23 | variadic: 24 | min: 2 25 | return: any1 26 | - name: "add" 27 | description: "Add two values." 28 | impls: 29 | - args: 30 | - name: x 31 | value: i8 32 | - name: y 33 | value: i8 34 | options: 35 | overflow: 36 | values: [ SILENT, SATURATE, ERROR ] 37 | return: i8 38 | - args: 39 | - name: x 40 | value: i8 41 | - name: y 42 | value: i8 43 | - name: z 44 | value: any 45 | options: 46 | overflow: 47 | values: [ SILENT, SATURATE, ERROR ] 48 | return: i16 49 | - args: 50 | - name: x 51 | value: any1 52 | - name: y 53 | value: any1 54 | - name: z 55 | value: any2 56 | options: 57 | overflow: 58 | values: [ SILENT, SATURATE, ERROR ] 59 | return: any2 60 | - name: "test_decimal" 61 | impls: 62 | - args: 63 | - name: x 64 | value: decimal 65 | - name: y 66 | value: decimal 67 | return: decimal 68 | - name: "test_enum" 69 | impls: 70 | - args: 71 | - name: op 72 | options: [ INTACT, FLIP ] 73 | - name: x 74 | value: i8 75 | return: i8 76 | - name: "add_declared" 77 | description: "Add two values." 78 | impls: 79 | - args: 80 | - name: x 81 | value: i8 82 | - name: y 83 | value: i8 84 | nullability: DECLARED_OUTPUT 85 | return: i8? 86 | - name: "add_discrete" 87 | description: "Add two values." 88 | impls: 89 | - args: 90 | - name: x 91 | value: i8? 92 | - name: y 93 | value: i8 94 | nullability: DISCRETE 95 | return: i8? 96 | - name: "test_decimal_discrete" 97 | impls: 98 | - args: 99 | - name: x 100 | value: decimal? 101 | - name: y 102 | value: decimal 103 | nullability: DISCRETE 104 | return: decimal? 105 | """ 106 | 107 | 108 | registry = ExtensionRegistry() 109 | 110 | registry.register_extension_dict(yaml.safe_load(content), uri="test") 111 | 112 | 113 | def i8(nullable=False): 114 | return Type( 115 | i8=Type.I8( 116 | nullability=Type.NULLABILITY_REQUIRED 117 | if not nullable 118 | else Type.NULLABILITY_NULLABLE 119 | ) 120 | ) 121 | 122 | 123 | def i16(nullable=False): 124 | return Type( 125 | i16=Type.I16( 126 | nullability=Type.NULLABILITY_REQUIRED 127 | if not nullable 128 | else Type.NULLABILITY_NULLABLE 129 | ) 130 | ) 131 | 132 | 133 | def bool(nullable=False): 134 | return Type( 135 | bool=Type.Boolean( 136 | nullability=Type.NULLABILITY_REQUIRED 137 | if not nullable 138 | else Type.NULLABILITY_NULLABLE 139 | ) 140 | ) 141 | 142 | 143 | def decimal(precision, scale, nullable=False): 144 | return Type( 145 | decimal=Type.Decimal( 146 | scale=scale, 147 | precision=precision, 148 | nullability=Type.NULLABILITY_REQUIRED 149 | if not nullable 150 | else Type.NULLABILITY_NULLABLE, 151 | ) 152 | ) 153 | 154 | 155 | def test_non_existing_uri(): 156 | assert ( 157 | registry.lookup_function( 158 | uri="non_existent", function_name="add", signature=[i8(), i8()] 159 | ) 160 | is None 161 | ) 162 | 163 | 164 | def test_non_existing_function(): 165 | assert ( 166 | registry.lookup_function( 167 | uri="test", function_name="sub", signature=[i8(), i8()] 168 | ) 169 | is None 170 | ) 171 | 172 | 173 | def test_non_existing_function_signature(): 174 | assert ( 175 | registry.lookup_function(uri="test", function_name="add", signature=[i8()]) 176 | is None 177 | ) 178 | 179 | 180 | def test_exact_match(): 181 | assert registry.lookup_function( 182 | uri="test", function_name="add", signature=[i8(), i8()] 183 | )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) 184 | 185 | 186 | def test_wildcard_match(): 187 | assert registry.lookup_function( 188 | uri="test", function_name="add", signature=[i8(), i8(), bool()] 189 | )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) 190 | 191 | 192 | def test_wildcard_match_fails_with_constraits(): 193 | assert ( 194 | registry.lookup_function( 195 | uri="test", function_name="add", signature=[i8(), i16(), i16()] 196 | ) 197 | is None 198 | ) 199 | 200 | 201 | def test_wildcard_match_with_constraits(): 202 | assert ( 203 | registry.lookup_function( 204 | uri="test", function_name="add", signature=[i16(), i16(), i8()] 205 | )[1] 206 | == i8() 207 | ) 208 | 209 | 210 | def test_variadic(): 211 | assert ( 212 | registry.lookup_function( 213 | uri="test", function_name="test_fn", signature=[i8(), i8(), i8()] 214 | )[1] 215 | == i8() 216 | ) 217 | 218 | 219 | def test_variadic_any(): 220 | assert ( 221 | registry.lookup_function( 222 | uri="test", 223 | function_name="test_fn_variadic_any", 224 | signature=[i16(), i16(), i16()], 225 | )[1] 226 | == i16() 227 | ) 228 | 229 | 230 | def test_variadic_fails_min_constraint(): 231 | assert ( 232 | registry.lookup_function(uri="test", function_name="test_fn", signature=[i8()]) 233 | is None 234 | ) 235 | 236 | 237 | def test_decimal_happy_path(): 238 | assert registry.lookup_function( 239 | uri="test", 240 | function_name="test_decimal", 241 | signature=[decimal(10, 8), decimal(8, 6)], 242 | )[1] == decimal(11, 7) 243 | 244 | 245 | def test_decimal_violates_constraint(): 246 | assert ( 247 | registry.lookup_function( 248 | uri="test", 249 | function_name="test_decimal", 250 | signature=[decimal(10, 8), decimal(12, 10)], 251 | ) 252 | is None 253 | ) 254 | 255 | 256 | def test_decimal_happy_path_discrete(): 257 | assert registry.lookup_function( 258 | uri="test", 259 | function_name="test_decimal_discrete", 260 | signature=[decimal(10, 8, nullable=True), decimal(8, 6)], 261 | )[1] == decimal(11, 7, nullable=True) 262 | 263 | 264 | def test_enum_with_valid_option(): 265 | assert ( 266 | registry.lookup_function( 267 | uri="test", 268 | function_name="test_enum", 269 | signature=["FLIP", i8()], 270 | )[1] 271 | == i8() 272 | ) 273 | 274 | 275 | def test_enum_with_nonexistent_option(): 276 | assert ( 277 | registry.lookup_function( 278 | uri="test", 279 | function_name="test_enum", 280 | signature=["NONEXISTENT", i8()], 281 | ) 282 | is None 283 | ) 284 | 285 | 286 | def test_function_with_nullable_args(): 287 | assert registry.lookup_function( 288 | uri="test", function_name="add", signature=[i8(nullable=True), i8()] 289 | )[1] == i8(nullable=True) 290 | 291 | 292 | def test_function_with_declared_output_nullability(): 293 | assert registry.lookup_function( 294 | uri="test", function_name="add_declared", signature=[i8(), i8()] 295 | )[1] == i8(nullable=True) 296 | 297 | 298 | def test_function_with_discrete_nullability(): 299 | assert registry.lookup_function( 300 | uri="test", function_name="add_discrete", signature=[i8(nullable=True), i8()] 301 | )[1] == i8(nullable=True) 302 | 303 | 304 | def test_function_with_discrete_nullability_nonexisting(): 305 | assert ( 306 | registry.lookup_function( 307 | uri="test", function_name="add_discrete", signature=[i8(), i8()] 308 | ) 309 | is None 310 | ) 311 | 312 | 313 | def test_covers(): 314 | params = {} 315 | assert covers(i8(), _parse("i8"), params) 316 | assert params == {} 317 | 318 | 319 | def test_covers_nullability(): 320 | assert not covers(i8(nullable=True), _parse("i8"), {}, check_nullability=True) 321 | assert covers(i8(nullable=True), _parse("i8?"), {}, check_nullability=True) 322 | 323 | 324 | def test_covers_decimal(): 325 | assert not covers(decimal(10, 8), _parse("decimal<11, A>"), {}) 326 | 327 | 328 | def test_covers_decimal_happy_path(): 329 | params = {} 330 | assert covers(decimal(10, 8), _parse("decimal<10, A>"), params) 331 | assert params == {"A": 8} 332 | 333 | 334 | def test_covers_any(): 335 | assert covers(decimal(10, 8), _parse("any"), {}) 336 | -------------------------------------------------------------------------------- /tests/test_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import tempfile 4 | 5 | from substrait.proto import Plan 6 | from substrait.json import load_json, parse_json, dump_json, write_json 7 | 8 | import pytest 9 | 10 | 11 | JSON_FIXTURES = ( 12 | pathlib.Path(os.path.dirname(__file__)) 13 | / ".." 14 | / "third_party" 15 | / "substrait-cpp" 16 | / "src" 17 | / "substrait" 18 | / "textplan" 19 | / "data" 20 | ) 21 | JSON_TEST_FILE = sorted(JSON_FIXTURES.glob("*.json")) 22 | JSON_TEST_FILENAMES = [path.name for path in JSON_TEST_FILE] 23 | 24 | 25 | @pytest.mark.parametrize("jsonfile", JSON_TEST_FILE, ids=JSON_TEST_FILENAMES) 26 | def test_json_load(jsonfile): 27 | with open(jsonfile) as f: 28 | jsondata = _strip_json_comments(f) 29 | parsed_plan = parse_json(jsondata) 30 | 31 | # Save to a temporary file so we can test load_json 32 | # on content stripped of comments. 33 | with tempfile.TemporaryDirectory() as tmpdir: 34 | # We use a TemporaryDirectory as on Windows NamedTemporaryFile 35 | # doesn't allow for easy reopening of the file. 36 | with open(pathlib.Path(tmpdir) / "jsonfile.json", "w+") as stripped_file: 37 | stripped_file.write(jsondata) 38 | loaded_plan = load_json(stripped_file.name) 39 | 40 | # The Plan constructor itself will throw an exception 41 | # in case there is anything wrong in parsing the JSON 42 | # so we can take for granted that if the plan was created 43 | # it is a valid plan in terms of protobuf definition. 44 | assert type(loaded_plan) is Plan 45 | 46 | # Ensure that when loading from file or from string 47 | # the outcome is the same 48 | assert parsed_plan == loaded_plan 49 | 50 | 51 | @pytest.mark.parametrize("jsonfile", JSON_TEST_FILE, ids=JSON_TEST_FILENAMES) 52 | def test_json_roundtrip(jsonfile): 53 | with open(jsonfile) as f: 54 | jsondata = _strip_json_comments(f) 55 | 56 | parsed_plan = parse_json(jsondata) 57 | assert parse_json(dump_json(parsed_plan)) == parsed_plan 58 | 59 | # Test with write/load 60 | with tempfile.TemporaryDirectory() as tmpdir: 61 | filename = pathlib.Path(tmpdir) / "jsonfile.json" 62 | write_json(parsed_plan, filename) 63 | assert load_json(filename) == parsed_plan 64 | 65 | 66 | def _strip_json_comments(jsonfile): 67 | # The JSON files in the cpp testsuite are prefixed with 68 | # a comment containing the SQL that matches the json plan. 69 | # As Python JSON parser doesn't support comments, 70 | # we have to strip them to make the content readable 71 | return "\n".join(line for line in jsonfile.readlines() if line[0] != "#") 72 | -------------------------------------------------------------------------------- /tests/test_proto.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | 3 | 4 | def test_imports(): 5 | """Temporary sanity test""" 6 | from substrait.gen.proto.algebra_pb2 import Expression 7 | from substrait.gen.proto.capabilities_pb2 import Capabilities 8 | from substrait.gen.proto.extended_expression_pb2 import ExtendedExpression 9 | from substrait.gen.proto.function_pb2 import FunctionSignature 10 | from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType 11 | from substrait.gen.proto.plan_pb2 import Plan 12 | from substrait.gen.proto.type_expressions_pb2 import DerivationExpression 13 | from substrait.gen.proto.type_pb2 import Type 14 | from substrait.gen.proto.extensions.extensions_pb2 import SimpleExtensionURI 15 | 16 | 17 | def test_proto_proxy_module(): 18 | """Test that protocol classes are made available in substrait.proto""" 19 | import substrait.proto 20 | 21 | assert {"Plan", "Type", "NamedStruct", "RelRoot"} <= set(dir(substrait.proto)) 22 | assert { 23 | "algebra", 24 | "capabilities", 25 | "extensions", 26 | "extended_expression", 27 | "function", 28 | "parameterized_types", 29 | "plan", 30 | "type_expressions", 31 | "type", 32 | } <= set(dir(substrait.proto)) 33 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import substrait.gen.proto.type_pb2 as stt 2 | from substrait.utils import type_num_names 3 | 4 | 5 | def test_type_num_names_flat_struct(): 6 | assert ( 7 | type_num_names( 8 | stt.Type( 9 | struct=stt.Type.Struct( 10 | types=[ 11 | stt.Type(i64=stt.Type.I64()), 12 | stt.Type(string=stt.Type.String()), 13 | stt.Type(fp32=stt.Type.FP32()), 14 | ] 15 | ) 16 | ) 17 | ) 18 | == 4 19 | ) 20 | 21 | 22 | def test_type_num_names_nested_struct(): 23 | assert ( 24 | type_num_names( 25 | stt.Type( 26 | struct=stt.Type.Struct( 27 | types=[ 28 | stt.Type(i64=stt.Type.I64()), 29 | stt.Type( 30 | struct=stt.Type.Struct( 31 | types=[ 32 | stt.Type(i64=stt.Type.I64()), 33 | stt.Type(fp32=stt.Type.FP32()), 34 | ] 35 | ) 36 | ), 37 | stt.Type(fp32=stt.Type.FP32()), 38 | ] 39 | ) 40 | ) 41 | ) 42 | == 6 43 | ) 44 | 45 | 46 | def test_type_num_names_flat_list(): 47 | assert ( 48 | type_num_names( 49 | stt.Type( 50 | struct=stt.Type.Struct( 51 | types=[ 52 | stt.Type(i64=stt.Type.I64()), 53 | stt.Type(list=stt.Type.List(type=stt.Type(i64=stt.Type.I64()))), 54 | stt.Type(fp32=stt.Type.FP32()), 55 | ] 56 | ) 57 | ) 58 | ) 59 | == 4 60 | ) 61 | 62 | 63 | def test_type_num_names_nested_list(): 64 | assert ( 65 | type_num_names( 66 | stt.Type( 67 | struct=stt.Type.Struct( 68 | types=[ 69 | stt.Type(i64=stt.Type.I64()), 70 | stt.Type( 71 | list=stt.Type.List( 72 | type=stt.Type( 73 | struct=stt.Type.Struct( 74 | types=[ 75 | stt.Type(i64=stt.Type.I64()), 76 | stt.Type(fp32=stt.Type.FP32()), 77 | ] 78 | ) 79 | ) 80 | ) 81 | ), 82 | stt.Type(fp32=stt.Type.FP32()), 83 | ] 84 | ) 85 | ) 86 | ) 87 | == 6 88 | ) 89 | -------------------------------------------------------------------------------- /update_cpp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Updating substrait-cpp submodule..." 4 | git submodule update --remote third_party/substrait-cpp 5 | 6 | -------------------------------------------------------------------------------- /update_proto.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if ! command -v curl > /dev/null 2>&1; then 4 | echo "curl is required to grab the latest version tag" 5 | fi 6 | 7 | VERSION=$(curl -s https://api.github.com/repos/substrait-io/substrait/releases/latest | grep 'tag_name' | cut -d '"' -f 4) 8 | 9 | echo "Updating substrait submodule..." 10 | git submodule update --remote third_party/substrait 11 | 12 | DIR=$(cd "$(dirname "$0")" && pwd) 13 | pushd "${DIR}"/third_party/substrait/ || exit 14 | git checkout "$VERSION" 15 | SUBSTRAIT_HASH=$(git rev-parse --short HEAD) 16 | popd || exit 17 | 18 | VERSION=${VERSION//v/} 19 | 20 | sed -i "s#__substrait_hash__.*#__substrait_hash__ = \"$SUBSTRAIT_HASH\"#g" src/substrait/__init__.py 21 | sed -i "s#__substrait_version__.*#__substrait_version__ = \"$VERSION\"#g" src/substrait/__init__.py 22 | 23 | ./gen_proto.sh 24 | --------------------------------------------------------------------------------