├── .github ├── pull_request_template.md └── workflows │ ├── build.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Cargo.lock ├── Cargo.toml ├── LICENSE.txt ├── README.md ├── crates └── connect │ ├── Cargo.toml │ ├── build.rs │ ├── protobuf │ └── spark-3.5 │ │ ├── buf.yaml │ │ └── spark │ │ └── connect │ │ ├── base.proto │ │ ├── catalog.proto │ │ ├── commands.proto │ │ ├── common.proto │ │ ├── expressions.proto │ │ ├── relations.proto │ │ └── types.proto │ └── src │ ├── catalog.rs │ ├── client │ ├── builder.rs │ ├── config.rs │ ├── middleware.rs │ └── mod.rs │ ├── column.rs │ ├── conf.rs │ ├── dataframe.rs │ ├── errors.rs │ ├── expressions.rs │ ├── functions │ └── mod.rs │ ├── group.rs │ ├── lib.rs │ ├── plan.rs │ ├── readwriter.rs │ ├── session.rs │ ├── storage.rs │ ├── streaming │ └── mod.rs │ ├── types.rs │ └── window.rs ├── datasets ├── dir1 │ ├── dir2 │ │ └── file2.parquet │ ├── file1.parquet │ └── file3.json ├── employees.json ├── full_user.avsc ├── kv1.txt ├── people.csv ├── people.json ├── people.txt ├── user.avsc ├── users.avro ├── users.orc └── users.parquet ├── docker-compose.yml ├── examples ├── Cargo.toml ├── README.md └── src │ ├── databricks.rs │ ├── deltalake.rs │ ├── reader.rs │ ├── readstream.rs │ ├── sql.rs │ └── writer.rs └── pre-commit.sh /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Description 2 | The description of the main changes of your pull request 3 | 4 | # Related Issue(s) 5 | 10 | 11 | # Documentation 12 | 13 | 16 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [main, "v*"] 6 | pull_request: 7 | branches: [main, "v*"] 8 | 9 | jobs: 10 | format: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | with: 15 | submodules: "true" 16 | 17 | - name: install protoc 18 | uses: arduino/setup-protoc@v2 19 | with: 20 | version: 23.x 21 | 22 | - name: Install minimal stable with clippy and rustfmt 23 | uses: actions-rs/toolchain@v1 24 | with: 25 | profile: default 26 | toolchain: stable 27 | override: true 28 | 29 | - name: Format 30 | run: cargo fmt -- --check 31 | 32 | build: 33 | runs-on: ubuntu-latest 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | with: 38 | submodules: "true" 39 | 40 | - name: install protoc 41 | uses: arduino/setup-protoc@v2 42 | with: 43 | version: 23.x 44 | 45 | - name: install minimal stable with clippy and rustfmt 46 | uses: actions-rs/toolchain@v1 47 | with: 48 | profile: default 49 | toolchain: stable 50 | override: true 51 | 52 | - uses: Swatinem/rust-cache@v2 53 | 54 | - name: build and lint with clippy 55 | run: cargo clippy 56 | 57 | - name: Check docs 58 | run: cargo doc 59 | 60 | - name: Check no default features (except rustls) 61 | run: cargo check 62 | 63 | integration_test: 64 | name: integration tests 65 | runs-on: ubuntu-latest 66 | env: 67 | CARGO_INCREMENTAL: 0 68 | # Disable full debug symbol generation to speed up CI build and keep memory down 69 | # 70 | RUSTFLAGS: "-C debuginfo=line-tables-only" 71 | # https://github.com/rust-lang/cargo/issues/10280 72 | CARGO_NET_GIT_FETCH_WITH_CLI: "true" 73 | RUST_BACKTRACE: "1" 74 | 75 | steps: 76 | - uses: actions/checkout@v4 77 | with: 78 | submodules: "true" 79 | 80 | - name: install protoc 81 | uses: arduino/setup-protoc@v2 82 | with: 83 | version: 23.x 84 | 85 | - name: install minimal stable with clippy and rustfmt 86 | uses: actions-rs/toolchain@v1 87 | with: 88 | profile: default 89 | toolchain: stable 90 | override: true 91 | 92 | - uses: Swatinem/rust-cache@v2 93 | 94 | - name: Start emulated services 95 | run: docker compose up -d 96 | 97 | - name: Run tests 98 | run: cargo test -p spark-connect-rs --features polars,datafusion 99 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release to cargo 2 | 3 | on: 4 | push: 5 | tags: ["v*"] 6 | 7 | jobs: 8 | validate-release-tag: 9 | name: Validate git tag 10 | runs-on: ubuntu-20.04 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: compare git tag with cargo metadata 14 | run: | 15 | PUSHED_TAG=${GITHUB_REF##*/} 16 | CURR_VER=$( grep version Cargo.toml | head -n 1 | awk '{print $3}' | tr -d '"' ) 17 | if [[ "${PUSHED_TAG}" != "v${CURR_VER}" ]]; then 18 | echo "Cargo metadata has version set to ${CURR_VER}, but got pushed tag ${PUSHED_TAG}." 19 | exit 1 20 | fi 21 | working-directory: ./crates 22 | 23 | release-crate: 24 | needs: validate-release-tag 25 | name: Release crate 26 | runs-on: ubuntu-20.04 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - uses: actions-rs/toolchain@v1 31 | with: 32 | profile: minimal 33 | toolchain: stable 34 | override: true 35 | 36 | - name: install protoc 37 | uses: arduino/setup-protoc@v2 38 | with: 39 | version: 23.x 40 | 41 | - name: cargo publish rust 42 | uses: actions-rs/cargo@v1 43 | env: 44 | CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} 45 | with: 46 | command: publish 47 | args: --token "${CARGO_REGISTRY_TOKEN}" --package spark-connect-rs --manifest-path ./Cargo.toml 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | .vscode 4 | *.ipynb 5 | 6 | /spark-warehouse 7 | /artifacts 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-byte-order-marker 6 | - id: check-case-conflict 7 | - id: check-merge-conflict 8 | - id: check-symlinks 9 | - id: check-yaml 10 | - id: end-of-file-fixer 11 | - id: mixed-line-ending 12 | - id: trailing-whitespace 13 | - repo: https://github.com/doublify/pre-commit-rust 14 | rev: v1.0 15 | hooks: 16 | - id: fmt 17 | - id: cargo-check 18 | - id: clippy 19 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["crates/*", "examples"] 3 | resolver = "2" 4 | 5 | [workspace.package] 6 | authors = ["Steve Russo <64294847+sjrusso8@users.noreply.github.com>"] 7 | keywords = ["spark", "spark_connect"] 8 | readme = "README.md" 9 | edition = "2021" 10 | homepage = "https://github.com/sjrusso8/spark-connect-rs" 11 | description = "Apache Spark Connect Client for Rust" 12 | license = "Apache-2.0" 13 | documentation = "https://docs.rs/spark-connect-rs" 14 | repository = "https://github.com/sjrusso8/spark-connect-rs" 15 | rust-version = "1.81" 16 | 17 | [workspace.dependencies] 18 | tonic = { version ="0.11", default-features = false } 19 | 20 | tokio = { version = "1.44", default-features = false, features = ["macros"] } 21 | tower = { version = "0.5" } 22 | 23 | futures-util = { version = "0.3" } 24 | thiserror = { version = "2.0" } 25 | 26 | http-body = { version = "0.4.6" } 27 | 28 | arrow = { version = "55", features = ["prettyprint"] } 29 | arrow-ipc = { version = "55" } 30 | 31 | serde_json = { version = "1" } 32 | 33 | prost = { version = "0.12" } 34 | prost-types = { version = "0.12" } 35 | 36 | rand = { version = "0.9" } 37 | uuid = { version = "1.16", features = ["v4"] } 38 | url = { version = "2.5" } 39 | regex = { version = "1" } 40 | 41 | chrono = { version = "0.4" } 42 | 43 | datafusion = { version = "47.0", default-features = false } 44 | polars = { version = "0.43", default-features = false } 45 | polars-arrow = { version = "0.43", default-features = false, features = ["arrow_rs"] } 46 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /crates/connect/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "spark-connect-rs" 3 | version = "0.0.2" 4 | authors.workspace = true 5 | keywords.workspace = true 6 | readme.workspace = true 7 | edition.workspace = true 8 | homepage.workspace = true 9 | description.workspace = true 10 | license.workspace = true 11 | documentation.workspace = true 12 | repository.workspace = true 13 | rust-version.workspace = true 14 | include = [ 15 | "build.rs", 16 | "src/**/*", 17 | "protobuf/**/*", 18 | ] 19 | 20 | [dependencies] 21 | tonic = { workspace = true, default-features = false, optional = true } 22 | 23 | tower = { workspace = true } 24 | tokio = { workspace = true, optional = true } 25 | 26 | futures-util = { workspace = true } 27 | thiserror = { workspace = true } 28 | 29 | http-body = { workspace = true } 30 | 31 | arrow = { workspace = true } 32 | arrow-ipc = { workspace = true } 33 | 34 | serde_json = { workspace = true } 35 | 36 | prost = { workspace = true } 37 | prost-types = { workspace = true } 38 | 39 | rand = { workspace = true } 40 | uuid = { workspace = true } 41 | url = { workspace = true } 42 | regex = { workspace = true } 43 | 44 | chrono = { workspace = true } 45 | 46 | datafusion = { workspace = true, optional = true } 47 | 48 | polars = { workspace = true, optional = true } 49 | polars-arrow = { workspace = true, optional = true } 50 | 51 | [dev-dependencies] 52 | futures = "0.3" 53 | tokio = { workspace = true, features = ["rt-multi-thread"] } 54 | 55 | [build-dependencies] 56 | tonic-build = "0.11" 57 | 58 | [lib] 59 | doctest = false 60 | 61 | [features] 62 | default = [ 63 | "tokio", 64 | "tonic/codegen", 65 | "tonic/prost", 66 | "tonic/transport", 67 | ] 68 | 69 | tls = [ 70 | "tonic/tls", 71 | "tonic/tls-roots" 72 | ] 73 | 74 | datafusion = [ 75 | "dep:datafusion" 76 | ] 77 | 78 | polars = [ 79 | "dep:polars", 80 | "dep:polars-arrow" 81 | ] 82 | -------------------------------------------------------------------------------- /crates/connect/build.rs: -------------------------------------------------------------------------------- 1 | use std::fs; 2 | 3 | fn main() -> Result<(), Box> { 4 | let files = fs::read_dir("./protobuf/spark-3.5/spark/connect/")?; 5 | 6 | let mut file_paths: Vec = vec![]; 7 | 8 | for file in files { 9 | let entry = file?.path(); 10 | file_paths.push(entry.to_str().unwrap().to_string()); 11 | } 12 | 13 | tonic_build::configure() 14 | .protoc_arg("--experimental_allow_proto3_optional") 15 | .build_server(false) 16 | .build_client(true) 17 | .build_transport(true) 18 | .compile(file_paths.as_ref(), &["./protobuf/spark-3.5/"])?; 19 | 20 | Ok(()) 21 | } 22 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/buf.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | version: v1 18 | breaking: 19 | use: 20 | - FILE 21 | except: 22 | - FILE_SAME_GO_PACKAGE 23 | lint: 24 | use: 25 | - DEFAULT 26 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/catalog.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | package spark.connect; 21 | 22 | import "spark/connect/common.proto"; 23 | import "spark/connect/types.proto"; 24 | 25 | option java_multiple_files = true; 26 | option java_package = "org.apache.spark.connect.proto"; 27 | option go_package = "internal/generated"; 28 | 29 | // Catalog messages are marked as unstable. 30 | message Catalog { 31 | oneof cat_type { 32 | CurrentDatabase current_database = 1; 33 | SetCurrentDatabase set_current_database = 2; 34 | ListDatabases list_databases = 3; 35 | ListTables list_tables = 4; 36 | ListFunctions list_functions = 5; 37 | ListColumns list_columns = 6; 38 | GetDatabase get_database = 7; 39 | GetTable get_table = 8; 40 | GetFunction get_function = 9; 41 | DatabaseExists database_exists = 10; 42 | TableExists table_exists = 11; 43 | FunctionExists function_exists = 12; 44 | CreateExternalTable create_external_table = 13; 45 | CreateTable create_table = 14; 46 | DropTempView drop_temp_view = 15; 47 | DropGlobalTempView drop_global_temp_view = 16; 48 | RecoverPartitions recover_partitions = 17; 49 | IsCached is_cached = 18; 50 | CacheTable cache_table = 19; 51 | UncacheTable uncache_table = 20; 52 | ClearCache clear_cache = 21; 53 | RefreshTable refresh_table = 22; 54 | RefreshByPath refresh_by_path = 23; 55 | CurrentCatalog current_catalog = 24; 56 | SetCurrentCatalog set_current_catalog = 25; 57 | ListCatalogs list_catalogs = 26; 58 | } 59 | } 60 | 61 | // See `spark.catalog.currentDatabase` 62 | message CurrentDatabase { } 63 | 64 | // See `spark.catalog.setCurrentDatabase` 65 | message SetCurrentDatabase { 66 | // (Required) 67 | string db_name = 1; 68 | } 69 | 70 | // See `spark.catalog.listDatabases` 71 | message ListDatabases { 72 | // (Optional) The pattern that the database name needs to match 73 | optional string pattern = 1; 74 | } 75 | 76 | // See `spark.catalog.listTables` 77 | message ListTables { 78 | // (Optional) 79 | optional string db_name = 1; 80 | // (Optional) The pattern that the table name needs to match 81 | optional string pattern = 2; 82 | } 83 | 84 | // See `spark.catalog.listFunctions` 85 | message ListFunctions { 86 | // (Optional) 87 | optional string db_name = 1; 88 | // (Optional) The pattern that the function name needs to match 89 | optional string pattern = 2; 90 | } 91 | 92 | // See `spark.catalog.listColumns` 93 | message ListColumns { 94 | // (Required) 95 | string table_name = 1; 96 | // (Optional) 97 | optional string db_name = 2; 98 | } 99 | 100 | // See `spark.catalog.getDatabase` 101 | message GetDatabase { 102 | // (Required) 103 | string db_name = 1; 104 | } 105 | 106 | // See `spark.catalog.getTable` 107 | message GetTable { 108 | // (Required) 109 | string table_name = 1; 110 | // (Optional) 111 | optional string db_name = 2; 112 | } 113 | 114 | // See `spark.catalog.getFunction` 115 | message GetFunction { 116 | // (Required) 117 | string function_name = 1; 118 | // (Optional) 119 | optional string db_name = 2; 120 | } 121 | 122 | // See `spark.catalog.databaseExists` 123 | message DatabaseExists { 124 | // (Required) 125 | string db_name = 1; 126 | } 127 | 128 | // See `spark.catalog.tableExists` 129 | message TableExists { 130 | // (Required) 131 | string table_name = 1; 132 | // (Optional) 133 | optional string db_name = 2; 134 | } 135 | 136 | // See `spark.catalog.functionExists` 137 | message FunctionExists { 138 | // (Required) 139 | string function_name = 1; 140 | // (Optional) 141 | optional string db_name = 2; 142 | } 143 | 144 | // See `spark.catalog.createExternalTable` 145 | message CreateExternalTable { 146 | // (Required) 147 | string table_name = 1; 148 | // (Optional) 149 | optional string path = 2; 150 | // (Optional) 151 | optional string source = 3; 152 | // (Optional) 153 | optional DataType schema = 4; 154 | // Options could be empty for valid data source format. 155 | // The map key is case insensitive. 156 | map options = 5; 157 | } 158 | 159 | // See `spark.catalog.createTable` 160 | message CreateTable { 161 | // (Required) 162 | string table_name = 1; 163 | // (Optional) 164 | optional string path = 2; 165 | // (Optional) 166 | optional string source = 3; 167 | // (Optional) 168 | optional string description = 4; 169 | // (Optional) 170 | optional DataType schema = 5; 171 | // Options could be empty for valid data source format. 172 | // The map key is case insensitive. 173 | map options = 6; 174 | } 175 | 176 | // See `spark.catalog.dropTempView` 177 | message DropTempView { 178 | // (Required) 179 | string view_name = 1; 180 | } 181 | 182 | // See `spark.catalog.dropGlobalTempView` 183 | message DropGlobalTempView { 184 | // (Required) 185 | string view_name = 1; 186 | } 187 | 188 | // See `spark.catalog.recoverPartitions` 189 | message RecoverPartitions { 190 | // (Required) 191 | string table_name = 1; 192 | } 193 | 194 | // See `spark.catalog.isCached` 195 | message IsCached { 196 | // (Required) 197 | string table_name = 1; 198 | } 199 | 200 | // See `spark.catalog.cacheTable` 201 | message CacheTable { 202 | // (Required) 203 | string table_name = 1; 204 | 205 | // (Optional) 206 | optional StorageLevel storage_level = 2; 207 | } 208 | 209 | // See `spark.catalog.uncacheTable` 210 | message UncacheTable { 211 | // (Required) 212 | string table_name = 1; 213 | } 214 | 215 | // See `spark.catalog.clearCache` 216 | message ClearCache { } 217 | 218 | // See `spark.catalog.refreshTable` 219 | message RefreshTable { 220 | // (Required) 221 | string table_name = 1; 222 | } 223 | 224 | // See `spark.catalog.refreshByPath` 225 | message RefreshByPath { 226 | // (Required) 227 | string path = 1; 228 | } 229 | 230 | // See `spark.catalog.currentCatalog` 231 | message CurrentCatalog { } 232 | 233 | // See `spark.catalog.setCurrentCatalog` 234 | message SetCurrentCatalog { 235 | // (Required) 236 | string catalog_name = 1; 237 | } 238 | 239 | // See `spark.catalog.listCatalogs` 240 | message ListCatalogs { 241 | // (Optional) The pattern that the catalog name needs to match 242 | optional string pattern = 1; 243 | } 244 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/commands.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | import "google/protobuf/any.proto"; 21 | import "spark/connect/common.proto"; 22 | import "spark/connect/expressions.proto"; 23 | import "spark/connect/relations.proto"; 24 | 25 | package spark.connect; 26 | 27 | option java_multiple_files = true; 28 | option java_package = "org.apache.spark.connect.proto"; 29 | option go_package = "internal/generated"; 30 | 31 | // A [[Command]] is an operation that is executed by the server that does not directly consume or 32 | // produce a relational result. 33 | message Command { 34 | oneof command_type { 35 | CommonInlineUserDefinedFunction register_function = 1; 36 | WriteOperation write_operation = 2; 37 | CreateDataFrameViewCommand create_dataframe_view = 3; 38 | WriteOperationV2 write_operation_v2 = 4; 39 | SqlCommand sql_command = 5; 40 | WriteStreamOperationStart write_stream_operation_start = 6; 41 | StreamingQueryCommand streaming_query_command = 7; 42 | GetResourcesCommand get_resources_command = 8; 43 | StreamingQueryManagerCommand streaming_query_manager_command = 9; 44 | CommonInlineUserDefinedTableFunction register_table_function = 10; 45 | 46 | // This field is used to mark extensions to the protocol. When plugins generate arbitrary 47 | // Commands they can add them here. During the planning the correct resolution is done. 48 | google.protobuf.Any extension = 999; 49 | 50 | } 51 | } 52 | 53 | // A SQL Command is used to trigger the eager evaluation of SQL commands in Spark. 54 | // 55 | // When the SQL provide as part of the message is a command it will be immediately evaluated 56 | // and the result will be collected and returned as part of a LocalRelation. If the result is 57 | // not a command, the operation will simply return a SQL Relation. This allows the client to be 58 | // almost oblivious to the server-side behavior. 59 | message SqlCommand { 60 | // (Required) SQL Query. 61 | string sql = 1; 62 | 63 | // (Optional) A map of parameter names to literal expressions. 64 | map args = 2; 65 | 66 | // (Optional) A sequence of literal expressions for positional parameters in the SQL query text. 67 | repeated Expression.Literal pos_args = 3; 68 | } 69 | 70 | // A command that can create DataFrame global temp view or local temp view. 71 | message CreateDataFrameViewCommand { 72 | // (Required) The relation that this view will be built on. 73 | Relation input = 1; 74 | 75 | // (Required) View name. 76 | string name = 2; 77 | 78 | // (Required) Whether this is global temp view or local temp view. 79 | bool is_global = 3; 80 | 81 | // (Required) 82 | // 83 | // If true, and if the view already exists, updates it; if false, and if the view 84 | // already exists, throws exception. 85 | bool replace = 4; 86 | } 87 | 88 | // As writes are not directly handled during analysis and planning, they are modeled as commands. 89 | message WriteOperation { 90 | // (Required) The output of the `input` relation will be persisted according to the options. 91 | Relation input = 1; 92 | 93 | // (Optional) Format value according to the Spark documentation. Examples are: text, parquet, delta. 94 | optional string source = 2; 95 | 96 | // (Optional) 97 | // 98 | // The destination of the write operation can be either a path or a table. 99 | // If the destination is neither a path nor a table, such as jdbc and noop, 100 | // the `save_type` should not be set. 101 | oneof save_type { 102 | string path = 3; 103 | SaveTable table = 4; 104 | } 105 | 106 | // (Required) the save mode. 107 | SaveMode mode = 5; 108 | 109 | // (Optional) List of columns to sort the output by. 110 | repeated string sort_column_names = 6; 111 | 112 | // (Optional) List of columns for partitioning. 113 | repeated string partitioning_columns = 7; 114 | 115 | // (Optional) Bucketing specification. Bucketing must set the number of buckets and the columns 116 | // to bucket by. 117 | BucketBy bucket_by = 8; 118 | 119 | // (Optional) A list of configuration options. 120 | map options = 9; 121 | 122 | message SaveTable { 123 | // (Required) The table name. 124 | string table_name = 1; 125 | // (Required) The method to be called to write to the table. 126 | TableSaveMethod save_method = 2; 127 | 128 | enum TableSaveMethod { 129 | TABLE_SAVE_METHOD_UNSPECIFIED = 0; 130 | TABLE_SAVE_METHOD_SAVE_AS_TABLE = 1; 131 | TABLE_SAVE_METHOD_INSERT_INTO = 2; 132 | } 133 | } 134 | 135 | message BucketBy { 136 | repeated string bucket_column_names = 1; 137 | int32 num_buckets = 2; 138 | } 139 | 140 | enum SaveMode { 141 | SAVE_MODE_UNSPECIFIED = 0; 142 | SAVE_MODE_APPEND = 1; 143 | SAVE_MODE_OVERWRITE = 2; 144 | SAVE_MODE_ERROR_IF_EXISTS = 3; 145 | SAVE_MODE_IGNORE = 4; 146 | } 147 | } 148 | 149 | // As writes are not directly handled during analysis and planning, they are modeled as commands. 150 | message WriteOperationV2 { 151 | // (Required) The output of the `input` relation will be persisted according to the options. 152 | Relation input = 1; 153 | 154 | // (Required) The destination of the write operation must be either a path or a table. 155 | string table_name = 2; 156 | 157 | // (Optional) A provider for the underlying output data source. Spark's default catalog supports 158 | // "parquet", "json", etc. 159 | optional string provider = 3; 160 | 161 | // (Optional) List of columns for partitioning for output table created by `create`, 162 | // `createOrReplace`, or `replace` 163 | repeated Expression partitioning_columns = 4; 164 | 165 | // (Optional) A list of configuration options. 166 | map options = 5; 167 | 168 | // (Optional) A list of table properties. 169 | map table_properties = 6; 170 | 171 | // (Required) Write mode. 172 | Mode mode = 7; 173 | 174 | enum Mode { 175 | MODE_UNSPECIFIED = 0; 176 | MODE_CREATE = 1; 177 | MODE_OVERWRITE = 2; 178 | MODE_OVERWRITE_PARTITIONS = 3; 179 | MODE_APPEND = 4; 180 | MODE_REPLACE = 5; 181 | MODE_CREATE_OR_REPLACE = 6; 182 | } 183 | 184 | // (Optional) A condition for overwrite saving mode 185 | Expression overwrite_condition = 8; 186 | } 187 | 188 | // Starts write stream operation as streaming query. Query ID and Run ID of the streaming 189 | // query are returned. 190 | message WriteStreamOperationStart { 191 | 192 | // (Required) The output of the `input` streaming relation will be written. 193 | Relation input = 1; 194 | 195 | // The following fields directly map to API for DataStreamWriter(). 196 | // Consult API documentation unless explicitly documented here. 197 | 198 | string format = 2; 199 | map options = 3; 200 | repeated string partitioning_column_names = 4; 201 | 202 | oneof trigger { 203 | string processing_time_interval = 5; 204 | bool available_now = 6; 205 | bool once = 7; 206 | string continuous_checkpoint_interval = 8; 207 | } 208 | 209 | string output_mode = 9; 210 | string query_name = 10; 211 | 212 | // The destination is optional. When set, it can be a path or a table name. 213 | oneof sink_destination { 214 | string path = 11; 215 | string table_name = 12; 216 | } 217 | 218 | StreamingForeachFunction foreach_writer = 13; 219 | StreamingForeachFunction foreach_batch = 14; 220 | } 221 | 222 | message StreamingForeachFunction { 223 | oneof function { 224 | PythonUDF python_function = 1; 225 | ScalarScalaUDF scala_function = 2; 226 | } 227 | } 228 | 229 | message WriteStreamOperationStartResult { 230 | 231 | // (Required) Query instance. See `StreamingQueryInstanceId`. 232 | StreamingQueryInstanceId query_id = 1; 233 | 234 | // An optional query name. 235 | string name = 2; 236 | 237 | // TODO: How do we indicate errors? 238 | // TODO: Consider adding status, last progress etc here. 239 | } 240 | 241 | // A tuple that uniquely identifies an instance of streaming query run. It consists of `id` that 242 | // persists across the streaming runs and `run_id` that changes between each run of the 243 | // streaming query that resumes from the checkpoint. 244 | message StreamingQueryInstanceId { 245 | 246 | // (Required) The unique id of this query that persists across restarts from checkpoint data. 247 | // That is, this id is generated when a query is started for the first time, and 248 | // will be the same every time it is restarted from checkpoint data. 249 | string id = 1; 250 | 251 | // (Required) The unique id of this run of the query. That is, every start/restart of a query 252 | // will generate a unique run_id. Therefore, every time a query is restarted from 253 | // checkpoint, it will have the same `id` but different `run_id`s. 254 | string run_id = 2; 255 | } 256 | 257 | // Commands for a streaming query. 258 | message StreamingQueryCommand { 259 | 260 | // (Required) Query instance. See `StreamingQueryInstanceId`. 261 | StreamingQueryInstanceId query_id = 1; 262 | 263 | // See documentation for the corresponding API method in StreamingQuery. 264 | oneof command { 265 | // status() API. 266 | bool status = 2; 267 | // lastProgress() API. 268 | bool last_progress = 3; 269 | // recentProgress() API. 270 | bool recent_progress = 4; 271 | // stop() API. Stops the query. 272 | bool stop = 5; 273 | // processAllAvailable() API. Waits till all the available data is processed 274 | bool process_all_available = 6; 275 | // explain() API. Returns logical and physical plans. 276 | ExplainCommand explain = 7; 277 | // exception() API. Returns the exception in the query if any. 278 | bool exception = 8; 279 | // awaitTermination() API. Waits for the termination of the query. 280 | AwaitTerminationCommand await_termination = 9; 281 | } 282 | 283 | message ExplainCommand { 284 | // TODO: Consider reusing Explain from AnalyzePlanRequest message. 285 | // We can not do this right now since it base.proto imports this file. 286 | bool extended = 1; 287 | } 288 | 289 | message AwaitTerminationCommand { 290 | optional int64 timeout_ms = 2; 291 | } 292 | } 293 | 294 | // Response for commands on a streaming query. 295 | message StreamingQueryCommandResult { 296 | // (Required) Query instance id. See `StreamingQueryInstanceId`. 297 | StreamingQueryInstanceId query_id = 1; 298 | 299 | oneof result_type { 300 | StatusResult status = 2; 301 | RecentProgressResult recent_progress = 3; 302 | ExplainResult explain = 4; 303 | ExceptionResult exception = 5; 304 | AwaitTerminationResult await_termination = 6; 305 | } 306 | 307 | message StatusResult { 308 | // See documentation for these Scala 'StreamingQueryStatus' struct 309 | string status_message = 1; 310 | bool is_data_available = 2; 311 | bool is_trigger_active = 3; 312 | bool is_active = 4; 313 | } 314 | 315 | message RecentProgressResult { 316 | // Progress reports as an array of json strings. 317 | repeated string recent_progress_json = 5; 318 | } 319 | 320 | message ExplainResult { 321 | // Logical and physical plans as string 322 | string result = 1; 323 | } 324 | 325 | message ExceptionResult { 326 | // (Optional) Exception message as string, maps to the return value of original 327 | // StreamingQueryException's toString method 328 | optional string exception_message = 1; 329 | // (Optional) Exception error class as string 330 | optional string error_class = 2; 331 | // (Optional) Exception stack trace as string 332 | optional string stack_trace = 3; 333 | } 334 | 335 | message AwaitTerminationResult { 336 | bool terminated = 1; 337 | } 338 | } 339 | 340 | // Commands for the streaming query manager. 341 | message StreamingQueryManagerCommand { 342 | 343 | // See documentation for the corresponding API method in StreamingQueryManager. 344 | oneof command { 345 | // active() API, returns a list of active queries. 346 | bool active = 1; 347 | // get() API, returns the StreamingQuery identified by id. 348 | string get_query = 2; 349 | // awaitAnyTermination() API, wait until any query terminates or timeout. 350 | AwaitAnyTerminationCommand await_any_termination = 3; 351 | // resetTerminated() API. 352 | bool reset_terminated = 4; 353 | // addListener API. 354 | StreamingQueryListenerCommand add_listener = 5; 355 | // removeListener API. 356 | StreamingQueryListenerCommand remove_listener = 6; 357 | // listListeners() API, returns a list of streaming query listeners. 358 | bool list_listeners = 7; 359 | } 360 | 361 | message AwaitAnyTerminationCommand { 362 | // (Optional) The waiting time in milliseconds to wait for any query to terminate. 363 | optional int64 timeout_ms = 1; 364 | } 365 | 366 | message StreamingQueryListenerCommand { 367 | bytes listener_payload = 1; 368 | optional PythonUDF python_listener_payload = 2; 369 | string id = 3; 370 | } 371 | } 372 | 373 | // Response for commands on the streaming query manager. 374 | message StreamingQueryManagerCommandResult { 375 | oneof result_type { 376 | ActiveResult active = 1; 377 | StreamingQueryInstance query = 2; 378 | AwaitAnyTerminationResult await_any_termination = 3; 379 | bool reset_terminated = 4; 380 | bool add_listener = 5; 381 | bool remove_listener = 6; 382 | ListStreamingQueryListenerResult list_listeners = 7; 383 | } 384 | 385 | message ActiveResult { 386 | repeated StreamingQueryInstance active_queries = 1; 387 | } 388 | 389 | message StreamingQueryInstance { 390 | // (Required) The id and runId of this query. 391 | StreamingQueryInstanceId id = 1; 392 | // (Optional) The name of this query. 393 | optional string name = 2; 394 | } 395 | 396 | message AwaitAnyTerminationResult { 397 | bool terminated = 1; 398 | } 399 | 400 | message StreamingQueryListenerInstance { 401 | bytes listener_payload = 1; 402 | } 403 | 404 | message ListStreamingQueryListenerResult { 405 | // (Required) Reference IDs of listener instances. 406 | repeated string listener_ids = 1; 407 | } 408 | } 409 | 410 | // Command to get the output of 'SparkContext.resources' 411 | message GetResourcesCommand { } 412 | 413 | // Response for command 'GetResourcesCommand'. 414 | message GetResourcesCommandResult { 415 | map resources = 1; 416 | } 417 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/common.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | package spark.connect; 21 | 22 | option java_multiple_files = true; 23 | option java_package = "org.apache.spark.connect.proto"; 24 | option go_package = "internal/generated"; 25 | 26 | // StorageLevel for persisting Datasets/Tables. 27 | message StorageLevel { 28 | // (Required) Whether the cache should use disk or not. 29 | bool use_disk = 1; 30 | // (Required) Whether the cache should use memory or not. 31 | bool use_memory = 2; 32 | // (Required) Whether the cache should use off-heap or not. 33 | bool use_off_heap = 3; 34 | // (Required) Whether the cached data is deserialized or not. 35 | bool deserialized = 4; 36 | // (Required) The number of replicas. 37 | int32 replication = 5; 38 | } 39 | 40 | 41 | // ResourceInformation to hold information about a type of Resource. 42 | // The corresponding class is 'org.apache.spark.resource.ResourceInformation' 43 | message ResourceInformation { 44 | // (Required) The name of the resource 45 | string name = 1; 46 | // (Required) An array of strings describing the addresses of the resource. 47 | repeated string addresses = 2; 48 | } 49 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/expressions.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | import "google/protobuf/any.proto"; 21 | import "spark/connect/types.proto"; 22 | 23 | package spark.connect; 24 | 25 | option java_multiple_files = true; 26 | option java_package = "org.apache.spark.connect.proto"; 27 | option go_package = "internal/generated"; 28 | 29 | // Expression used to refer to fields, functions and similar. This can be used everywhere 30 | // expressions in SQL appear. 31 | message Expression { 32 | 33 | oneof expr_type { 34 | Literal literal = 1; 35 | UnresolvedAttribute unresolved_attribute = 2; 36 | UnresolvedFunction unresolved_function = 3; 37 | ExpressionString expression_string = 4; 38 | UnresolvedStar unresolved_star = 5; 39 | Alias alias = 6; 40 | Cast cast = 7; 41 | UnresolvedRegex unresolved_regex = 8; 42 | SortOrder sort_order = 9; 43 | LambdaFunction lambda_function = 10; 44 | Window window = 11; 45 | UnresolvedExtractValue unresolved_extract_value = 12; 46 | UpdateFields update_fields = 13; 47 | UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14; 48 | CommonInlineUserDefinedFunction common_inline_user_defined_function = 15; 49 | CallFunction call_function = 16; 50 | 51 | // This field is used to mark extensions to the protocol. When plugins generate arbitrary 52 | // relations they can add them here. During the planning the correct resolution is done. 53 | google.protobuf.Any extension = 999; 54 | } 55 | 56 | 57 | // Expression for the OVER clause or WINDOW clause. 58 | message Window { 59 | 60 | // (Required) The window function. 61 | Expression window_function = 1; 62 | 63 | // (Optional) The way that input rows are partitioned. 64 | repeated Expression partition_spec = 2; 65 | 66 | // (Optional) Ordering of rows in a partition. 67 | repeated SortOrder order_spec = 3; 68 | 69 | // (Optional) Window frame in a partition. 70 | // 71 | // If not set, it will be treated as 'UnspecifiedFrame'. 72 | WindowFrame frame_spec = 4; 73 | 74 | // The window frame 75 | message WindowFrame { 76 | 77 | // (Required) The type of the frame. 78 | FrameType frame_type = 1; 79 | 80 | // (Required) The lower bound of the frame. 81 | FrameBoundary lower = 2; 82 | 83 | // (Required) The upper bound of the frame. 84 | FrameBoundary upper = 3; 85 | 86 | enum FrameType { 87 | FRAME_TYPE_UNDEFINED = 0; 88 | 89 | // RowFrame treats rows in a partition individually. 90 | FRAME_TYPE_ROW = 1; 91 | 92 | // RangeFrame treats rows in a partition as groups of peers. 93 | // All rows having the same 'ORDER BY' ordering are considered as peers. 94 | FRAME_TYPE_RANGE = 2; 95 | } 96 | 97 | message FrameBoundary { 98 | oneof boundary { 99 | // CURRENT ROW boundary 100 | bool current_row = 1; 101 | 102 | // UNBOUNDED boundary. 103 | // For lower bound, it will be converted to 'UnboundedPreceding'. 104 | // for upper bound, it will be converted to 'UnboundedFollowing'. 105 | bool unbounded = 2; 106 | 107 | // This is an expression for future proofing. We are expecting literals on the server side. 108 | Expression value = 3; 109 | } 110 | } 111 | } 112 | } 113 | 114 | // SortOrder is used to specify the data ordering, it is normally used in Sort and Window. 115 | // It is an unevaluable expression and cannot be evaluated, so can not be used in Projection. 116 | message SortOrder { 117 | // (Required) The expression to be sorted. 118 | Expression child = 1; 119 | 120 | // (Required) The sort direction, should be ASCENDING or DESCENDING. 121 | SortDirection direction = 2; 122 | 123 | // (Required) How to deal with NULLs, should be NULLS_FIRST or NULLS_LAST. 124 | NullOrdering null_ordering = 3; 125 | 126 | enum SortDirection { 127 | SORT_DIRECTION_UNSPECIFIED = 0; 128 | SORT_DIRECTION_ASCENDING = 1; 129 | SORT_DIRECTION_DESCENDING = 2; 130 | } 131 | 132 | enum NullOrdering { 133 | SORT_NULLS_UNSPECIFIED = 0; 134 | SORT_NULLS_FIRST = 1; 135 | SORT_NULLS_LAST = 2; 136 | } 137 | } 138 | 139 | message Cast { 140 | // (Required) the expression to be casted. 141 | Expression expr = 1; 142 | 143 | // (Required) the data type that the expr to be casted to. 144 | oneof cast_to_type { 145 | DataType type = 2; 146 | // If this is set, Server will use Catalyst parser to parse this string to DataType. 147 | string type_str = 3; 148 | } 149 | } 150 | 151 | message Literal { 152 | oneof literal_type { 153 | DataType null = 1; 154 | bytes binary = 2; 155 | bool boolean = 3; 156 | 157 | int32 byte = 4; 158 | int32 short = 5; 159 | int32 integer = 6; 160 | int64 long = 7; 161 | float float = 10; 162 | double double = 11; 163 | Decimal decimal = 12; 164 | 165 | string string = 13; 166 | 167 | // Date in units of days since the UNIX epoch. 168 | int32 date = 16; 169 | // Timestamp in units of microseconds since the UNIX epoch. 170 | int64 timestamp = 17; 171 | // Timestamp in units of microseconds since the UNIX epoch (without timezone information). 172 | int64 timestamp_ntz = 18; 173 | 174 | CalendarInterval calendar_interval = 19; 175 | int32 year_month_interval = 20; 176 | int64 day_time_interval = 21; 177 | Array array = 22; 178 | Map map = 23; 179 | Struct struct = 24; 180 | } 181 | 182 | message Decimal { 183 | // the string representation. 184 | string value = 1; 185 | // The maximum number of digits allowed in the value. 186 | // the maximum precision is 38. 187 | optional int32 precision = 2; 188 | // declared scale of decimal literal 189 | optional int32 scale = 3; 190 | } 191 | 192 | message CalendarInterval { 193 | int32 months = 1; 194 | int32 days = 2; 195 | int64 microseconds = 3; 196 | } 197 | 198 | message Array { 199 | DataType element_type = 1; 200 | repeated Literal elements = 2; 201 | } 202 | 203 | message Map { 204 | DataType key_type = 1; 205 | DataType value_type = 2; 206 | repeated Literal keys = 3; 207 | repeated Literal values = 4; 208 | } 209 | 210 | message Struct { 211 | DataType struct_type = 1; 212 | repeated Literal elements = 2; 213 | } 214 | } 215 | 216 | // An unresolved attribute that is not explicitly bound to a specific column, but the column 217 | // is resolved during analysis by name. 218 | message UnresolvedAttribute { 219 | // (Required) An identifier that will be parsed by Catalyst parser. This should follow the 220 | // Spark SQL identifier syntax. 221 | string unparsed_identifier = 1; 222 | 223 | // (Optional) The id of corresponding connect plan. 224 | optional int64 plan_id = 2; 225 | } 226 | 227 | // An unresolved function is not explicitly bound to one explicit function, but the function 228 | // is resolved during analysis following Sparks name resolution rules. 229 | message UnresolvedFunction { 230 | // (Required) name (or unparsed name for user defined function) for the unresolved function. 231 | string function_name = 1; 232 | 233 | // (Optional) Function arguments. Empty arguments are allowed. 234 | repeated Expression arguments = 2; 235 | 236 | // (Required) Indicate if this function should be applied on distinct values. 237 | bool is_distinct = 3; 238 | 239 | // (Required) Indicate if this is a user defined function. 240 | // 241 | // When it is not a user defined function, Connect will use the function name directly. 242 | // When it is a user defined function, Connect will parse the function name first. 243 | bool is_user_defined_function = 4; 244 | } 245 | 246 | // Expression as string. 247 | message ExpressionString { 248 | // (Required) A SQL expression that will be parsed by Catalyst parser. 249 | string expression = 1; 250 | } 251 | 252 | // UnresolvedStar is used to expand all the fields of a relation or struct. 253 | message UnresolvedStar { 254 | 255 | // (Optional) The target of the expansion. 256 | // 257 | // If set, it should end with '.*' and will be parsed by 'parseAttributeName' 258 | // in the server side. 259 | optional string unparsed_target = 1; 260 | } 261 | 262 | // Represents all of the input attributes to a given relational operator, for example in 263 | // "SELECT `(id)?+.+` FROM ...". 264 | message UnresolvedRegex { 265 | // (Required) The column name used to extract column with regex. 266 | string col_name = 1; 267 | 268 | // (Optional) The id of corresponding connect plan. 269 | optional int64 plan_id = 2; 270 | } 271 | 272 | // Extracts a value or values from an Expression 273 | message UnresolvedExtractValue { 274 | // (Required) The expression to extract value from, can be 275 | // Map, Array, Struct or array of Structs. 276 | Expression child = 1; 277 | 278 | // (Required) The expression to describe the extraction, can be 279 | // key of Map, index of Array, field name of Struct. 280 | Expression extraction = 2; 281 | } 282 | 283 | // Add, replace or drop a field of `StructType` expression by name. 284 | message UpdateFields { 285 | // (Required) The struct expression. 286 | Expression struct_expression = 1; 287 | 288 | // (Required) The field name. 289 | string field_name = 2; 290 | 291 | // (Optional) The expression to add or replace. 292 | // 293 | // When not set, it means this field will be dropped. 294 | Expression value_expression = 3; 295 | } 296 | 297 | message Alias { 298 | // (Required) The expression that alias will be added on. 299 | Expression expr = 1; 300 | 301 | // (Required) a list of name parts for the alias. 302 | // 303 | // Scalar columns only has one name that presents. 304 | repeated string name = 2; 305 | 306 | // (Optional) Alias metadata expressed as a JSON map. 307 | optional string metadata = 3; 308 | } 309 | 310 | message LambdaFunction { 311 | // (Required) The lambda function. 312 | // 313 | // The function body should use 'UnresolvedAttribute' as arguments, the sever side will 314 | // replace 'UnresolvedAttribute' with 'UnresolvedNamedLambdaVariable'. 315 | Expression function = 1; 316 | 317 | // (Required) Function variables. Must contains 1 ~ 3 variables. 318 | repeated Expression.UnresolvedNamedLambdaVariable arguments = 2; 319 | } 320 | 321 | message UnresolvedNamedLambdaVariable { 322 | 323 | // (Required) a list of name parts for the variable. Must not be empty. 324 | repeated string name_parts = 1; 325 | } 326 | } 327 | 328 | message CommonInlineUserDefinedFunction { 329 | // (Required) Name of the user-defined function. 330 | string function_name = 1; 331 | // (Optional) Indicate if the user-defined function is deterministic. 332 | bool deterministic = 2; 333 | // (Optional) Function arguments. Empty arguments are allowed. 334 | repeated Expression arguments = 3; 335 | // (Required) Indicate the function type of the user-defined function. 336 | oneof function { 337 | PythonUDF python_udf = 4; 338 | ScalarScalaUDF scalar_scala_udf = 5; 339 | JavaUDF java_udf = 6; 340 | } 341 | } 342 | 343 | message PythonUDF { 344 | // (Required) Output type of the Python UDF 345 | DataType output_type = 1; 346 | // (Required) EvalType of the Python UDF 347 | int32 eval_type = 2; 348 | // (Required) The encoded commands of the Python UDF 349 | bytes command = 3; 350 | // (Required) Python version being used in the client. 351 | string python_ver = 4; 352 | } 353 | 354 | message ScalarScalaUDF { 355 | // (Required) Serialized JVM object containing UDF definition, input encoders and output encoder 356 | bytes payload = 1; 357 | // (Optional) Input type(s) of the UDF 358 | repeated DataType inputTypes = 2; 359 | // (Required) Output type of the UDF 360 | DataType outputType = 3; 361 | // (Required) True if the UDF can return null value 362 | bool nullable = 4; 363 | } 364 | 365 | message JavaUDF { 366 | // (Required) Fully qualified name of Java class 367 | string class_name = 1; 368 | 369 | // (Optional) Output type of the Java UDF 370 | optional DataType output_type = 2; 371 | 372 | // (Required) Indicate if the Java user-defined function is an aggregate function 373 | bool aggregate = 3; 374 | } 375 | 376 | message CallFunction { 377 | // (Required) Unparsed name of the SQL function. 378 | string function_name = 1; 379 | 380 | // (Optional) Function arguments. Empty arguments are allowed. 381 | repeated Expression arguments = 2; 382 | } 383 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/types.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | package spark.connect; 21 | 22 | option java_multiple_files = true; 23 | option java_package = "org.apache.spark.connect.proto"; 24 | option go_package = "internal/generated"; 25 | 26 | // This message describes the logical [[DataType]] of something. It does not carry the value 27 | // itself but only describes it. 28 | message DataType { 29 | oneof kind { 30 | NULL null = 1; 31 | 32 | Binary binary = 2; 33 | 34 | Boolean boolean = 3; 35 | 36 | // Numeric types 37 | Byte byte = 4; 38 | Short short = 5; 39 | Integer integer = 6; 40 | Long long = 7; 41 | 42 | Float float = 8; 43 | Double double = 9; 44 | Decimal decimal = 10; 45 | 46 | // String types 47 | String string = 11; 48 | Char char = 12; 49 | VarChar var_char = 13; 50 | 51 | // Datatime types 52 | Date date = 14; 53 | Timestamp timestamp = 15; 54 | TimestampNTZ timestamp_ntz = 16; 55 | 56 | // Interval types 57 | CalendarInterval calendar_interval = 17; 58 | YearMonthInterval year_month_interval = 18; 59 | DayTimeInterval day_time_interval = 19; 60 | 61 | // Complex types 62 | Array array = 20; 63 | Struct struct = 21; 64 | Map map = 22; 65 | 66 | // UserDefinedType 67 | UDT udt = 23; 68 | 69 | // UnparsedDataType 70 | Unparsed unparsed = 24; 71 | } 72 | 73 | message Boolean { 74 | uint32 type_variation_reference = 1; 75 | } 76 | 77 | message Byte { 78 | uint32 type_variation_reference = 1; 79 | } 80 | 81 | message Short { 82 | uint32 type_variation_reference = 1; 83 | } 84 | 85 | message Integer { 86 | uint32 type_variation_reference = 1; 87 | } 88 | 89 | message Long { 90 | uint32 type_variation_reference = 1; 91 | } 92 | 93 | message Float { 94 | uint32 type_variation_reference = 1; 95 | } 96 | 97 | message Double { 98 | uint32 type_variation_reference = 1; 99 | } 100 | 101 | message String { 102 | uint32 type_variation_reference = 1; 103 | } 104 | 105 | message Binary { 106 | uint32 type_variation_reference = 1; 107 | } 108 | 109 | message NULL { 110 | uint32 type_variation_reference = 1; 111 | } 112 | 113 | message Timestamp { 114 | uint32 type_variation_reference = 1; 115 | } 116 | 117 | message Date { 118 | uint32 type_variation_reference = 1; 119 | } 120 | 121 | message TimestampNTZ { 122 | uint32 type_variation_reference = 1; 123 | } 124 | 125 | message CalendarInterval { 126 | uint32 type_variation_reference = 1; 127 | } 128 | 129 | message YearMonthInterval { 130 | optional int32 start_field = 1; 131 | optional int32 end_field = 2; 132 | uint32 type_variation_reference = 3; 133 | } 134 | 135 | message DayTimeInterval { 136 | optional int32 start_field = 1; 137 | optional int32 end_field = 2; 138 | uint32 type_variation_reference = 3; 139 | } 140 | 141 | // Start compound types. 142 | message Char { 143 | int32 length = 1; 144 | uint32 type_variation_reference = 2; 145 | } 146 | 147 | message VarChar { 148 | int32 length = 1; 149 | uint32 type_variation_reference = 2; 150 | } 151 | 152 | message Decimal { 153 | optional int32 scale = 1; 154 | optional int32 precision = 2; 155 | uint32 type_variation_reference = 3; 156 | } 157 | 158 | message StructField { 159 | string name = 1; 160 | DataType data_type = 2; 161 | bool nullable = 3; 162 | optional string metadata = 4; 163 | } 164 | 165 | message Struct { 166 | repeated StructField fields = 1; 167 | uint32 type_variation_reference = 2; 168 | } 169 | 170 | message Array { 171 | DataType element_type = 1; 172 | bool contains_null = 2; 173 | uint32 type_variation_reference = 3; 174 | } 175 | 176 | message Map { 177 | DataType key_type = 1; 178 | DataType value_type = 2; 179 | bool value_contains_null = 3; 180 | uint32 type_variation_reference = 4; 181 | } 182 | 183 | message UDT { 184 | string type = 1; 185 | optional string jvm_class = 2; 186 | optional string python_class = 3; 187 | optional string serialized_python_class = 4; 188 | DataType sql_type = 5; 189 | } 190 | 191 | message Unparsed { 192 | // (Required) The unparsed data type string 193 | string data_type_string = 1; 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /crates/connect/src/client/builder.rs: -------------------------------------------------------------------------------- 1 | //! Implementation of ChannelBuilder 2 | 3 | use std::collections::HashMap; 4 | use std::env; 5 | use std::str::FromStr; 6 | 7 | use crate::errors::SparkError; 8 | 9 | use url::Url; 10 | 11 | use uuid::Uuid; 12 | 13 | pub(crate) type Host = String; 14 | pub(crate) type Port = u16; 15 | pub(crate) type UrlParse = (Host, Port, Option>); 16 | 17 | /// ChannelBuilder validates a connection string 18 | /// based on the requirements from [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) 19 | #[derive(Clone, Debug)] 20 | pub struct ChannelBuilder { 21 | pub(super) host: Host, 22 | pub(super) port: Port, 23 | pub(super) session_id: Uuid, 24 | pub(super) token: Option, 25 | pub(super) user_id: Option, 26 | pub(super) user_agent: Option, 27 | pub(super) use_ssl: bool, 28 | pub(super) headers: Option>, 29 | } 30 | 31 | impl Default for ChannelBuilder { 32 | fn default() -> Self { 33 | let connection = match env::var("SPARK_REMOTE") { 34 | Ok(conn) => conn.to_string(), 35 | Err(_) => "sc://localhost:15002".to_string(), 36 | }; 37 | 38 | ChannelBuilder::create(&connection).unwrap() 39 | } 40 | } 41 | 42 | impl ChannelBuilder { 43 | pub fn new() -> Self { 44 | ChannelBuilder::default() 45 | } 46 | 47 | pub(crate) fn endpoint(&self) -> String { 48 | let scheme = if cfg!(feature = "tls") { 49 | "https" 50 | } else { 51 | "http" 52 | }; 53 | 54 | format!("{}://{}:{}", scheme, self.host, self.port) 55 | } 56 | 57 | pub(crate) fn headers(&self) -> Option> { 58 | self.headers.to_owned() 59 | } 60 | 61 | pub(crate) fn create_user_agent(user_agent: Option<&str>) -> Option { 62 | let user_agent = user_agent.unwrap_or("_SPARK_CONNECT_RUST"); 63 | let pkg_version = env!("CARGO_PKG_VERSION"); 64 | let os = env::consts::OS.to_lowercase(); 65 | 66 | Some(format!( 67 | "{} os/{} spark_connect_rs/{}", 68 | user_agent, os, pkg_version 69 | )) 70 | } 71 | 72 | pub(crate) fn create_user_id(user_id: Option<&str>) -> Option { 73 | match user_id { 74 | Some(user_id) => Some(user_id.to_string()), 75 | None => env::var("USER").ok(), 76 | } 77 | } 78 | 79 | pub(crate) fn parse_connection_string(connection: &str) -> Result { 80 | let url = Url::parse(connection).map_err(|_| { 81 | SparkError::InvalidConnectionUrl("Failed to parse the connection URL".to_string()) 82 | })?; 83 | 84 | if url.scheme() != "sc" { 85 | return Err(SparkError::InvalidConnectionUrl( 86 | "The URL must start with 'sc://'. Please update the URL to follow the correct format, e.g., 'sc://hostname:port'".to_string(), 87 | )); 88 | }; 89 | 90 | let host = url 91 | .host_str() 92 | .ok_or_else(|| { 93 | SparkError::InvalidConnectionUrl( 94 | "The hostname must not be empty. Please update 95 | the URL to follow the correct format, e.g., 'sc://hostname:port'." 96 | .to_string(), 97 | ) 98 | })? 99 | .to_string(); 100 | 101 | let port = url.port().ok_or_else(|| { 102 | SparkError::InvalidConnectionUrl( 103 | "The port must not be empty. Please update 104 | the URL to follow the correct format, e.g., 'sc://hostname:port'." 105 | .to_string(), 106 | ) 107 | })?; 108 | 109 | let headers = ChannelBuilder::parse_headers(url); 110 | 111 | Ok((host, port, headers)) 112 | } 113 | 114 | pub(crate) fn parse_headers(url: Url) -> Option> { 115 | let path: Vec<&str> = url 116 | .path() 117 | .split(';') 118 | .filter(|&pair| (pair != "/") & (!pair.is_empty())) 119 | .collect(); 120 | 121 | if path.is_empty() || (path.len() == 1 && (path[0].is_empty() || path[0] == "/")) { 122 | return None; 123 | } 124 | 125 | let headers: HashMap = path 126 | .iter() 127 | .copied() 128 | .map(|pair| { 129 | let mut parts = pair.splitn(2, '='); 130 | ( 131 | parts.next().unwrap_or("").to_string(), 132 | parts.next().unwrap_or("").to_string(), 133 | ) 134 | }) 135 | .collect(); 136 | 137 | if headers.is_empty() { 138 | return None; 139 | } 140 | 141 | Some(headers) 142 | } 143 | 144 | /// Create and validate a connnection string 145 | #[allow(unreachable_code)] 146 | pub fn create(connection: &str) -> Result { 147 | let (host, port, headers) = ChannelBuilder::parse_connection_string(connection)?; 148 | 149 | let mut channel_builder = ChannelBuilder { 150 | host, 151 | port, 152 | session_id: Uuid::new_v4(), 153 | token: None, 154 | user_id: ChannelBuilder::create_user_id(None), 155 | user_agent: ChannelBuilder::create_user_agent(None), 156 | use_ssl: false, 157 | headers: None, 158 | }; 159 | 160 | if let Some(mut headers) = headers { 161 | channel_builder.user_id = headers 162 | .remove("user_id") 163 | .map(|user_id| ChannelBuilder::create_user_id(Some(&user_id))) 164 | .unwrap_or_else(|| ChannelBuilder::create_user_id(None)); 165 | 166 | channel_builder.user_agent = headers 167 | .remove("user_agent") 168 | .map(|user_agent| ChannelBuilder::create_user_agent(Some(&user_agent))) 169 | .unwrap_or_else(|| ChannelBuilder::create_user_agent(None)); 170 | 171 | if let Some(token) = headers.remove("token") { 172 | let token = format!("Bearer {token}"); 173 | channel_builder.token = Some(token.clone()); 174 | headers.insert("authorization".to_string(), token); 175 | } 176 | 177 | if let Some(session_id) = headers.remove("session_id") { 178 | channel_builder.session_id = Uuid::from_str(&session_id)? 179 | } 180 | 181 | if let Some(use_ssl) = headers.remove("use_ssl") { 182 | if use_ssl.to_lowercase() == "true" { 183 | #[cfg(not(feature = "tls"))] 184 | { 185 | panic!( 186 | "The 'use_ssl' option requires the 'tls' feature, but it's not enabled!" 187 | ); 188 | }; 189 | channel_builder.use_ssl = true 190 | } 191 | }; 192 | 193 | if !headers.is_empty() { 194 | channel_builder.headers = Some(headers); 195 | } 196 | } 197 | 198 | Ok(channel_builder) 199 | } 200 | } 201 | 202 | #[cfg(test)] 203 | mod tests { 204 | use super::*; 205 | 206 | #[test] 207 | fn test_channel_builder_default() { 208 | let expected_url = "http://localhost:15002".to_string(); 209 | 210 | let cb = ChannelBuilder::default(); 211 | 212 | assert_eq!(expected_url, cb.endpoint()) 213 | } 214 | 215 | #[test] 216 | fn test_panic_incorrect_url_scheme() { 217 | let connection = "http://127.0.0.1:15002"; 218 | 219 | assert!(ChannelBuilder::create(connection).is_err()) 220 | } 221 | 222 | #[test] 223 | fn test_panic_missing_url_host() { 224 | let connection = "sc://:15002"; 225 | 226 | assert!(ChannelBuilder::create(connection).is_err()) 227 | } 228 | 229 | #[test] 230 | fn test_panic_missing_url_port() { 231 | let connection = "sc://127.0.0.1"; 232 | 233 | assert!(ChannelBuilder::create(connection).is_err()) 234 | } 235 | 236 | #[test] 237 | fn test_settings_builder() { 238 | let connection = "sc://myhost.com:443/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; 239 | 240 | let builder = ChannelBuilder::create(connection).unwrap(); 241 | 242 | assert_eq!("http://myhost.com:443".to_string(), builder.endpoint()); 243 | assert_eq!("Bearer ABCDEFG".to_string(), builder.token.unwrap()); 244 | assert_eq!("user123".to_string(), builder.user_id.unwrap()); 245 | } 246 | 247 | #[test] 248 | #[should_panic( 249 | expected = "The 'use_ssl' option requires the 'tls' feature, but it's not enabled!" 250 | )] 251 | fn test_panic_ssl() { 252 | let connection = "sc://127.0.0.1:443/;use_ssl=true"; 253 | 254 | ChannelBuilder::create(connection).unwrap(); 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /crates/connect/src/client/config.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use uuid::Uuid; 3 | 4 | use crate::client::builder::{Host, Port}; 5 | use crate::client::ChannelBuilder; 6 | 7 | /// Config handler to set custom SparkSessionBuilder options 8 | #[derive(Clone, Debug, Default)] 9 | pub struct Config { 10 | pub host: Host, 11 | pub port: Port, 12 | pub session_id: Uuid, 13 | pub token: Option, 14 | pub user_id: Option, 15 | pub user_agent: Option, 16 | pub use_ssl: bool, 17 | pub headers: Option>, 18 | } 19 | 20 | impl Config { 21 | pub fn new() -> Self { 22 | Config { 23 | host: "localhost".to_string(), 24 | port: 15002, 25 | token: None, 26 | session_id: Uuid::new_v4(), 27 | user_id: ChannelBuilder::create_user_id(None), 28 | user_agent: ChannelBuilder::create_user_agent(None), 29 | use_ssl: false, 30 | headers: None, 31 | } 32 | } 33 | 34 | pub fn host(mut self, val: &str) -> Self { 35 | self.host = val.to_string(); 36 | self 37 | } 38 | 39 | pub fn port(mut self, val: Port) -> Self { 40 | self.port = val; 41 | self 42 | } 43 | 44 | pub fn token(mut self, val: &str) -> Self { 45 | self.token = Some(val.to_string()); 46 | self 47 | } 48 | 49 | pub fn session_id(mut self, val: Uuid) -> Self { 50 | self.session_id = val; 51 | self 52 | } 53 | 54 | pub fn user_id(mut self, val: &str) -> Self { 55 | self.user_id = Some(val.to_string()); 56 | self 57 | } 58 | 59 | pub fn user_agent(mut self, val: &str) -> Self { 60 | self.user_agent = Some(val.to_string()); 61 | self 62 | } 63 | 64 | pub fn use_ssl(mut self, val: bool) -> Self { 65 | self.use_ssl = val; 66 | self 67 | } 68 | 69 | pub fn headers(mut self, val: HashMap) -> Self { 70 | self.headers = Some(val); 71 | self 72 | } 73 | } 74 | 75 | impl From for ChannelBuilder { 76 | fn from(config: Config) -> Self { 77 | // if there is a token, then it needs to be added to the headers 78 | // do not overwrite any existing authentication header 79 | 80 | let mut headers = config.headers.unwrap_or_default(); 81 | 82 | if let Some(token) = &config.token { 83 | headers 84 | .entry("authorization".to_string()) 85 | .or_insert_with(|| format!("Bearer {}", token)); 86 | } 87 | 88 | Self { 89 | host: config.host, 90 | port: config.port, 91 | session_id: config.session_id, 92 | token: config.token, 93 | user_id: config.user_id, 94 | user_agent: config.user_agent, 95 | use_ssl: config.use_ssl, 96 | headers: if headers.is_empty() { 97 | None 98 | } else { 99 | Some(headers) 100 | }, 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /crates/connect/src/client/middleware.rs: -------------------------------------------------------------------------------- 1 | //! Middleware services implemented with tower.rs 2 | 3 | use std::collections::HashMap; 4 | use std::fmt::Debug; 5 | use std::str::FromStr; 6 | use std::task::{Context, Poll}; 7 | 8 | use futures_util::future::BoxFuture; 9 | use http_body::combinators::UnsyncBoxBody; 10 | 11 | use tonic::codegen::http::Request; 12 | use tonic::codegen::http::{HeaderName, HeaderValue}; 13 | 14 | use tower::Service; 15 | 16 | /// Headers to apply a gRPC request 17 | #[derive(Debug, Clone)] 18 | pub struct HeadersLayer { 19 | headers: HashMap, 20 | } 21 | 22 | impl HeadersLayer { 23 | pub fn new(headers: HashMap) -> Self { 24 | Self { headers } 25 | } 26 | } 27 | 28 | impl tower::Layer for HeadersLayer { 29 | type Service = HeadersMiddleware; 30 | 31 | fn layer(&self, inner: S) -> Self::Service { 32 | HeadersMiddleware::new(inner, self.headers.clone()) 33 | } 34 | } 35 | 36 | /// Middleware used to apply provided headers onto a gRPC request 37 | #[derive(Clone, Debug)] 38 | pub struct HeadersMiddleware { 39 | inner: S, 40 | headers: HashMap, 41 | } 42 | 43 | #[allow(dead_code)] 44 | impl HeadersMiddleware { 45 | pub fn new(inner: S, headers: HashMap) -> Self { 46 | Self { inner, headers } 47 | } 48 | } 49 | 50 | // TODO! as of now Request is not clone. So the retry logic does not work. 51 | // https://github.com/tower-rs/tower/pull/790 52 | impl Service>> for HeadersMiddleware 53 | where 54 | S: Service>> 55 | + Clone 56 | + Send 57 | + Sync 58 | + 'static, 59 | S::Future: Send + 'static, 60 | S::Response: Send + Debug + 'static, 61 | S::Error: Debug, 62 | { 63 | type Response = S::Response; 64 | type Error = S::Error; 65 | type Future = BoxFuture<'static, Result>; 66 | 67 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 68 | self.inner.poll_ready(cx).map_err(Into::into) 69 | } 70 | 71 | fn call( 72 | &mut self, 73 | mut request: Request>, 74 | ) -> Self::Future { 75 | let clone = self.inner.clone(); 76 | let mut inner = std::mem::replace(&mut self.inner, clone); 77 | 78 | let headers = self.headers.clone(); 79 | 80 | Box::pin(async move { 81 | for (key, value) in &headers { 82 | let meta_key = HeaderName::from_str(key.as_str()).unwrap(); 83 | let meta_val = HeaderValue::from_str(value.as_str()).unwrap(); 84 | 85 | request.headers_mut().insert(meta_key, meta_val); 86 | } 87 | 88 | inner.call(request).await 89 | }) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /crates/connect/src/client/mod.rs: -------------------------------------------------------------------------------- 1 | //! Implementation of the SparkConnectServiceClient 2 | 3 | use std::sync::Arc; 4 | 5 | use tokio::sync::RwLock; 6 | 7 | use tonic::codec::Streaming; 8 | use tonic::codegen::{Body, Bytes, StdError}; 9 | use tonic::transport::Channel; 10 | 11 | use crate::spark; 12 | use spark::execute_plan_response::ResponseType; 13 | use spark::spark_connect_service_client::SparkConnectServiceClient; 14 | 15 | use arrow::compute::concat_batches; 16 | use arrow::error::ArrowError; 17 | use arrow::record_batch::RecordBatch; 18 | use arrow_ipc::reader::StreamReader; 19 | 20 | use uuid::Uuid; 21 | 22 | use crate::errors::SparkError; 23 | 24 | mod builder; 25 | mod config; 26 | mod middleware; 27 | 28 | pub use builder::ChannelBuilder; 29 | pub use config::Config; 30 | pub use middleware::{HeadersLayer, HeadersMiddleware}; 31 | 32 | pub type SparkClient = SparkConnectClient>; 33 | 34 | #[allow(dead_code)] 35 | #[derive(Default, Debug, Clone)] 36 | pub(crate) struct ResponseHandler { 37 | metrics: Option, 38 | observed_metrics: Option, 39 | pub(crate) schema: Option, 40 | batches: Vec, 41 | pub(crate) sql_command_result: Option, 42 | pub(crate) write_stream_operation_start_result: Option, 43 | pub(crate) streaming_query_command_result: Option, 44 | pub(crate) get_resources_command_result: Option, 45 | pub(crate) streaming_query_manager_command_result: 46 | Option, 47 | pub(crate) result_complete: bool, 48 | total_count: isize, 49 | } 50 | 51 | #[derive(Default, Debug, Clone)] 52 | pub(crate) struct AnalyzeHandler { 53 | pub(crate) schema: Option, 54 | pub(crate) explain: Option, 55 | pub(crate) tree_string: Option, 56 | pub(crate) is_local: Option, 57 | pub(crate) is_streaming: Option, 58 | pub(crate) input_files: Option>, 59 | pub(crate) spark_version: Option, 60 | pub(crate) ddl_parse: Option, 61 | pub(crate) same_semantics: Option, 62 | pub(crate) semantic_hash: Option, 63 | pub(crate) get_storage_level: Option, 64 | } 65 | 66 | /// Client wrapper to handle submitting requests and handling responses from the [SparkConnectServiceClient] 67 | #[derive(Clone, Debug)] 68 | pub struct SparkConnectClient { 69 | stub: Arc>>, 70 | builder: ChannelBuilder, 71 | session_id: String, 72 | operation_id: Option, 73 | response_id: Option, 74 | pub(crate) handler: ResponseHandler, 75 | pub(crate) analyzer: AnalyzeHandler, 76 | pub(crate) user_context: Option, 77 | pub(crate) tags: Vec, 78 | pub(crate) use_reattachable_execute: bool, 79 | } 80 | 81 | impl SparkConnectClient 82 | where 83 | T: tonic::client::GrpcService, 84 | T::Error: Into, 85 | T::ResponseBody: Body + Send + 'static, 86 | ::Error: Into + Send, 87 | { 88 | pub fn new(stub: Arc>>, builder: ChannelBuilder) -> Self { 89 | let user_ref = builder.user_id.clone().unwrap_or("".to_string()); 90 | let session_id = builder.session_id.to_string(); 91 | 92 | SparkConnectClient { 93 | stub, 94 | builder, 95 | session_id, 96 | operation_id: None, 97 | response_id: None, 98 | handler: ResponseHandler::default(), 99 | analyzer: AnalyzeHandler::default(), 100 | user_context: Some(spark::UserContext { 101 | user_id: user_ref.clone(), 102 | user_name: user_ref, 103 | extensions: vec![], 104 | }), 105 | tags: vec![], 106 | use_reattachable_execute: true, 107 | } 108 | } 109 | 110 | /// Session ID 111 | pub fn session_id(&self) -> String { 112 | self.session_id.clone() 113 | } 114 | 115 | /// Change the reattachable execute value 116 | pub fn set_reattachable_execute(&mut self, setting: bool) -> Result<(), SparkError> { 117 | self.use_reattachable_execute = setting; 118 | Ok(()) 119 | } 120 | 121 | fn request_options(&self) -> Vec { 122 | if self.use_reattachable_execute { 123 | let reattach_opt = spark::ReattachOptions { reattachable: true }; 124 | let request_opt = spark::execute_plan_request::RequestOption { 125 | request_option: Some( 126 | spark::execute_plan_request::request_option::RequestOption::ReattachOptions( 127 | reattach_opt, 128 | ), 129 | ), 130 | }; 131 | 132 | return vec![request_opt]; 133 | }; 134 | 135 | vec![] 136 | } 137 | 138 | pub fn execute_plan_request_with_metadata(&mut self) -> spark::ExecutePlanRequest { 139 | let operation_id = Uuid::new_v4().to_string(); 140 | 141 | self.operation_id = Some(operation_id.clone()); 142 | 143 | spark::ExecutePlanRequest { 144 | session_id: self.session_id(), 145 | user_context: self.user_context.clone(), 146 | operation_id: Some(operation_id), 147 | plan: None, 148 | client_type: self.builder.user_agent.clone(), 149 | request_options: self.request_options(), 150 | tags: self.tags.clone(), 151 | } 152 | } 153 | 154 | pub fn analyze_plan_request_with_metadata(&self) -> spark::AnalyzePlanRequest { 155 | spark::AnalyzePlanRequest { 156 | session_id: self.session_id(), 157 | user_context: self.user_context.clone(), 158 | client_type: self.builder.user_agent.clone(), 159 | analyze: None, 160 | } 161 | } 162 | 163 | pub async fn execute_and_fetch( 164 | &mut self, 165 | req: spark::ExecutePlanRequest, 166 | ) -> Result<(), SparkError> { 167 | let mut client = self.stub.write().await; 168 | 169 | let mut stream = client.execute_plan(req).await?.into_inner(); 170 | drop(client); 171 | 172 | // clear out any prior responses 173 | self.handler = ResponseHandler::default(); 174 | 175 | self.process_stream(&mut stream).await?; 176 | 177 | if self.use_reattachable_execute && self.handler.result_complete { 178 | self.release_all().await? 179 | } 180 | 181 | Ok(()) 182 | } 183 | 184 | async fn reattach_execute(&mut self) -> Result<(), SparkError> { 185 | let mut client = self.stub.write().await; 186 | 187 | let req = spark::ReattachExecuteRequest { 188 | session_id: self.session_id(), 189 | user_context: self.user_context.clone(), 190 | operation_id: self.operation_id.clone().unwrap(), 191 | client_type: self.builder.user_agent.clone(), 192 | last_response_id: self.response_id.clone(), 193 | }; 194 | 195 | let mut stream = client.reattach_execute(req).await?.into_inner(); 196 | drop(client); 197 | 198 | self.process_stream(&mut stream).await?; 199 | 200 | if self.use_reattachable_execute && self.handler.result_complete { 201 | self.release_all().await? 202 | } 203 | 204 | Ok(()) 205 | } 206 | 207 | async fn process_stream( 208 | &mut self, 209 | stream: &mut Streaming, 210 | ) -> Result<(), SparkError> { 211 | while let Some(_resp) = match stream.message().await { 212 | Ok(Some(msg)) => { 213 | self.handle_response(msg.clone())?; 214 | Some(msg) 215 | } 216 | Ok(None) => { 217 | if self.use_reattachable_execute && !self.handler.result_complete { 218 | Box::pin(self.reattach_execute()).await?; 219 | } 220 | None 221 | } 222 | Err(err) => { 223 | if self.use_reattachable_execute && self.response_id.is_some() { 224 | self.release_until().await?; 225 | } 226 | return Err(err.into()); 227 | } 228 | } {} 229 | 230 | Ok(()) 231 | } 232 | 233 | async fn release_until(&mut self) -> Result<(), SparkError> { 234 | let release_until = spark::release_execute_request::ReleaseUntil { 235 | response_id: self.response_id.clone().unwrap(), 236 | }; 237 | 238 | self.release_execute(Some(spark::release_execute_request::Release::ReleaseUntil( 239 | release_until, 240 | ))) 241 | .await 242 | } 243 | 244 | async fn release_all(&mut self) -> Result<(), SparkError> { 245 | let release_all = spark::release_execute_request::ReleaseAll {}; 246 | 247 | self.release_execute(Some(spark::release_execute_request::Release::ReleaseAll( 248 | release_all, 249 | ))) 250 | .await 251 | } 252 | 253 | async fn release_execute( 254 | &mut self, 255 | release: Option, 256 | ) -> Result<(), SparkError> { 257 | let mut client = self.stub.write().await; 258 | 259 | let req = spark::ReleaseExecuteRequest { 260 | session_id: self.session_id(), 261 | user_context: self.user_context.clone(), 262 | operation_id: self.operation_id.clone().unwrap(), 263 | client_type: self.builder.user_agent.clone(), 264 | release, 265 | }; 266 | 267 | let _resp = client.release_execute(req).await?.into_inner(); 268 | 269 | Ok(()) 270 | } 271 | 272 | pub async fn analyze( 273 | &mut self, 274 | analyze: spark::analyze_plan_request::Analyze, 275 | ) -> Result<&mut Self, SparkError> { 276 | let mut req = self.analyze_plan_request_with_metadata(); 277 | 278 | req.analyze = Some(analyze); 279 | 280 | // clear out any prior responses 281 | self.analyzer = AnalyzeHandler::default(); 282 | 283 | let mut client = self.stub.write().await; 284 | let resp = client.analyze_plan(req).await?.into_inner(); 285 | drop(client); 286 | 287 | self.handle_analyze(resp) 288 | } 289 | 290 | fn validate_tag(&self, tag: &str) -> Result<(), SparkError> { 291 | if tag.contains(',') { 292 | return Err(SparkError::AnalysisException( 293 | "Spark Connect tag can not contain ',' ".to_string(), 294 | )); 295 | }; 296 | 297 | if tag.is_empty() { 298 | return Err(SparkError::AnalysisException( 299 | "Spark Connect tag can not an empty string ".to_string(), 300 | )); 301 | }; 302 | 303 | Ok(()) 304 | } 305 | 306 | pub fn add_tag(&mut self, tag: &str) -> Result<(), SparkError> { 307 | self.validate_tag(tag)?; 308 | self.tags.push(tag.to_string()); 309 | Ok(()) 310 | } 311 | 312 | pub fn remove_tag(&mut self, tag: &str) -> Result<(), SparkError> { 313 | self.validate_tag(tag)?; 314 | self.tags.retain(|t| t != tag); 315 | Ok(()) 316 | } 317 | 318 | pub fn get_tags(&self) -> &Vec { 319 | &self.tags 320 | } 321 | 322 | pub fn clear_tags(&mut self) { 323 | self.tags = vec![]; 324 | } 325 | 326 | pub async fn config_request( 327 | &self, 328 | operation: spark::config_request::Operation, 329 | ) -> Result { 330 | let operation = spark::ConfigRequest { 331 | session_id: self.session_id(), 332 | user_context: self.user_context.clone(), 333 | client_type: self.builder.user_agent.clone(), 334 | operation: Some(operation), 335 | }; 336 | 337 | let mut client = self.stub.write().await; 338 | 339 | let resp = client.config(operation).await?.into_inner(); 340 | 341 | Ok(resp) 342 | } 343 | 344 | pub async fn interrupt_request( 345 | &self, 346 | interrupt_type: spark::interrupt_request::InterruptType, 347 | id_or_tag: Option, 348 | ) -> Result { 349 | let mut req = spark::InterruptRequest { 350 | session_id: self.session_id(), 351 | user_context: self.user_context.clone(), 352 | client_type: self.builder.user_agent.clone(), 353 | interrupt_type: 0, 354 | interrupt: None, 355 | }; 356 | 357 | match interrupt_type { 358 | spark::interrupt_request::InterruptType::All => { 359 | req.interrupt_type = interrupt_type.into(); 360 | } 361 | spark::interrupt_request::InterruptType::Tag => { 362 | let tag = id_or_tag.expect("Tag can not be empty"); 363 | let interrupt = spark::interrupt_request::Interrupt::OperationTag(tag); 364 | req.interrupt_type = interrupt_type.into(); 365 | req.interrupt = Some(interrupt); 366 | } 367 | spark::interrupt_request::InterruptType::OperationId => { 368 | let op_id = id_or_tag.expect("Operation ID can not be empty"); 369 | let interrupt = spark::interrupt_request::Interrupt::OperationId(op_id); 370 | req.interrupt_type = interrupt_type.into(); 371 | req.interrupt = Some(interrupt); 372 | } 373 | spark::interrupt_request::InterruptType::Unspecified => { 374 | return Err(SparkError::AnalysisException( 375 | "Interrupt Type was not specified".to_string(), 376 | )) 377 | } 378 | }; 379 | 380 | let mut client = self.stub.write().await; 381 | 382 | let resp = client.interrupt(req).await?.into_inner(); 383 | 384 | Ok(resp) 385 | } 386 | 387 | fn handle_response(&mut self, resp: spark::ExecutePlanResponse) -> Result<(), SparkError> { 388 | self.validate_session(&resp.session_id)?; 389 | 390 | self.operation_id = Some(resp.operation_id); 391 | self.response_id = Some(resp.response_id); 392 | 393 | if let Some(schema) = &resp.schema { 394 | self.handler.schema = Some(schema.clone()); 395 | } 396 | if let Some(metrics) = &resp.metrics { 397 | self.handler.metrics = Some(metrics.clone()); 398 | } 399 | if let Some(data) = resp.response_type { 400 | match data { 401 | ResponseType::ArrowBatch(res) => { 402 | self.deserialize(res.data.as_slice(), res.row_count)? 403 | } 404 | ResponseType::SqlCommandResult(sql_cmd) => { 405 | self.handler.sql_command_result = Some(sql_cmd.clone()) 406 | } 407 | ResponseType::WriteStreamOperationStartResult(write_stream_op) => { 408 | self.handler.write_stream_operation_start_result = Some(write_stream_op) 409 | } 410 | ResponseType::StreamingQueryCommandResult(stream_qry_cmd) => { 411 | self.handler.streaming_query_command_result = Some(stream_qry_cmd) 412 | } 413 | ResponseType::GetResourcesCommandResult(resource_cmd) => { 414 | self.handler.get_resources_command_result = Some(resource_cmd) 415 | } 416 | ResponseType::StreamingQueryManagerCommandResult(stream_qry_mngr_cmd) => { 417 | self.handler.streaming_query_manager_command_result = Some(stream_qry_mngr_cmd) 418 | } 419 | ResponseType::ResultComplete(_) => self.handler.result_complete = true, 420 | ResponseType::Extension(_) => { 421 | unimplemented!("extension response types are not implemented") 422 | } 423 | } 424 | } 425 | Ok(()) 426 | } 427 | 428 | fn handle_analyze( 429 | &mut self, 430 | resp: spark::AnalyzePlanResponse, 431 | ) -> Result<&mut Self, SparkError> { 432 | self.validate_session(&resp.session_id)?; 433 | if let Some(result) = resp.result { 434 | match result { 435 | spark::analyze_plan_response::Result::Schema(schema) => { 436 | self.analyzer.schema = schema.schema 437 | } 438 | spark::analyze_plan_response::Result::Explain(explain) => { 439 | self.analyzer.explain = Some(explain.explain_string) 440 | } 441 | spark::analyze_plan_response::Result::TreeString(tree_string) => { 442 | self.analyzer.tree_string = Some(tree_string.tree_string) 443 | } 444 | spark::analyze_plan_response::Result::IsLocal(is_local) => { 445 | self.analyzer.is_local = Some(is_local.is_local) 446 | } 447 | spark::analyze_plan_response::Result::IsStreaming(is_streaming) => { 448 | self.analyzer.is_streaming = Some(is_streaming.is_streaming) 449 | } 450 | spark::analyze_plan_response::Result::InputFiles(input_files) => { 451 | self.analyzer.input_files = Some(input_files.files) 452 | } 453 | spark::analyze_plan_response::Result::SparkVersion(spark_version) => { 454 | self.analyzer.spark_version = Some(spark_version.version) 455 | } 456 | spark::analyze_plan_response::Result::DdlParse(ddl_parse) => { 457 | self.analyzer.ddl_parse = ddl_parse.parsed 458 | } 459 | spark::analyze_plan_response::Result::SameSemantics(same_semantics) => { 460 | self.analyzer.same_semantics = Some(same_semantics.result) 461 | } 462 | spark::analyze_plan_response::Result::SemanticHash(semantic_hash) => { 463 | self.analyzer.semantic_hash = Some(semantic_hash.result) 464 | } 465 | spark::analyze_plan_response::Result::Persist(_) => {} 466 | spark::analyze_plan_response::Result::Unpersist(_) => {} 467 | spark::analyze_plan_response::Result::GetStorageLevel(level) => { 468 | self.analyzer.get_storage_level = level.storage_level 469 | } 470 | } 471 | } 472 | 473 | Ok(self) 474 | } 475 | 476 | fn validate_session(&self, session_id: &str) -> Result<(), SparkError> { 477 | if self.builder.session_id.to_string() != session_id { 478 | return Err(SparkError::AnalysisException(format!( 479 | "Received incorrect session identifier for request: {0} != {1}", 480 | self.builder.session_id, session_id 481 | ))); 482 | } 483 | Ok(()) 484 | } 485 | 486 | fn deserialize(&mut self, res: &[u8], row_count: i64) -> Result<(), SparkError> { 487 | let reader = StreamReader::try_new(res, None)?; 488 | for batch in reader { 489 | let record = batch?; 490 | if record.num_rows() != row_count as usize { 491 | return Err(SparkError::ArrowError(ArrowError::IpcError(format!( 492 | "Expected {} rows in arrow batch but got {}", 493 | row_count, 494 | record.num_rows() 495 | )))); 496 | }; 497 | self.handler.batches.push(record); 498 | self.handler.total_count += row_count as isize; 499 | } 500 | Ok(()) 501 | } 502 | 503 | pub async fn execute_command(&mut self, plan: spark::Plan) -> Result<(), SparkError> { 504 | let mut req = self.execute_plan_request_with_metadata(); 505 | 506 | req.plan = Some(plan); 507 | 508 | self.execute_and_fetch(req).await?; 509 | 510 | Ok(()) 511 | } 512 | 513 | pub(crate) async fn execute_command_and_fetch( 514 | &mut self, 515 | plan: spark::Plan, 516 | ) -> Result { 517 | let mut req = self.execute_plan_request_with_metadata(); 518 | 519 | req.plan = Some(plan); 520 | 521 | self.execute_and_fetch(req).await?; 522 | 523 | Ok(self.handler.clone()) 524 | } 525 | 526 | #[allow(clippy::wrong_self_convention)] 527 | pub async fn to_arrow(&mut self, plan: spark::Plan) -> Result { 528 | let mut req = self.execute_plan_request_with_metadata(); 529 | 530 | req.plan = Some(plan); 531 | 532 | self.execute_and_fetch(req).await?; 533 | 534 | Ok(concat_batches( 535 | &self.handler.batches[0].schema(), 536 | &self.handler.batches, 537 | )?) 538 | } 539 | 540 | #[allow(clippy::wrong_self_convention)] 541 | pub(crate) async fn to_first_value(&mut self, plan: spark::Plan) -> Result { 542 | let rows = self.to_arrow(plan).await?; 543 | let col = rows.column(0); 544 | 545 | let data: &arrow::array::StringArray = match col.data_type() { 546 | arrow::datatypes::DataType::Utf8 => col.as_any().downcast_ref().unwrap(), 547 | _ => unimplemented!("only Utf8 data types are currently handled currently."), 548 | }; 549 | 550 | Ok(data.value(0).to_string()) 551 | } 552 | 553 | pub fn schema(&self) -> Result { 554 | self.analyzer 555 | .schema 556 | .to_owned() 557 | .ok_or_else(|| SparkError::AnalysisException("Schema response is empty".to_string())) 558 | } 559 | 560 | pub fn explain(&self) -> Result { 561 | self.analyzer 562 | .explain 563 | .to_owned() 564 | .ok_or_else(|| SparkError::AnalysisException("Explain response is empty".to_string())) 565 | } 566 | 567 | pub fn tree_string(&self) -> Result { 568 | self.analyzer.tree_string.to_owned().ok_or_else(|| { 569 | SparkError::AnalysisException("Tree String response is empty".to_string()) 570 | }) 571 | } 572 | 573 | pub fn is_local(&self) -> Result { 574 | self.analyzer 575 | .is_local 576 | .to_owned() 577 | .ok_or_else(|| SparkError::AnalysisException("Is Local response is empty".to_string())) 578 | } 579 | 580 | pub fn is_streaming(&self) -> Result { 581 | self.analyzer.is_streaming.to_owned().ok_or_else(|| { 582 | SparkError::AnalysisException("Is Streaming response is empty".to_string()) 583 | }) 584 | } 585 | 586 | pub fn input_files(&self) -> Result, SparkError> { 587 | self.analyzer.input_files.to_owned().ok_or_else(|| { 588 | SparkError::AnalysisException("Input Files response is empty".to_string()) 589 | }) 590 | } 591 | 592 | pub fn spark_version(&mut self) -> Result { 593 | self.analyzer.spark_version.to_owned().ok_or_else(|| { 594 | SparkError::AnalysisException("Spark Version resonse is empty".to_string()) 595 | }) 596 | } 597 | 598 | pub fn ddl_parse(&self) -> Result { 599 | self.analyzer 600 | .ddl_parse 601 | .to_owned() 602 | .ok_or_else(|| SparkError::AnalysisException("DDL parse response is empty".to_string())) 603 | } 604 | 605 | pub fn same_semantics(&self) -> Result { 606 | self.analyzer.same_semantics.to_owned().ok_or_else(|| { 607 | SparkError::AnalysisException("Same Semantics response is empty".to_string()) 608 | }) 609 | } 610 | 611 | pub fn semantic_hash(&self) -> Result { 612 | self.analyzer.semantic_hash.to_owned().ok_or_else(|| { 613 | SparkError::AnalysisException("Semantic Hash response is empty".to_string()) 614 | }) 615 | } 616 | 617 | pub fn get_storage_level(&self) -> Result { 618 | self.analyzer.get_storage_level.to_owned().ok_or_else(|| { 619 | SparkError::AnalysisException("Storage Level response is empty".to_string()) 620 | }) 621 | } 622 | } 623 | -------------------------------------------------------------------------------- /crates/connect/src/column.rs: -------------------------------------------------------------------------------- 1 | //! [Column] represents a column in a DataFrame that holds a [spark::Expression] 2 | use std::convert::From; 3 | use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Not, Rem, Sub}; 4 | 5 | use crate::spark; 6 | 7 | use crate::functions::invoke_func; 8 | use crate::window::WindowSpec; 9 | 10 | use spark::expression::cast::CastToType; 11 | 12 | /// # A column in a DataFrame. 13 | /// 14 | /// A column holds a specific [spark::Expression] which will be resolved once an action is called. 15 | /// The columns are resolved by the Spark Connect server of the remote session. 16 | /// 17 | /// A column instance can be created by in a similar way as to the Spark API. A column with created 18 | /// with `col("*")` or `col("name.*")` is created as an unresolved star attribute which will select 19 | /// all columns or references in the specified column. 20 | /// 21 | /// ```rust 22 | /// use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 23 | /// 24 | /// let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs".to_string()) 25 | /// .build() 26 | /// .await?; 27 | /// 28 | /// // As a &str representing an unresolved column in the dataframe 29 | /// spark.range(None, 1, 1, Some(1)).select(["id"]); 30 | /// 31 | /// // By using the `col` function 32 | /// spark.range(None, 1, 1, Some(1)).select([col("id")]); 33 | /// 34 | /// // By using the `lit` function to return a literal value 35 | /// spark.range(None, 1, 1, Some(1)).select([lit(4.0).alias("num_col")]); 36 | /// ``` 37 | #[derive(Clone, Debug)] 38 | pub struct Column { 39 | /// a [spark::Expression] containing any unresolved value to be leveraged in a [spark::Plan] 40 | pub expression: spark::Expression, 41 | } 42 | 43 | impl Column { 44 | #[allow(clippy::should_implement_trait)] 45 | pub fn from_str(s: &str) -> Self { 46 | Self::from(s) 47 | } 48 | 49 | pub fn from_string(s: String) -> Self { 50 | Self::from(s.as_str()) 51 | } 52 | 53 | /// Returns the column with a new name 54 | /// 55 | /// # Example: 56 | /// ```rust 57 | /// let cols = [ 58 | /// col("name").alias("new_name"), 59 | /// col("age").alias("new_age") 60 | /// ]; 61 | /// 62 | /// df.select(cols); 63 | /// ``` 64 | pub fn alias(self, value: &str) -> Column { 65 | let alias = spark::expression::Alias { 66 | expr: Some(Box::new(self.expression)), 67 | name: vec![value.to_string()], 68 | metadata: None, 69 | }; 70 | 71 | let expression = spark::Expression { 72 | expr_type: Some(spark::expression::ExprType::Alias(Box::new(alias))), 73 | }; 74 | 75 | Column::from(expression) 76 | } 77 | 78 | /// An alias for the function `alias` 79 | pub fn name(self, value: &str) -> Column { 80 | self.alias(value) 81 | } 82 | 83 | /// Returns a sorted expression based on the ascending order of the column 84 | /// 85 | /// # Example: 86 | /// ```rust 87 | /// let df: DataFrame = df.sort([col("id").asc()]); 88 | /// 89 | /// let df: DataFrame = df.sort([asc(col("id"))]); 90 | /// ``` 91 | pub fn asc(self) -> Column { 92 | self.asc_nulls_first() 93 | } 94 | 95 | pub fn asc_nulls_first(self) -> Column { 96 | let asc = spark::expression::SortOrder { 97 | child: Some(Box::new(self.expression)), 98 | direction: 1, 99 | null_ordering: 1, 100 | }; 101 | 102 | let expression = spark::Expression { 103 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 104 | }; 105 | 106 | Column::from(expression) 107 | } 108 | 109 | pub fn asc_nulls_last(self) -> Column { 110 | let asc = spark::expression::SortOrder { 111 | child: Some(Box::new(self.expression)), 112 | direction: 1, 113 | null_ordering: 2, 114 | }; 115 | 116 | let expression = spark::Expression { 117 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 118 | }; 119 | 120 | Column::from(expression) 121 | } 122 | 123 | /// Returns a sorted expression based on the ascending order of the column 124 | /// 125 | /// # Example: 126 | /// ```rust 127 | /// let df: DataFrame = df.sort(col("id").desc()); 128 | /// 129 | /// let df: DataFrame = df.sort(desc(col("id"))); 130 | /// ``` 131 | pub fn desc(self) -> Column { 132 | self.desc_nulls_first() 133 | } 134 | 135 | pub fn desc_nulls_first(self) -> Column { 136 | let asc = spark::expression::SortOrder { 137 | child: Some(Box::new(self.expression)), 138 | direction: 2, 139 | null_ordering: 1, 140 | }; 141 | 142 | let expression = spark::Expression { 143 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 144 | }; 145 | 146 | Column::from(expression) 147 | } 148 | 149 | pub fn desc_nulls_last(self) -> Column { 150 | let asc = spark::expression::SortOrder { 151 | child: Some(Box::new(self.expression)), 152 | direction: 2, 153 | null_ordering: 2, 154 | }; 155 | 156 | let expression = spark::Expression { 157 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 158 | }; 159 | 160 | Column::from(expression) 161 | } 162 | 163 | pub fn drop_fields(self, field_names: I) -> Column 164 | where 165 | I: IntoIterator>, 166 | { 167 | let mut parent_col = self.expression; 168 | 169 | for field in field_names { 170 | parent_col = spark::Expression { 171 | expr_type: Some(spark::expression::ExprType::UpdateFields(Box::new( 172 | spark::expression::UpdateFields { 173 | struct_expression: Some(Box::new(parent_col)), 174 | field_name: field.as_ref().to_string(), 175 | value_expression: None, 176 | }, 177 | ))), 178 | }; 179 | } 180 | 181 | Column::from(parent_col) 182 | } 183 | 184 | pub fn with_field(self, field_name: &str, col: impl Into) -> Column { 185 | let update_field = spark::Expression { 186 | expr_type: Some(spark::expression::ExprType::UpdateFields(Box::new( 187 | spark::expression::UpdateFields { 188 | struct_expression: Some(Box::new(self.expression)), 189 | field_name: field_name.to_string(), 190 | value_expression: Some(Box::new(col.into().expression)), 191 | }, 192 | ))), 193 | }; 194 | 195 | Column::from(update_field) 196 | } 197 | 198 | pub fn substr(self, start_pos: impl Into, length: impl Into) -> Column { 199 | invoke_func("substr", vec![self, start_pos.into(), length.into()]) 200 | } 201 | 202 | /// Casts the column into the Spark DataType 203 | /// 204 | /// # Arguments: 205 | /// 206 | /// * `to_type` is a string or [crate::types::DataType] of the target type 207 | /// 208 | /// # Example: 209 | /// ```rust 210 | /// use crate::types::DataType; 211 | /// 212 | /// let df = df.select([ 213 | /// col("age").cast("int"), 214 | /// col("name").cast("string") 215 | /// ]) 216 | /// 217 | /// // Using DataTypes 218 | /// let df = df.select([ 219 | /// col("age").cast(DataType::Integer), 220 | /// col("name").cast(DataType::String) 221 | /// ]) 222 | /// ``` 223 | pub fn cast(self, to_type: impl Into) -> Column { 224 | let cast = spark::expression::Cast { 225 | expr: Some(Box::new(self.expression)), 226 | cast_to_type: Some(to_type.into()), 227 | }; 228 | 229 | let expression = spark::Expression { 230 | expr_type: Some(spark::expression::ExprType::Cast(Box::new(cast))), 231 | }; 232 | 233 | Column::from(expression) 234 | } 235 | 236 | /// A boolean expression that is evaluated to `true` if the value of the expression is 237 | /// contained by the evaluated values of the arguments 238 | /// 239 | /// # Arguments: 240 | /// 241 | /// * `cols` a vector of Columns 242 | /// 243 | /// # Example: 244 | /// ```rust 245 | /// df.filter(col("name").isin([lit("Jorge"), lit("Bob")])); 246 | /// ``` 247 | pub fn isin(self, cols: Vec) -> Column { 248 | let mut val = cols.clone(); 249 | 250 | val.insert(0, self); 251 | 252 | invoke_func("in", val) 253 | } 254 | 255 | /// A boolean expression that is evaluated to `true` if the value is in the Column 256 | /// 257 | /// # Arguments: 258 | /// 259 | /// * `cols`: a col reference that is translated into an [spark::Expression] 260 | /// 261 | /// # Example: 262 | /// ```rust 263 | /// df.filter(col("name").contains("ge")); 264 | /// ``` 265 | pub fn contains(self, other: impl Into) -> Column { 266 | invoke_func("contains", vec![self, other.into()]) 267 | } 268 | 269 | /// A filter expression that evaluates if the column startswith a string literal 270 | pub fn startswith(self, other: impl Into) -> Column { 271 | invoke_func("startswith", vec![self, other.into()]) 272 | } 273 | 274 | /// A filter expression that evaluates if the column endswith a string literal 275 | pub fn endswith(self, other: impl Into) -> Column { 276 | invoke_func("endswith", vec![self, other.into()]) 277 | } 278 | 279 | /// A SQL LIKE filter expression that evaluates the column based on a case sensitive match 280 | pub fn like(self, other: impl Into) -> Column { 281 | invoke_func("like", vec![self, other.into()]) 282 | } 283 | 284 | /// A SQL ILIKE filter expression that evaluates the column based on a case insensitive match 285 | pub fn ilike(self, other: impl Into) -> Column { 286 | invoke_func("ilike", vec![self, other.into()]) 287 | } 288 | 289 | /// A SQL RLIKE filter expression that evaluates the column based on a regex match 290 | pub fn rlike(self, other: impl Into) -> Column { 291 | invoke_func("rlike", vec![self, other.into()]) 292 | } 293 | 294 | /// Equality comparion. Cannot overload the '==' and return something other 295 | /// than a bool 296 | pub fn eq(self, other: impl Into) -> Column { 297 | invoke_func("==", vec![self, other.into()]) 298 | } 299 | 300 | /// Logical AND comparion. Cannot overload the '&&' and return something other 301 | /// than a bool 302 | pub fn and(self, other: impl Into) -> Column { 303 | invoke_func("and", vec![self, other.into()]) 304 | } 305 | 306 | /// Logical OR comparion. 307 | pub fn or(self, other: impl Into) -> Column { 308 | invoke_func("or", vec![self, other.into()]) 309 | } 310 | 311 | /// A filter expression that evaluates to true is the expression is null 312 | pub fn is_null(self) -> Column { 313 | invoke_func("isnull", vec![self]) 314 | } 315 | 316 | /// A filter expression that evaluates to true is the expression is NOT null 317 | pub fn is_not_null(self) -> Column { 318 | invoke_func("isnotnull", vec![self]) 319 | } 320 | 321 | pub fn is_nan(self) -> Column { 322 | invoke_func("isNaN", vec![self]) 323 | } 324 | 325 | /// Defines a windowing column 326 | /// # Arguments: 327 | /// 328 | /// * `window`: a [WindowSpec] 329 | /// 330 | /// # Example 331 | /// 332 | /// ``` 333 | /// let window = Window::new() 334 | /// .partition_by([col("name")]) 335 | /// .order_by([col("age")]) 336 | /// .range_between(Window::unbounded_preceding(), Window::current_row()); 337 | /// 338 | /// let df = df.with_column("rank", rank().over(window.clone())) 339 | /// .with_column("min", min("age").over(window)); 340 | /// ``` 341 | pub fn over(self, window: WindowSpec) -> Column { 342 | let window_expr = spark::expression::Window { 343 | window_function: Some(Box::new(self.expression)), 344 | partition_spec: window.partition_spec, 345 | order_spec: window.order_spec, 346 | frame_spec: window.frame_spec, 347 | }; 348 | 349 | let expression = spark::Expression { 350 | expr_type: Some(spark::expression::ExprType::Window(Box::new(window_expr))), 351 | }; 352 | 353 | Column::from(expression) 354 | } 355 | } 356 | 357 | impl From for Column { 358 | /// Used for creating columns from a [spark::Expression] 359 | fn from(expression: spark::Expression) -> Self { 360 | Self { expression } 361 | } 362 | } 363 | 364 | impl From for Column { 365 | /// Used for creating columns from a [spark::Expression] 366 | fn from(expression: spark::expression::Literal) -> Self { 367 | Self::from(spark::Expression { 368 | expr_type: Some(spark::expression::ExprType::Literal(expression)), 369 | }) 370 | } 371 | } 372 | 373 | impl From for Column { 374 | fn from(value: String) -> Self { 375 | Column::from_string(value) 376 | } 377 | } 378 | 379 | impl From<&String> for Column { 380 | fn from(value: &String) -> Self { 381 | Column::from_str(value.as_str()) 382 | } 383 | } 384 | 385 | impl From<&str> for Column { 386 | /// `&str` values containing a `*` will be created as an unresolved star expression 387 | /// Otherwise, the value is created as an unresolved attribute 388 | fn from(value: &str) -> Self { 389 | let expression = match value { 390 | "*" => spark::Expression { 391 | expr_type: Some(spark::expression::ExprType::UnresolvedStar( 392 | spark::expression::UnresolvedStar { 393 | unparsed_target: None, 394 | }, 395 | )), 396 | }, 397 | value if value.ends_with(".*") => spark::Expression { 398 | expr_type: Some(spark::expression::ExprType::UnresolvedStar( 399 | spark::expression::UnresolvedStar { 400 | unparsed_target: Some(value.to_string()), 401 | }, 402 | )), 403 | }, 404 | _ => spark::Expression { 405 | expr_type: Some(spark::expression::ExprType::UnresolvedAttribute( 406 | spark::expression::UnresolvedAttribute { 407 | unparsed_identifier: value.to_string(), 408 | plan_id: None, 409 | }, 410 | )), 411 | }, 412 | }; 413 | 414 | Column::from(expression) 415 | } 416 | } 417 | 418 | impl Add for Column { 419 | type Output = Self; 420 | 421 | fn add(self, other: Self) -> Self { 422 | invoke_func("+", vec![self, other]) 423 | } 424 | } 425 | 426 | impl Neg for Column { 427 | type Output = Self; 428 | 429 | fn neg(self) -> Self { 430 | invoke_func("negative", vec![self]) 431 | } 432 | } 433 | 434 | impl Sub for Column { 435 | type Output = Self; 436 | 437 | fn sub(self, other: Self) -> Self { 438 | invoke_func("-", vec![self, other]) 439 | } 440 | } 441 | 442 | impl Mul for Column { 443 | type Output = Self; 444 | 445 | fn mul(self, other: Self) -> Self { 446 | invoke_func("*", vec![self, other]) 447 | } 448 | } 449 | 450 | impl Div for Column { 451 | type Output = Self; 452 | 453 | fn div(self, other: Self) -> Self { 454 | invoke_func("/", vec![self, other]) 455 | } 456 | } 457 | 458 | impl Rem for Column { 459 | type Output = Self; 460 | 461 | fn rem(self, other: Self) -> Self { 462 | invoke_func("%", vec![self, other]) 463 | } 464 | } 465 | 466 | impl BitOr for Column { 467 | type Output = Self; 468 | 469 | fn bitor(self, other: Self) -> Self { 470 | invoke_func("|", vec![self, other]) 471 | } 472 | } 473 | 474 | impl BitAnd for Column { 475 | type Output = Self; 476 | 477 | fn bitand(self, other: Self) -> Self { 478 | invoke_func("&", vec![self, other]) 479 | } 480 | } 481 | 482 | impl BitXor for Column { 483 | type Output = Self; 484 | 485 | fn bitxor(self, other: Self) -> Self { 486 | invoke_func("^", vec![self, other]) 487 | } 488 | } 489 | 490 | impl Not for Column { 491 | type Output = Self; 492 | 493 | fn not(self) -> Self::Output { 494 | invoke_func("not", vec![self]) 495 | } 496 | } 497 | -------------------------------------------------------------------------------- /crates/connect/src/conf.rs: -------------------------------------------------------------------------------- 1 | //! Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. 2 | 3 | use std::collections::HashMap; 4 | 5 | use crate::spark; 6 | 7 | use crate::client::SparkClient; 8 | use crate::errors::SparkError; 9 | 10 | /// User-facing configuration API, accessible through SparkSession.conf. 11 | pub struct RunTimeConfig { 12 | pub(crate) client: SparkClient, 13 | } 14 | 15 | /// User-facing configuration API, accessible through SparkSession.conf. 16 | /// 17 | /// Options set here are automatically propagated to the Hadoop configuration during I/O. 18 | /// 19 | /// # Example 20 | /// ```rust 21 | /// spark 22 | /// .conf() 23 | /// .set("spark.sql.shuffle.partitions", "42") 24 | /// .await?; 25 | /// ``` 26 | impl RunTimeConfig { 27 | pub fn new(client: &SparkClient) -> RunTimeConfig { 28 | RunTimeConfig { 29 | client: client.clone(), 30 | } 31 | } 32 | 33 | pub(crate) async fn set_configs( 34 | &mut self, 35 | map: &HashMap, 36 | ) -> Result<(), SparkError> { 37 | for (key, value) in map { 38 | self.set(key.as_str(), value.as_str()).await? 39 | } 40 | Ok(()) 41 | } 42 | 43 | /// Sets the given Spark runtime configuration property. 44 | pub async fn set(&mut self, key: &str, value: &str) -> Result<(), SparkError> { 45 | let op_type = spark::config_request::operation::OpType::Set(spark::config_request::Set { 46 | pairs: vec![spark::KeyValue { 47 | key: key.into(), 48 | value: Some(value.into()), 49 | }], 50 | }); 51 | let operation = spark::config_request::Operation { 52 | op_type: Some(op_type), 53 | }; 54 | 55 | let _ = self.client.config_request(operation).await?; 56 | 57 | Ok(()) 58 | } 59 | 60 | /// Resets the configuration property for the given key. 61 | pub async fn unset(&mut self, key: &str) -> Result<(), SparkError> { 62 | let op_type = 63 | spark::config_request::operation::OpType::Unset(spark::config_request::Unset { 64 | keys: vec![key.to_string()], 65 | }); 66 | let operation = spark::config_request::Operation { 67 | op_type: Some(op_type), 68 | }; 69 | 70 | let _ = self.client.config_request(operation).await?; 71 | 72 | Ok(()) 73 | } 74 | 75 | /// Indicates whether the configuration property with the given key is modifiable in the current session. 76 | pub async fn get(&mut self, key: &str, default: Option<&str>) -> Result { 77 | let operation = match default { 78 | Some(default) => { 79 | let op_type = spark::config_request::operation::OpType::GetWithDefault( 80 | spark::config_request::GetWithDefault { 81 | pairs: vec![spark::KeyValue { 82 | key: key.into(), 83 | value: Some(default.into()), 84 | }], 85 | }, 86 | ); 87 | spark::config_request::Operation { 88 | op_type: Some(op_type), 89 | } 90 | } 91 | None => { 92 | let op_type = 93 | spark::config_request::operation::OpType::Get(spark::config_request::Get { 94 | keys: vec![key.to_string()], 95 | }); 96 | spark::config_request::Operation { 97 | op_type: Some(op_type), 98 | } 99 | } 100 | }; 101 | 102 | let resp = self.client.config_request(operation).await?; 103 | 104 | let val = resp.pairs.first().unwrap().value().to_string(); 105 | 106 | Ok(val) 107 | } 108 | 109 | /// Indicates whether the configuration property with the given key is modifiable in the current session. 110 | pub async fn is_modifable(&mut self, key: &str) -> Result { 111 | let op_type = spark::config_request::operation::OpType::IsModifiable( 112 | spark::config_request::IsModifiable { 113 | keys: vec![key.to_string()], 114 | }, 115 | ); 116 | let operation = spark::config_request::Operation { 117 | op_type: Some(op_type), 118 | }; 119 | 120 | let resp = self.client.config_request(operation).await?; 121 | 122 | let val = resp.pairs.first().unwrap().value(); 123 | 124 | match val { 125 | "true" => Ok(true), 126 | "false" => Ok(false), 127 | _ => Err(SparkError::AnalysisException( 128 | "Unexpected response value for boolean".to_string(), 129 | )), 130 | } 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /crates/connect/src/errors.rs: -------------------------------------------------------------------------------- 1 | //! Defines a [SparkError] for representing failures in various Spark operations. 2 | //! Most of these are wrappers for tonic or arrow error messages 3 | use std::error::Error; 4 | use std::fmt::Debug; 5 | use std::io::Write; 6 | 7 | use arrow::error::ArrowError; 8 | use thiserror::Error; 9 | 10 | use tonic::Code; 11 | 12 | #[cfg(feature = "datafusion")] 13 | use datafusion::error::DataFusionError; 14 | #[cfg(feature = "polars")] 15 | use polars::error::PolarsError; 16 | 17 | /// Different `Spark` Error types 18 | #[derive(Error, Debug)] 19 | pub enum SparkError { 20 | #[error("Aborted: {0}")] 21 | Aborted(String), 22 | 23 | #[error("Already Exists: {0}")] 24 | AlreadyExists(String), 25 | 26 | #[error("Analysis Exception: {0}")] 27 | AnalysisException(String), 28 | 29 | #[error("Apache Arrow Error: {0}")] 30 | ArrowError(#[from] ArrowError), 31 | 32 | #[error("Cancelled: {0}")] 33 | Cancelled(String), 34 | 35 | #[error("Data Loss Exception: {0}")] 36 | DataLoss(String), 37 | 38 | #[error("Deadline Exceeded: {0}")] 39 | DeadlineExceeded(String), 40 | 41 | #[error("External Error: {0}")] 42 | ExternalError(Box), 43 | 44 | #[error("Failed Precondition: {0}")] 45 | FailedPrecondition(String), 46 | 47 | #[error("Invalid Connection Url: {0}")] 48 | InvalidConnectionUrl(String), 49 | 50 | #[error("Invalid Argument: {0}")] 51 | InvalidArgument(String), 52 | 53 | #[error("Io Error: {0}")] 54 | IoError(String, std::io::Error), 55 | 56 | #[error("Not Found: {0}")] 57 | NotFound(String), 58 | 59 | #[error("Not Yet Implemented: {0}")] 60 | NotYetImplemented(String), 61 | 62 | #[error("Permission Denied: {0}")] 63 | PermissionDenied(String), 64 | 65 | #[error("Resource Exhausted: {0}")] 66 | ResourceExhausted(String), 67 | 68 | #[error("Spark Session ID is not the same: {0}")] 69 | SessionNotSameException(String), 70 | 71 | #[error("Unauthenticated: {0}")] 72 | Unauthenticated(String), 73 | 74 | #[error("Unavailable: {0}")] 75 | Unavailable(String), 76 | 77 | #[error("Unkown: {0}")] 78 | Unknown(String), 79 | 80 | #[error("Unimplemented; {0}")] 81 | Unimplemented(String), 82 | 83 | #[error("Invalid UUID")] 84 | Uuid(#[from] uuid::Error), 85 | 86 | #[error("Out of Range: {0}")] 87 | OutOfRange(String), 88 | } 89 | 90 | impl SparkError { 91 | /// Wraps an external error in an `SparkError`. 92 | pub fn from_external_error(error: Box) -> Self { 93 | Self::ExternalError(error) 94 | } 95 | } 96 | 97 | impl From for SparkError { 98 | fn from(error: std::io::Error) -> Self { 99 | SparkError::IoError(error.to_string(), error) 100 | } 101 | } 102 | 103 | impl From for SparkError { 104 | fn from(error: std::str::Utf8Error) -> Self { 105 | SparkError::AnalysisException(error.to_string()) 106 | } 107 | } 108 | 109 | impl From for SparkError { 110 | fn from(error: std::string::FromUtf8Error) -> Self { 111 | SparkError::AnalysisException(error.to_string()) 112 | } 113 | } 114 | 115 | impl From for SparkError { 116 | fn from(status: tonic::Status) -> Self { 117 | match status.code() { 118 | Code::Ok => SparkError::AnalysisException(status.message().to_string()), 119 | Code::Unknown => SparkError::Unknown(status.message().to_string()), 120 | Code::Aborted => SparkError::Aborted(status.message().to_string()), 121 | Code::NotFound => SparkError::NotFound(status.message().to_string()), 122 | Code::Internal => SparkError::AnalysisException(status.message().to_string()), 123 | Code::DataLoss => SparkError::DataLoss(status.message().to_string()), 124 | Code::Cancelled => SparkError::Cancelled(status.message().to_string()), 125 | Code::OutOfRange => SparkError::OutOfRange(status.message().to_string()), 126 | Code::Unavailable => SparkError::Unavailable(status.message().to_string()), 127 | Code::AlreadyExists => SparkError::AnalysisException(status.message().to_string()), 128 | Code::InvalidArgument => SparkError::InvalidArgument(status.message().to_string()), 129 | Code::DeadlineExceeded => SparkError::DeadlineExceeded(status.message().to_string()), 130 | Code::Unimplemented => SparkError::Unimplemented(status.message().to_string()), 131 | Code::Unauthenticated => SparkError::Unauthenticated(status.message().to_string()), 132 | Code::PermissionDenied => SparkError::PermissionDenied(status.message().to_string()), 133 | Code::ResourceExhausted => SparkError::ResourceExhausted(status.message().to_string()), 134 | Code::FailedPrecondition => { 135 | SparkError::FailedPrecondition(status.message().to_string()) 136 | } 137 | } 138 | } 139 | } 140 | 141 | impl From for SparkError { 142 | fn from(value: serde_json::Error) -> Self { 143 | SparkError::AnalysisException(value.to_string()) 144 | } 145 | } 146 | 147 | #[cfg(feature = "datafusion")] 148 | impl From for SparkError { 149 | fn from(_value: DataFusionError) -> Self { 150 | SparkError::AnalysisException("Error converting to DataFusion DataFrame".to_string()) 151 | } 152 | } 153 | 154 | #[cfg(feature = "polars")] 155 | impl From for SparkError { 156 | fn from(_value: PolarsError) -> Self { 157 | SparkError::AnalysisException("Error converting to Polars DataFrame".to_string()) 158 | } 159 | } 160 | 161 | impl From for SparkError { 162 | fn from(value: tonic::codegen::http::uri::InvalidUri) -> Self { 163 | SparkError::InvalidConnectionUrl(value.to_string()) 164 | } 165 | } 166 | 167 | impl From for SparkError { 168 | fn from(value: tonic::transport::Error) -> Self { 169 | SparkError::InvalidConnectionUrl(value.to_string()) 170 | } 171 | } 172 | 173 | impl From> for SparkError { 174 | fn from(error: std::io::IntoInnerError) -> Self { 175 | SparkError::IoError(error.to_string(), error.into()) 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /crates/connect/src/expressions.rs: -------------------------------------------------------------------------------- 1 | //! Traits for converting Rust Types to Spark Connect Expression Types 2 | //! 3 | //! Spark Connect has a few different ways of creating expressions and different gRPC methods 4 | //! require expressions in different forms. These traits are used to either translate a value into 5 | //! a [spark::Expression] or into a [spark::expression::Literal]. 6 | 7 | use chrono::NaiveDateTime; 8 | 9 | use crate::spark; 10 | 11 | use crate::column::Column; 12 | use crate::types::DataType; 13 | 14 | pub struct VecExpression { 15 | pub(super) expr: Vec, 16 | } 17 | 18 | impl FromIterator for VecExpression 19 | where 20 | T: Into, 21 | { 22 | fn from_iter>(iter: I) -> Self { 23 | let expr = iter 24 | .into_iter() 25 | .map(Into::into) 26 | .map(|col| col.expression) 27 | .collect(); 28 | 29 | VecExpression { expr } 30 | } 31 | } 32 | 33 | impl From for Vec { 34 | fn from(value: VecExpression) -> Self { 35 | value.expr 36 | } 37 | } 38 | 39 | impl<'a> From<&'a str> for VecExpression { 40 | fn from(value: &'a str) -> Self { 41 | VecExpression { 42 | expr: vec![Column::from_str(value).expression], 43 | } 44 | } 45 | } 46 | 47 | impl From for VecExpression { 48 | fn from(value: String) -> Self { 49 | VecExpression { 50 | expr: vec![Column::from_string(value).expression], 51 | } 52 | } 53 | } 54 | 55 | impl From for spark::Expression { 56 | fn from(value: String) -> Self { 57 | Column::from(value).expression 58 | } 59 | } 60 | 61 | impl<'a> From<&'a str> for spark::Expression { 62 | fn from(value: &'a str) -> Self { 63 | Column::from(value).expression 64 | } 65 | } 66 | 67 | impl From for spark::Expression { 68 | fn from(value: Column) -> Self { 69 | value.expression 70 | } 71 | } 72 | 73 | /// Create a filter expression 74 | pub trait ToFilterExpr { 75 | fn to_filter_expr(&self) -> Option; 76 | } 77 | 78 | impl ToFilterExpr for Column { 79 | fn to_filter_expr(&self) -> Option { 80 | Some(self.expression.clone()) 81 | } 82 | } 83 | 84 | impl ToFilterExpr for &str { 85 | fn to_filter_expr(&self) -> Option { 86 | let expr_type = Some(spark::expression::ExprType::ExpressionString( 87 | spark::expression::ExpressionString { 88 | expression: self.to_string(), 89 | }, 90 | )); 91 | 92 | Some(spark::Expression { expr_type }) 93 | } 94 | } 95 | 96 | /// Translate a rust value into a literal type 97 | pub trait ToLiteral { 98 | fn to_literal(&self) -> spark::expression::Literal; 99 | } 100 | 101 | macro_rules! impl_to_literal { 102 | ($type:ty, $inner_type:ident) => { 103 | impl From<$type> for spark::expression::Literal { 104 | fn from(value: $type) -> spark::expression::Literal { 105 | spark::expression::Literal { 106 | literal_type: Some(spark::expression::literal::LiteralType::$inner_type(value)), 107 | } 108 | } 109 | } 110 | }; 111 | } 112 | 113 | impl_to_literal!(bool, Boolean); 114 | impl_to_literal!(i32, Integer); 115 | impl_to_literal!(i64, Long); 116 | impl_to_literal!(f32, Float); 117 | impl_to_literal!(f64, Double); 118 | impl_to_literal!(String, String); 119 | 120 | impl From<&[u8]> for spark::expression::Literal { 121 | fn from(value: &[u8]) -> Self { 122 | spark::expression::Literal { 123 | literal_type: Some(spark::expression::literal::LiteralType::Binary(Vec::from( 124 | value, 125 | ))), 126 | } 127 | } 128 | } 129 | 130 | impl From for spark::expression::Literal { 131 | fn from(value: i16) -> Self { 132 | spark::expression::Literal { 133 | literal_type: Some(spark::expression::literal::LiteralType::Short(value as i32)), 134 | } 135 | } 136 | } 137 | 138 | impl<'a> From<&'a str> for spark::expression::Literal { 139 | fn from(value: &'a str) -> Self { 140 | spark::expression::Literal { 141 | literal_type: Some(spark::expression::literal::LiteralType::String( 142 | value.to_string(), 143 | )), 144 | } 145 | } 146 | } 147 | 148 | impl From> for spark::expression::Literal { 149 | fn from(value: chrono::DateTime) -> Self { 150 | // timestamps for spark have to be the microsends since 1/1/1970 151 | let timestamp = value.timestamp_micros(); 152 | 153 | spark::expression::Literal { 154 | literal_type: Some(spark::expression::literal::LiteralType::Timestamp( 155 | timestamp, 156 | )), 157 | } 158 | } 159 | } 160 | 161 | impl From for spark::expression::Literal { 162 | fn from(value: NaiveDateTime) -> Self { 163 | // timestamps for spark have to be the microsends since 1/1/1970 164 | let timestamp = value.and_utc().timestamp_micros(); 165 | 166 | spark::expression::Literal { 167 | literal_type: Some(spark::expression::literal::LiteralType::TimestampNtz( 168 | timestamp, 169 | )), 170 | } 171 | } 172 | } 173 | 174 | impl From for spark::expression::Literal { 175 | fn from(value: chrono::NaiveDate) -> Self { 176 | // Spark works based on unix time. I.e. seconds since 1/1/1970 177 | // to get dates to work you have to do this math 178 | let days_since_unix_epoch = 179 | value.signed_duration_since(chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()); 180 | 181 | spark::expression::Literal { 182 | literal_type: Some(spark::expression::literal::LiteralType::Date( 183 | days_since_unix_epoch.num_days() as i32, 184 | )), 185 | } 186 | } 187 | } 188 | 189 | impl From> for spark::expression::Literal 190 | where 191 | T: Into + Clone, 192 | spark::DataType: From, 193 | { 194 | fn from(value: Vec) -> Self { 195 | let element_type = Some(spark::DataType::from( 196 | value.first().expect("Array can not be empty").clone(), 197 | )); 198 | 199 | let elements = value.iter().map(|val| val.clone().into()).collect(); 200 | 201 | let array_type = spark::expression::literal::Array { 202 | element_type, 203 | elements, 204 | }; 205 | 206 | spark::expression::Literal { 207 | literal_type: Some(spark::expression::literal::LiteralType::Array(array_type)), 208 | } 209 | } 210 | } 211 | 212 | impl From<[T; N]> for spark::expression::Literal 213 | where 214 | T: Into + Clone, 215 | spark::DataType: From, 216 | { 217 | fn from(value: [T; N]) -> Self { 218 | let element_type = Some(spark::DataType::from( 219 | value.first().expect("Array can not be empty").clone(), 220 | )); 221 | 222 | let elements = value.iter().map(|val| val.clone().into()).collect(); 223 | 224 | let array_type = spark::expression::literal::Array { 225 | element_type, 226 | elements, 227 | }; 228 | 229 | spark::expression::Literal { 230 | literal_type: Some(spark::expression::literal::LiteralType::Array(array_type)), 231 | } 232 | } 233 | } 234 | 235 | impl From<&str> for spark::expression::cast::CastToType { 236 | fn from(value: &str) -> Self { 237 | spark::expression::cast::CastToType::TypeStr(value.to_string()) 238 | } 239 | } 240 | 241 | impl From for spark::expression::cast::CastToType { 242 | fn from(value: String) -> Self { 243 | spark::expression::cast::CastToType::TypeStr(value) 244 | } 245 | } 246 | 247 | impl From for spark::expression::cast::CastToType { 248 | fn from(value: DataType) -> spark::expression::cast::CastToType { 249 | spark::expression::cast::CastToType::Type(value.into()) 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /crates/connect/src/group.rs: -------------------------------------------------------------------------------- 1 | //! A DataFrame created with an aggregate statement 2 | 3 | use crate::column::Column; 4 | use crate::dataframe::DataFrame; 5 | use crate::plan::LogicalPlanBuilder; 6 | 7 | use crate::functions::{invoke_func, lit}; 8 | 9 | use crate::spark; 10 | use crate::spark::aggregate::GroupType; 11 | 12 | /// A set of methods for aggregations on a [DataFrame], created by DataFrame.groupBy(). 13 | #[derive(Clone, Debug)] 14 | pub struct GroupedData { 15 | df: DataFrame, 16 | group_type: GroupType, 17 | grouping_cols: Vec, 18 | pivot_col: Option, 19 | pivot_vals: Option>, 20 | } 21 | 22 | impl GroupedData { 23 | pub fn new( 24 | df: DataFrame, 25 | group_type: GroupType, 26 | grouping_cols: Vec, 27 | pivot_col: Option, 28 | pivot_vals: Option>, 29 | ) -> GroupedData { 30 | Self { 31 | df, 32 | group_type, 33 | grouping_cols, 34 | pivot_col, 35 | pivot_vals, 36 | } 37 | } 38 | 39 | /// Compute aggregates and returns the result as a [DataFrame] 40 | pub fn agg(self, exprs: I) -> DataFrame 41 | where 42 | I: IntoIterator, 43 | S: Into, 44 | { 45 | let plan = LogicalPlanBuilder::aggregate( 46 | self.df.plan, 47 | self.group_type, 48 | self.grouping_cols, 49 | exprs, 50 | self.pivot_col, 51 | self.pivot_vals, 52 | ); 53 | 54 | DataFrame { 55 | spark_session: self.df.spark_session, 56 | plan, 57 | } 58 | } 59 | 60 | /// Computes average values for each numeric columns for each group. 61 | pub fn avg(self, cols: I) -> DataFrame 62 | where 63 | I: IntoIterator, 64 | S: Into, 65 | { 66 | self.agg([invoke_func("avg", cols)]) 67 | } 68 | 69 | /// Computes the min value for each numeric column for each group. 70 | pub fn min(self, cols: I) -> DataFrame 71 | where 72 | I: IntoIterator, 73 | S: Into, 74 | { 75 | self.agg([invoke_func("min", cols)]) 76 | } 77 | 78 | /// Computes the max value for each numeric columns for each group. 79 | pub fn max(self, cols: I) -> DataFrame 80 | where 81 | I: IntoIterator, 82 | S: Into, 83 | { 84 | self.agg([invoke_func("max", cols)]) 85 | } 86 | 87 | /// Computes the sum for each numeric columns for each group. 88 | pub fn sum(self, cols: I) -> DataFrame 89 | where 90 | I: IntoIterator, 91 | S: Into, 92 | { 93 | self.agg([invoke_func("sum", cols)]) 94 | } 95 | 96 | /// Counts the number of records for each group. 97 | pub fn count(self) -> DataFrame { 98 | self.agg([invoke_func("count", [lit(1).alias("count")])]) 99 | } 100 | 101 | /// Pivots a column of the current [DataFrame] and perform the specified aggregation 102 | pub fn pivot(self, col: &str, values: Option>) -> GroupedData { 103 | let pivot_vals = values.map(|vals| vals.iter().map(|val| val.to_string().into()).collect()); 104 | 105 | GroupedData::new( 106 | self.df, 107 | GroupType::Pivot, 108 | self.grouping_cols, 109 | Some(Column::from(col).into()), 110 | pivot_vals, 111 | ) 112 | } 113 | } 114 | 115 | #[cfg(test)] 116 | mod tests { 117 | 118 | use arrow::array::{ArrayRef, Int64Array, StringArray}; 119 | use arrow::datatypes::{DataType, Field, Schema}; 120 | use arrow::record_batch::RecordBatch; 121 | use std::sync::Arc; 122 | 123 | use crate::errors::SparkError; 124 | use crate::SparkSession; 125 | use crate::SparkSessionBuilder; 126 | 127 | use crate::functions::col; 128 | 129 | use crate::column::Column; 130 | 131 | async fn setup() -> SparkSession { 132 | println!("SparkSession Setup"); 133 | 134 | let connection = 135 | "sc://127.0.0.1:15002/;user_id=rust_group;session_id=02c25694-e875-4a25-9955-bc5bc56c4ade"; 136 | 137 | SparkSessionBuilder::remote(connection) 138 | .build() 139 | .await 140 | .unwrap() 141 | } 142 | 143 | #[tokio::test] 144 | async fn test_group_count() -> Result<(), SparkError> { 145 | let spark = setup().await; 146 | 147 | let df = spark.range(None, 100, 1, Some(8)); 148 | 149 | let res = df.group_by::>(None).count().collect().await?; 150 | 151 | let a: ArrayRef = Arc::new(Int64Array::from(vec![100])); 152 | 153 | let expected = RecordBatch::try_from_iter(vec![("count(1 AS count)", a)])?; 154 | 155 | assert_eq!(expected, res); 156 | Ok(()) 157 | } 158 | 159 | #[tokio::test] 160 | async fn test_group_pivot() -> Result<(), SparkError> { 161 | let spark = setup().await; 162 | 163 | let course: ArrayRef = Arc::new(StringArray::from(vec![ 164 | "dotNET", "Java", "dotNET", "dotNET", "Java", 165 | ])); 166 | let year: ArrayRef = Arc::new(Int64Array::from(vec![2012, 2012, 2012, 2013, 2013])); 167 | let earnings: ArrayRef = Arc::new(Int64Array::from(vec![10000, 20000, 5000, 48000, 30000])); 168 | 169 | let data = RecordBatch::try_from_iter(vec![ 170 | ("course", course), 171 | ("year", year), 172 | ("earnings", earnings), 173 | ])?; 174 | 175 | let df = spark.create_dataframe(&data)?; 176 | 177 | let res = df 178 | .clone() 179 | .group_by(Some([col("year")])) 180 | .pivot("course", Some(vec!["Java"])) 181 | .sum(["earnings"]) 182 | .collect() 183 | .await?; 184 | 185 | let year: ArrayRef = Arc::new(Int64Array::from(vec![2012, 2013])); 186 | let earnings: ArrayRef = Arc::new(Int64Array::from(vec![20000, 30000])); 187 | 188 | let schema = Schema::new(vec![ 189 | Field::new("year", DataType::Int64, false), 190 | Field::new("Java", DataType::Int64, true), 191 | ]); 192 | 193 | let expected = RecordBatch::try_new(Arc::new(schema), vec![year, earnings])?; 194 | 195 | assert_eq!(expected, res); 196 | 197 | let res = df 198 | .group_by(Some([col("year")])) 199 | .pivot("course", None) 200 | .sum(["earnings"]) 201 | .collect() 202 | .await?; 203 | 204 | let year: ArrayRef = Arc::new(Int64Array::from(vec![2012, 2013])); 205 | let java_earnings: ArrayRef = Arc::new(Int64Array::from(vec![20000, 30000])); 206 | let dnet_earnings: ArrayRef = Arc::new(Int64Array::from(vec![15000, 48000])); 207 | 208 | let schema = Schema::new(vec![ 209 | Field::new("year", DataType::Int64, false), 210 | Field::new("Java", DataType::Int64, true), 211 | Field::new("dotNET", DataType::Int64, true), 212 | ]); 213 | 214 | let expected = 215 | RecordBatch::try_new(Arc::new(schema), vec![year, java_earnings, dnet_earnings])?; 216 | 217 | assert_eq!(expected, res); 218 | 219 | Ok(()) 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /crates/connect/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Spark Connection Client for Rust 2 | //! 3 | //! Currently, the Spark Connect client for Rust is **highly experimental** and **should 4 | //! not be used in any production setting**. This is currently a "proof of concept" to identify the methods 5 | //! of interacting with Spark cluster from rust. 6 | //! 7 | //! # Quickstart 8 | //! 9 | //! Create a Spark Session and create a [DataFrame] from a [arrow::array::RecordBatch]. 10 | //! 11 | //! ```rust 12 | //! use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 13 | //! use spark_connect_rs::functions::{col, lit} 14 | //! 15 | //! #[tokio::main] 16 | //! async fn main() -> Result<(), Box> { 17 | //! 18 | //! let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") 19 | //! .build() 20 | //! .await?; 21 | //! 22 | //! let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Alice", "Bob"])); 23 | //! let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16])); 24 | //! 25 | //! let data = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])? 26 | //! 27 | //! let df = spark.create_dataframe(&data).await? 28 | //! 29 | //! // 2 records total 30 | //! let records = df.select(["*"]) 31 | //! .with_column("age_plus", col("age") + lit(4)) 32 | //! .filter(col("name").contains("o")) 33 | //! .count() 34 | //! .await?; 35 | //! 36 | //! Ok(()) 37 | //! }; 38 | //!``` 39 | //! 40 | //! Create a Spark Session and create a DataFrame from a SQL statement: 41 | //! 42 | //! ```rust 43 | //! use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 44 | //! 45 | //! #[tokio::main] 46 | //! async fn main() -> Result<(), Box> { 47 | //! 48 | //! let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") 49 | //! .build() 50 | //! .await?; 51 | //! 52 | //! let df = spark.sql("SELECT * FROM json.`/datasets/employees.json`").await?; 53 | //! 54 | //! // Show the first 5 records 55 | //! df.filter("salary > 3000").show(Some(5), None, None).await?; 56 | //! 57 | //! Ok(()) 58 | //! }; 59 | //!``` 60 | //! 61 | //! Create a Spark Session, read a CSV file into a DataFrame, apply function transformations, and write the results: 62 | //! 63 | //! ```rust 64 | //! use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 65 | //! 66 | //! use spark_connect_rs::functions as F; 67 | //! 68 | //! #[tokio::main] 69 | //! async fn main() -> Result<(), Box> { 70 | //! 71 | //! let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") 72 | //! .build() 73 | //! .await?; 74 | //! 75 | //! let paths = ["/datasets/people.csv"]; 76 | //! 77 | //! let df = spark 78 | //! .read() 79 | //! .format("csv") 80 | //! .option("header", "True") 81 | //! .option("delimiter", ";") 82 | //! .load(paths)?; 83 | //! 84 | //! let df = df 85 | //! .filter("age > 30") 86 | //! .select([ 87 | //! F::col("name"), 88 | //! F::col("age").cast("int") 89 | //! ]); 90 | //! 91 | //! df.write() 92 | //! .format("csv") 93 | //! .option("header", "true") 94 | //! .save("/opt/spark/examples/src/main/rust/people/") 95 | //! .await?; 96 | //! 97 | //! Ok(()) 98 | //! }; 99 | //!``` 100 | //! 101 | //! ## Databricks Connection 102 | //! 103 | //! Spark Connect is enabled for Databricks Runtime 13.3 LTS and above, and requires the feature 104 | //! flag `feature = "tls"`. The connection string for the remote session must contain the following 105 | //! values in the string; 106 | //! 107 | //! ```bash 108 | //! "sc://:443/;token=;x-databricks-cluster-id=" 109 | //! ``` 110 | //! 111 | //! 112 | 113 | /// Spark Connect gRPC protobuf translated using [tonic] 114 | pub mod spark { 115 | tonic::include_proto!("spark.connect"); 116 | } 117 | 118 | pub mod catalog; 119 | pub mod client; 120 | pub mod column; 121 | pub mod conf; 122 | pub mod dataframe; 123 | pub mod errors; 124 | pub mod expressions; 125 | pub mod functions; 126 | pub mod group; 127 | pub mod plan; 128 | pub mod readwriter; 129 | pub mod session; 130 | pub mod storage; 131 | pub mod streaming; 132 | pub mod types; 133 | pub mod window; 134 | 135 | pub use dataframe::{DataFrame, DataFrameReader, DataFrameWriter}; 136 | pub use session::{SparkSession, SparkSessionBuilder}; 137 | -------------------------------------------------------------------------------- /crates/connect/src/session.rs: -------------------------------------------------------------------------------- 1 | //! Spark Session containing the remote gRPC client 2 | 3 | use std::collections::HashMap; 4 | use std::sync::Arc; 5 | 6 | use crate::client::{ChannelBuilder, Config, HeadersLayer, SparkClient, SparkConnectClient}; 7 | 8 | use crate::catalog::Catalog; 9 | use crate::conf::RunTimeConfig; 10 | use crate::dataframe::{DataFrame, DataFrameReader}; 11 | use crate::errors::SparkError; 12 | use crate::plan::LogicalPlanBuilder; 13 | use crate::streaming::{DataStreamReader, StreamingQueryManager}; 14 | 15 | use crate::spark; 16 | use spark::spark_connect_service_client::SparkConnectServiceClient; 17 | 18 | use arrow::record_batch::RecordBatch; 19 | 20 | use tokio::sync::RwLock; 21 | 22 | use tower::ServiceBuilder; 23 | 24 | use tonic::transport::Channel; 25 | 26 | /// SparkSessionBuilder creates a remote Spark Session a connection string. 27 | /// 28 | /// The connection string is define based on the requirements from [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) 29 | #[derive(Clone, Debug)] 30 | pub struct SparkSessionBuilder { 31 | pub channel_builder: ChannelBuilder, 32 | configs: HashMap, 33 | } 34 | 35 | /// Default connects a Spark cluster running at `sc://127.0.0.1:15002/` 36 | impl Default for SparkSessionBuilder { 37 | fn default() -> Self { 38 | let channel_builder = ChannelBuilder::default(); 39 | 40 | Self { 41 | channel_builder, 42 | configs: HashMap::new(), 43 | } 44 | } 45 | } 46 | 47 | impl SparkSessionBuilder { 48 | fn new(connection: &str) -> Self { 49 | let channel_builder = ChannelBuilder::create(connection).unwrap(); 50 | 51 | Self { 52 | channel_builder, 53 | configs: HashMap::new(), 54 | } 55 | } 56 | 57 | /// Create a new Spark Session from a [Config] object 58 | pub fn from_config(config: Config) -> Self { 59 | Self { 60 | channel_builder: config.into(), 61 | configs: HashMap::new(), 62 | } 63 | } 64 | 65 | /// Validate a connect string for a remote Spark Session 66 | /// 67 | /// String must conform to the [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) 68 | pub fn remote(connection: &str) -> Self { 69 | Self::new(connection) 70 | } 71 | 72 | /// Sets a config option. 73 | pub fn config(mut self, key: &str, value: &str) -> Self { 74 | self.configs.insert(key.into(), value.into()); 75 | self 76 | } 77 | 78 | /// Sets a name for the application, which will be shown in the Spark web UI. 79 | pub fn app_name(mut self, name: &str) -> Self { 80 | self.configs 81 | .insert("spark.app.name".to_string(), name.into()); 82 | self 83 | } 84 | 85 | async fn create_client(&self) -> Result { 86 | let channel = Channel::from_shared(self.channel_builder.endpoint())? 87 | .connect() 88 | .await?; 89 | 90 | let channel = ServiceBuilder::new() 91 | .layer(HeadersLayer::new( 92 | self.channel_builder.headers().unwrap_or_default(), 93 | )) 94 | .service(channel); 95 | 96 | let client = SparkConnectServiceClient::new(channel); 97 | 98 | let spark_connnect_client = 99 | SparkConnectClient::new(Arc::new(RwLock::new(client)), self.channel_builder.clone()); 100 | 101 | let mut rt_config = RunTimeConfig::new(&spark_connnect_client); 102 | 103 | rt_config.set_configs(&self.configs).await?; 104 | 105 | Ok(SparkSession::new(spark_connnect_client)) 106 | } 107 | 108 | /// Attempt to connect to a remote Spark Session 109 | /// 110 | /// and return a [SparkSession] 111 | pub async fn build(&self) -> Result { 112 | self.create_client().await 113 | } 114 | } 115 | 116 | /// The entry point to connecting to a Spark Cluster 117 | /// using the Spark Connection gRPC protocol. 118 | #[derive(Clone, Debug)] 119 | pub struct SparkSession { 120 | client: SparkClient, 121 | session_id: String, 122 | } 123 | 124 | impl SparkSession { 125 | pub fn new(client: SparkClient) -> Self { 126 | Self { 127 | session_id: client.session_id(), 128 | client, 129 | } 130 | } 131 | 132 | pub fn session(&self) -> SparkSession { 133 | self.clone() 134 | } 135 | 136 | /// Create a [DataFrame] with a spingle column named `id`, 137 | /// containing elements in a range from `start` (default 0) to 138 | /// `end` (exclusive) with a step value `step`, and control the number 139 | /// of partitions with `num_partitions` 140 | pub fn range( 141 | &self, 142 | start: Option, 143 | end: i64, 144 | step: i64, 145 | num_partitions: Option, 146 | ) -> DataFrame { 147 | let range_relation = spark::relation::RelType::Range(spark::Range { 148 | start, 149 | end, 150 | step, 151 | num_partitions, 152 | }); 153 | 154 | DataFrame::new(self.session(), LogicalPlanBuilder::from(range_relation)) 155 | } 156 | 157 | /// Returns a [DataFrameReader] that can be used to read datra in as a [DataFrame] 158 | pub fn read(&self) -> DataFrameReader { 159 | DataFrameReader::new(self.session()) 160 | } 161 | 162 | /// Returns a [DataFrameReader] that can be used to read datra in as a [DataFrame] 163 | pub fn read_stream(&self) -> DataStreamReader { 164 | DataStreamReader::new(self.session()) 165 | } 166 | 167 | pub fn table(&self, name: &str) -> Result { 168 | DataFrameReader::new(self.session()).table(name, None) 169 | } 170 | 171 | /// Interface through which the user may create, drop, alter or query underlying databases, 172 | /// tables, functions, etc. 173 | pub fn catalog(&self) -> Catalog { 174 | Catalog::new(self.session()) 175 | } 176 | 177 | /// Returns a [DataFrame] representing the result of the given query 178 | pub async fn sql(&self, sql_query: &str) -> Result { 179 | let sql_cmd = spark::command::CommandType::SqlCommand(spark::SqlCommand { 180 | sql: sql_query.to_string(), 181 | args: HashMap::default(), 182 | pos_args: vec![], 183 | }); 184 | 185 | let plan = LogicalPlanBuilder::plan_cmd(sql_cmd); 186 | 187 | let resp = self 188 | .clone() 189 | .client() 190 | .execute_command_and_fetch(plan) 191 | .await?; 192 | 193 | let relation = resp.sql_command_result.to_owned().unwrap().relation; 194 | 195 | let logical_plan = LogicalPlanBuilder::new(relation.unwrap()); 196 | 197 | Ok(DataFrame::new(self.session(), logical_plan)) 198 | } 199 | 200 | pub fn create_dataframe(&self, data: &RecordBatch) -> Result { 201 | let logical_plan = LogicalPlanBuilder::local_relation(data)?; 202 | 203 | Ok(DataFrame::new(self.session(), logical_plan)) 204 | } 205 | 206 | /// Return the session ID 207 | pub fn session_id(&self) -> &str { 208 | &self.session_id 209 | } 210 | 211 | /// Spark Connection gRPC client interface 212 | pub fn client(self) -> SparkClient { 213 | self.client 214 | } 215 | 216 | /// Interrupt all operations of this session currently running on the connected server. 217 | pub async fn interrupt_all(&self) -> Result, SparkError> { 218 | let resp = self 219 | .client 220 | .interrupt_request(spark::interrupt_request::InterruptType::All, None) 221 | .await?; 222 | 223 | Ok(resp.interrupted_ids) 224 | } 225 | 226 | /// Interrupt all operations of this session with the given operation tag. 227 | pub async fn interrupt_tag(&self, tag: &str) -> Result, SparkError> { 228 | let resp = self 229 | .client 230 | .interrupt_request( 231 | spark::interrupt_request::InterruptType::Tag, 232 | Some(tag.to_string()), 233 | ) 234 | .await?; 235 | 236 | Ok(resp.interrupted_ids) 237 | } 238 | 239 | /// Interrupt an operation of this session with the given operationId. 240 | pub async fn interrupt_operation(&self, op_id: &str) -> Result, SparkError> { 241 | let resp = self 242 | .client 243 | .interrupt_request( 244 | spark::interrupt_request::InterruptType::OperationId, 245 | Some(op_id.to_string()), 246 | ) 247 | .await?; 248 | 249 | Ok(resp.interrupted_ids) 250 | } 251 | 252 | /// Add a tag to be assigned to all the operations started by this thread in this session. 253 | pub fn add_tag(&mut self, tag: &str) -> Result<(), SparkError> { 254 | self.client.add_tag(tag) 255 | } 256 | 257 | /// Remove a tag previously added to be assigned to all the operations started by this thread in this session. 258 | pub fn remove_tag(&mut self, tag: &str) -> Result<(), SparkError> { 259 | self.client.remove_tag(tag) 260 | } 261 | 262 | /// Get the tags that are currently set to be assigned to all the operations started by this thread. 263 | pub fn get_tags(&mut self) -> &Vec { 264 | self.client.get_tags() 265 | } 266 | 267 | /// Clear the current thread’s operation tags. 268 | pub fn clear_tags(&mut self) { 269 | self.client.clear_tags() 270 | } 271 | 272 | /// The version of Spark on which this application is running. 273 | pub async fn version(&self) -> Result { 274 | let version = spark::analyze_plan_request::Analyze::SparkVersion( 275 | spark::analyze_plan_request::SparkVersion {}, 276 | ); 277 | 278 | let mut client = self.client.clone(); 279 | 280 | client.analyze(version).await?.spark_version() 281 | } 282 | 283 | /// [RunTimeConfig] configuration interface for Spark. 284 | pub fn conf(&self) -> RunTimeConfig { 285 | RunTimeConfig::new(&self.client) 286 | } 287 | 288 | /// Returns a [StreamingQueryManager] that allows managing all the StreamingQuery instances active on this context. 289 | pub fn streams(&self) -> StreamingQueryManager { 290 | StreamingQueryManager::new(self) 291 | } 292 | } 293 | 294 | #[cfg(test)] 295 | mod tests { 296 | use super::*; 297 | 298 | use arrow::{ 299 | array::{ArrayRef, StringArray}, 300 | record_batch::RecordBatch, 301 | }; 302 | 303 | use regex::Regex; 304 | 305 | async fn setup() -> SparkSession { 306 | println!("SparkSession Setup"); 307 | 308 | let connection = "sc://127.0.0.1:15002/;user_id=rust_test;session_id=0d2af2a9-cc3c-4d4b-bf27-e2fefeaca233"; 309 | 310 | SparkSessionBuilder::remote(connection) 311 | .build() 312 | .await 313 | .unwrap() 314 | } 315 | 316 | #[tokio::test] 317 | async fn test_spark_range() -> Result<(), SparkError> { 318 | let spark = setup().await; 319 | 320 | let df = spark.range(None, 100, 1, Some(8)); 321 | 322 | let records = df.collect().await?; 323 | 324 | assert_eq!(records.num_rows(), 100); 325 | Ok(()) 326 | } 327 | 328 | #[tokio::test] 329 | async fn test_spark_create_dataframe() -> Result<(), SparkError> { 330 | let spark = setup().await; 331 | 332 | let a: ArrayRef = Arc::new(StringArray::from(vec!["hello", "world"])); 333 | 334 | let record_batch = RecordBatch::try_from_iter(vec![("a", a)])?; 335 | 336 | let df = spark.create_dataframe(&record_batch)?; 337 | 338 | let rows = df.collect().await?; 339 | 340 | assert_eq!(record_batch, rows); 341 | Ok(()) 342 | } 343 | 344 | #[tokio::test] 345 | async fn test_spark_session_create() { 346 | let connection = 347 | "sc://localhost:15002/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; 348 | 349 | let spark = SparkSessionBuilder::remote(connection).build().await; 350 | 351 | assert!(spark.is_ok()); 352 | } 353 | 354 | #[tokio::test] 355 | async fn test_session_tags() -> Result<(), SparkError> { 356 | let mut spark = SparkSessionBuilder::default().build().await?; 357 | 358 | spark.add_tag("hello-tag")?; 359 | 360 | spark.add_tag("hello-tag-2")?; 361 | 362 | let expected = vec!["hello-tag".to_string(), "hello-tag-2".to_string()]; 363 | 364 | let res = spark.get_tags(); 365 | 366 | assert_eq!(&expected, res); 367 | 368 | spark.clear_tags(); 369 | let res = spark.get_tags(); 370 | 371 | let expected: Vec = vec![]; 372 | 373 | assert_eq!(&expected, res); 374 | 375 | Ok(()) 376 | } 377 | 378 | #[tokio::test] 379 | async fn test_session_tags_panic() -> Result<(), SparkError> { 380 | let mut spark = SparkSessionBuilder::default().build().await?; 381 | 382 | assert!(spark.add_tag("bad,tag").is_err()); 383 | assert!(spark.add_tag("").is_err()); 384 | 385 | assert!(spark.remove_tag("bad,tag").is_err()); 386 | assert!(spark.remove_tag("").is_err()); 387 | 388 | Ok(()) 389 | } 390 | 391 | #[tokio::test] 392 | async fn test_session_version() -> Result<(), SparkError> { 393 | let spark = SparkSessionBuilder::default().build().await?; 394 | 395 | let version = spark.version().await?; 396 | 397 | let version_pattern = Regex::new(r"^\d+\.\d+\.\d+$").unwrap(); 398 | assert!( 399 | version_pattern.is_match(&version), 400 | "Version {} does not match X.X.X format", 401 | version 402 | ); 403 | 404 | Ok(()) 405 | } 406 | 407 | #[tokio::test] 408 | async fn test_session_config() -> Result<(), SparkError> { 409 | let value = "rust-test-app"; 410 | 411 | let spark = SparkSessionBuilder::default() 412 | .app_name("rust-test-app") 413 | .build() 414 | .await?; 415 | 416 | let name = spark.conf().get("spark.app.name", None).await?; 417 | 418 | assert_eq!(value, &name); 419 | 420 | // validate set 421 | spark 422 | .conf() 423 | .set("spark.sql.shuffle.partitions", "42") 424 | .await?; 425 | 426 | // validate get 427 | let val = spark 428 | .conf() 429 | .get("spark.sql.shuffle.partitions", None) 430 | .await?; 431 | 432 | assert_eq!("42", &val); 433 | 434 | // validate unset 435 | spark.conf().unset("spark.sql.shuffle.partitions").await?; 436 | 437 | let val = spark 438 | .conf() 439 | .get("spark.sql.shuffle.partitions", None) 440 | .await?; 441 | 442 | assert_eq!("200", &val); 443 | 444 | // not a modifable setting 445 | let val = spark 446 | .conf() 447 | .is_modifable("spark.executor.instances") 448 | .await?; 449 | assert!(!val); 450 | 451 | // a modifable setting 452 | let val = spark 453 | .conf() 454 | .is_modifable("spark.sql.shuffle.partitions") 455 | .await?; 456 | assert!(val); 457 | 458 | Ok(()) 459 | } 460 | } 461 | -------------------------------------------------------------------------------- /crates/connect/src/storage.rs: -------------------------------------------------------------------------------- 1 | //! Enum for handling Spark Storage representations 2 | 3 | use crate::spark; 4 | 5 | #[derive(Clone, Copy, Debug)] 6 | pub enum StorageLevel { 7 | None, 8 | DiskOnly, 9 | DiskOnly2, 10 | DiskOnly3, 11 | MemoryOnly, 12 | MemoryOnly2, 13 | MemoryAndDisk, 14 | MemoryAndDisk2, 15 | OffHeap, 16 | MemoryAndDiskDeser, 17 | } 18 | 19 | impl From for StorageLevel { 20 | fn from(spark_level: spark::StorageLevel) -> Self { 21 | match ( 22 | spark_level.use_disk, 23 | spark_level.use_memory, 24 | spark_level.use_off_heap, 25 | spark_level.deserialized, 26 | spark_level.replication, 27 | ) { 28 | (false, false, false, false, _) => StorageLevel::None, 29 | (true, false, false, false, 1) => StorageLevel::DiskOnly, 30 | (true, false, false, false, 2) => StorageLevel::DiskOnly2, 31 | (true, false, false, false, 3) => StorageLevel::DiskOnly3, 32 | (false, true, false, false, 1) => StorageLevel::MemoryOnly, 33 | (false, true, false, false, 2) => StorageLevel::MemoryOnly2, 34 | (true, true, false, false, 1) => StorageLevel::MemoryAndDisk, 35 | (true, true, false, false, 2) => StorageLevel::MemoryAndDisk2, 36 | (true, true, true, false, 1) => StorageLevel::OffHeap, 37 | (true, true, false, true, 1) => StorageLevel::MemoryAndDiskDeser, 38 | _ => unimplemented!(), 39 | } 40 | } 41 | } 42 | 43 | impl From for spark::StorageLevel { 44 | fn from(storage: StorageLevel) -> spark::StorageLevel { 45 | match storage { 46 | StorageLevel::None => spark::StorageLevel { 47 | use_disk: false, 48 | use_memory: false, 49 | use_off_heap: false, 50 | deserialized: false, 51 | replication: 1, 52 | }, 53 | StorageLevel::DiskOnly => spark::StorageLevel { 54 | use_disk: true, 55 | use_memory: false, 56 | use_off_heap: false, 57 | deserialized: false, 58 | replication: 1, 59 | }, 60 | StorageLevel::DiskOnly2 => spark::StorageLevel { 61 | use_disk: true, 62 | use_memory: false, 63 | use_off_heap: false, 64 | deserialized: false, 65 | replication: 2, 66 | }, 67 | StorageLevel::DiskOnly3 => spark::StorageLevel { 68 | use_disk: true, 69 | use_memory: false, 70 | use_off_heap: false, 71 | deserialized: false, 72 | replication: 3, 73 | }, 74 | StorageLevel::MemoryOnly => spark::StorageLevel { 75 | use_disk: false, 76 | use_memory: true, 77 | use_off_heap: false, 78 | deserialized: false, 79 | replication: 1, 80 | }, 81 | StorageLevel::MemoryOnly2 => spark::StorageLevel { 82 | use_disk: false, 83 | use_memory: true, 84 | use_off_heap: false, 85 | deserialized: false, 86 | replication: 2, 87 | }, 88 | StorageLevel::MemoryAndDisk => spark::StorageLevel { 89 | use_disk: true, 90 | use_memory: true, 91 | use_off_heap: false, 92 | deserialized: false, 93 | replication: 1, 94 | }, 95 | StorageLevel::MemoryAndDisk2 => spark::StorageLevel { 96 | use_disk: true, 97 | use_memory: true, 98 | use_off_heap: false, 99 | deserialized: false, 100 | replication: 2, 101 | }, 102 | StorageLevel::OffHeap => spark::StorageLevel { 103 | use_disk: true, 104 | use_memory: true, 105 | use_off_heap: true, 106 | deserialized: false, 107 | replication: 1, 108 | }, 109 | StorageLevel::MemoryAndDiskDeser => spark::StorageLevel { 110 | use_disk: true, 111 | use_memory: true, 112 | use_off_heap: false, 113 | deserialized: true, 114 | replication: 1, 115 | }, 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /crates/connect/src/window.rs: -------------------------------------------------------------------------------- 1 | //! Utility structs for defining a window over a DataFrame 2 | 3 | use crate::column::Column; 4 | use crate::expressions::VecExpression; 5 | use crate::functions::lit; 6 | use crate::plan::sort_order; 7 | 8 | use crate::spark; 9 | use crate::spark::expression::window; 10 | 11 | /// A window specification that defines the partitioning, ordering, and frame boundaries. 12 | /// 13 | /// **Recommended to create a WindowSpec using [Window] and not directly** 14 | #[derive(Debug, Default, Clone)] 15 | pub struct WindowSpec { 16 | pub partition_spec: Vec, 17 | pub order_spec: Vec, 18 | pub frame_spec: Option>, 19 | } 20 | 21 | impl WindowSpec { 22 | pub fn new( 23 | partition_spec: Vec, 24 | order_spec: Vec, 25 | frame_spec: Option>, 26 | ) -> WindowSpec { 27 | WindowSpec { 28 | partition_spec, 29 | order_spec, 30 | frame_spec, 31 | } 32 | } 33 | 34 | pub fn partition_by(self, cols: I) -> WindowSpec 35 | where 36 | I: IntoIterator, 37 | S: Into, 38 | { 39 | WindowSpec::new( 40 | VecExpression::from_iter(cols).into(), 41 | self.order_spec, 42 | self.frame_spec, 43 | ) 44 | } 45 | 46 | pub fn order_by(self, cols: I) -> WindowSpec 47 | where 48 | I: IntoIterator, 49 | S: Into, 50 | { 51 | let order_spec = sort_order(cols); 52 | 53 | WindowSpec::new(self.partition_spec, order_spec, self.frame_spec) 54 | } 55 | 56 | pub fn rows_between(self, start: i64, end: i64) -> WindowSpec { 57 | let frame_spec = WindowSpec::window_frame(true, start, end); 58 | 59 | WindowSpec::new(self.partition_spec, self.order_spec, frame_spec) 60 | } 61 | 62 | pub fn range_between(self, start: i64, end: i64) -> WindowSpec { 63 | let frame_spec = WindowSpec::window_frame(false, start, end); 64 | 65 | WindowSpec::new(self.partition_spec, self.order_spec, frame_spec) 66 | } 67 | 68 | fn frame_boundary(value: i64) -> Option> { 69 | match value { 70 | 0 => { 71 | let boundary = Some(window::window_frame::frame_boundary::Boundary::CurrentRow( 72 | true, 73 | )); 74 | 75 | Some(Box::new(window::window_frame::FrameBoundary { boundary })) 76 | } 77 | i64::MIN => { 78 | let boundary = Some(window::window_frame::frame_boundary::Boundary::Unbounded( 79 | true, 80 | )); 81 | 82 | Some(Box::new(window::window_frame::FrameBoundary { boundary })) 83 | } 84 | _ => { 85 | // !TODO - I don't like casting this to i32 86 | // however, the window boundary is expecting an INT and not a BIGINT 87 | // i64 is a BIGINT (i.e. Long) 88 | let expr = lit(value as i32); 89 | 90 | let boundary = Some(window::window_frame::frame_boundary::Boundary::Value( 91 | Box::new(expr.into()), 92 | )); 93 | 94 | Some(Box::new(window::window_frame::FrameBoundary { boundary })) 95 | } 96 | } 97 | } 98 | 99 | fn window_frame(row_frame: bool, start: i64, end: i64) -> Option> { 100 | let frame_type = match row_frame { 101 | true => 1, 102 | false => 2, 103 | }; 104 | 105 | let lower = WindowSpec::frame_boundary(start); 106 | let upper = WindowSpec::frame_boundary(end); 107 | 108 | Some(Box::new(window::WindowFrame { 109 | frame_type, 110 | lower, 111 | upper, 112 | })) 113 | } 114 | } 115 | 116 | /// Primary utility struct for defining window in DataFrames 117 | #[derive(Debug, Default, Clone)] 118 | pub struct Window { 119 | spec: WindowSpec, 120 | } 121 | 122 | impl Window { 123 | /// Creates a new empty [WindowSpec] 124 | pub fn new() -> Self { 125 | Window { 126 | spec: WindowSpec::default(), 127 | } 128 | } 129 | 130 | /// Returns 0 131 | pub fn current_row() -> i64 { 132 | 0 133 | } 134 | 135 | /// Returns [i64::MAX] 136 | pub fn unbounded_following() -> i64 { 137 | i64::MAX 138 | } 139 | 140 | /// Returns [i64::MIN] 141 | pub fn unbounded_preceding() -> i64 { 142 | i64::MIN 143 | } 144 | 145 | /// Creates a [WindowSpec] with the partitioning defined 146 | pub fn partition_by(mut self, cols: I) -> WindowSpec 147 | where 148 | I: IntoIterator, 149 | S: Into, 150 | { 151 | self.spec = self.spec.partition_by(cols); 152 | 153 | self.spec 154 | } 155 | 156 | /// Creates a [WindowSpec] with the ordering defined 157 | pub fn order_by(mut self, cols: I) -> WindowSpec 158 | where 159 | I: IntoIterator, 160 | S: Into, 161 | { 162 | self.spec = self.spec.order_by(cols); 163 | 164 | self.spec 165 | } 166 | 167 | /// Creates a [WindowSpec] with the frame boundaries defined, from start (inclusive) to end (inclusive). 168 | /// 169 | /// Both start and end are relative from the current row. For example, “0” means “current row”, 170 | /// while “-1” means one off before the current row, and “5” means the five off after the current row. 171 | /// 172 | /// Recommended to use [Window::unbounded_preceding], [Window::unbounded_following], and [Window::current_row] 173 | /// to specify special boundary values, rather than using integral values directly. 174 | /// 175 | /// # Example 176 | /// 177 | /// ``` 178 | /// let window = Window::new() 179 | /// .partition_by(col("name")) 180 | /// .order_by([col("age")]) 181 | /// .range_between(Window::unbounded_preceding(), Window::current_row()); 182 | /// 183 | /// let df = df.with_column("rank", rank().over(window.clone())) 184 | /// .with_column("min", min("age").over(window)); 185 | /// ``` 186 | pub fn range_between(mut self, start: i64, end: i64) -> WindowSpec { 187 | self.spec = self.spec.range_between(start, end); 188 | 189 | self.spec 190 | } 191 | 192 | /// Creates a [WindowSpec] with the frame boundaries defined, from start (inclusive) to end (inclusive). 193 | /// 194 | /// Both start and end are relative from the current row. For example, “0” means “current row”, 195 | /// while “-1” means one off before the current row, and “5” means the five off after the current row. 196 | /// 197 | /// Recommended to use [Window::unbounded_preceding], [Window::unbounded_following], and [Window::current_row] 198 | /// to specify special boundary values, rather than using integral values directly. 199 | /// 200 | /// # Example 201 | /// 202 | /// ``` 203 | /// let window = Window::new() 204 | /// .partition_by(col("name")) 205 | /// .order_by([col("age")]) 206 | /// .rows_between(Window::unbounded_preceding(), Window::current_row()); 207 | /// 208 | /// let df = df.with_column("rank", rank().over(window.clone())) 209 | /// .with_column("min", min("age").over(window)); 210 | /// ``` 211 | pub fn rows_between(mut self, start: i64, end: i64) -> WindowSpec { 212 | self.spec = self.spec.rows_between(start, end); 213 | 214 | self.spec 215 | } 216 | } 217 | 218 | #[cfg(test)] 219 | mod tests { 220 | 221 | use arrow::{ 222 | array::{ArrayRef, Int32Array, Int64Array, StringArray}, 223 | datatypes::{DataType, Field, Schema}, 224 | record_batch::RecordBatch, 225 | }; 226 | 227 | use std::sync::Arc; 228 | 229 | use super::*; 230 | 231 | use crate::errors::SparkError; 232 | use crate::functions::*; 233 | use crate::SparkSession; 234 | use crate::SparkSessionBuilder; 235 | 236 | async fn setup() -> SparkSession { 237 | println!("SparkSession Setup"); 238 | 239 | let connection = "sc://127.0.0.1:15002/;user_id=rust_window"; 240 | 241 | SparkSessionBuilder::remote(connection) 242 | .build() 243 | .await 244 | .unwrap() 245 | } 246 | 247 | fn mock_data() -> RecordBatch { 248 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 2, 1, 2, 3])); 249 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "a", "b", "b", "b"])); 250 | 251 | RecordBatch::try_from_iter(vec![("id", id), ("category", category)]).unwrap() 252 | } 253 | 254 | #[tokio::test] 255 | async fn test_window_over() -> Result<(), SparkError> { 256 | let spark = setup().await; 257 | 258 | let name: ArrayRef = Arc::new(StringArray::from(vec!["Alice", "Bob"])); 259 | let age: ArrayRef = Arc::new(Int64Array::from(vec![2, 5])); 260 | 261 | let data = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])?; 262 | 263 | let df = spark.create_dataframe(&data)?; 264 | 265 | let window = Window::new() 266 | .partition_by([col("name")]) 267 | .order_by([col("age")]) 268 | .rows_between(Window::unbounded_preceding(), Window::current_row()); 269 | 270 | let res = df 271 | .with_column("rank", rank().over(window.clone())) 272 | .with_column("min", min("age").over(window)) 273 | .collect() 274 | .await?; 275 | 276 | let name: ArrayRef = Arc::new(StringArray::from(vec!["Alice", "Bob"])); 277 | let age: ArrayRef = Arc::new(Int64Array::from(vec![2, 5])); 278 | let rank: ArrayRef = Arc::new(Int32Array::from(vec![1, 1])); 279 | let min = age.clone(); 280 | 281 | let schema = Schema::new(vec![ 282 | Field::new("name", DataType::Utf8, false), 283 | Field::new("age", DataType::Int64, false), 284 | Field::new("rank", DataType::Int32, false), 285 | Field::new("min", DataType::Int64, true), 286 | ]); 287 | 288 | let expected = RecordBatch::try_new(Arc::new(schema), vec![name, age, rank, min])?; 289 | 290 | assert_eq!(expected, res); 291 | 292 | Ok(()) 293 | } 294 | 295 | #[tokio::test] 296 | async fn test_window_orderby() -> Result<(), SparkError> { 297 | let spark = setup().await; 298 | 299 | let data = mock_data(); 300 | 301 | let df = spark.create_dataframe(&data)?; 302 | 303 | let window = Window::new() 304 | .partition_by([col("id")]) 305 | .order_by([col("category")]); 306 | 307 | let res = df 308 | .with_column("row_number", row_number().over(window)) 309 | .collect() 310 | .await?; 311 | 312 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 1, 2, 2, 3])); 313 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "b", "a", "b", "b"])); 314 | let row_number: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 1])); 315 | 316 | let expected = RecordBatch::try_from_iter(vec![ 317 | ("id", id), 318 | ("category", category), 319 | ("row_number", row_number), 320 | ])?; 321 | 322 | assert_eq!(expected, res); 323 | 324 | Ok(()) 325 | } 326 | 327 | #[tokio::test] 328 | async fn test_window_partitionby() -> Result<(), SparkError> { 329 | let spark = setup().await; 330 | 331 | let data = mock_data(); 332 | 333 | let df = spark.create_dataframe(&data)?; 334 | 335 | let window = Window::new() 336 | .partition_by([col("category")]) 337 | .order_by([col("id")]); 338 | 339 | let res = df 340 | .with_column("row_number", row_number().over(window)) 341 | .collect() 342 | .await?; 343 | 344 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 2, 1, 2, 3])); 345 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "a", "b", "b", "b"])); 346 | let row_number: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3])); 347 | 348 | let expected = RecordBatch::try_from_iter(vec![ 349 | ("id", id), 350 | ("category", category), 351 | ("row_number", row_number), 352 | ])?; 353 | 354 | assert_eq!(expected, res); 355 | 356 | Ok(()) 357 | } 358 | 359 | #[tokio::test] 360 | async fn test_window_rangebetween() -> Result<(), SparkError> { 361 | let spark = setup().await; 362 | 363 | let data = mock_data(); 364 | 365 | let df = spark.create_dataframe(&data)?; 366 | 367 | let window = Window::new() 368 | .partition_by([col("category")]) 369 | .order_by([col("id")]) 370 | .range_between(Window::current_row(), 1); 371 | 372 | let res = df 373 | .with_column("sum", sum("id").over(window)) 374 | .sort([col("id"), col("category")]) 375 | .collect() 376 | .await?; 377 | 378 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 1, 2, 2, 3])); 379 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "b", "a", "b", "b"])); 380 | let sum: ArrayRef = Arc::new(Int64Array::from(vec![4, 4, 3, 2, 5, 3])); 381 | 382 | let expected = RecordBatch::try_from_iter_with_nullable(vec![ 383 | ("id", id, false), 384 | ("category", category, false), 385 | ("sum", sum, true), 386 | ])?; 387 | 388 | assert_eq!(expected, res); 389 | 390 | Ok(()) 391 | } 392 | 393 | #[tokio::test] 394 | async fn test_window_rowsbetween() -> Result<(), SparkError> { 395 | let spark = setup().await; 396 | 397 | let data = mock_data(); 398 | 399 | let df = spark.create_dataframe(&data)?; 400 | 401 | let window = Window::new() 402 | .partition_by([col("category")]) 403 | .order_by([col("id")]) 404 | .rows_between(Window::current_row(), 1); 405 | 406 | let res = df 407 | .with_column("sum", sum("id").over(window)) 408 | .sort([col("id"), col("category"), col("sum")]) 409 | .collect() 410 | .await?; 411 | 412 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 1, 2, 2, 3])); 413 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "b", "a", "b", "b"])); 414 | let sum: ArrayRef = Arc::new(Int64Array::from(vec![2, 3, 3, 2, 5, 3])); 415 | 416 | let expected = RecordBatch::try_from_iter_with_nullable(vec![ 417 | ("id", id, false), 418 | ("category", category, false), 419 | ("sum", sum, true), 420 | ])?; 421 | 422 | assert_eq!(expected, res); 423 | 424 | Ok(()) 425 | } 426 | } 427 | -------------------------------------------------------------------------------- /datasets/dir1/dir2/file2.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/251ebce2005c3b24a30d0e3ac9dd52089e8afcaa/datasets/dir1/dir2/file2.parquet -------------------------------------------------------------------------------- /datasets/dir1/file1.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/251ebce2005c3b24a30d0e3ac9dd52089e8afcaa/datasets/dir1/file1.parquet -------------------------------------------------------------------------------- /datasets/dir1/file3.json: -------------------------------------------------------------------------------- 1 | {"file":"corrupt.json"} 2 | -------------------------------------------------------------------------------- /datasets/employees.json: -------------------------------------------------------------------------------- 1 | {"name":"Michael", "salary":3000} 2 | {"name":"Andy", "salary":4500} 3 | {"name":"Justin", "salary":3500} 4 | {"name":"Berta", "salary":4000} 5 | -------------------------------------------------------------------------------- /datasets/full_user.avsc: -------------------------------------------------------------------------------- 1 | {"type": "record", "namespace": "example.avro", "name": "User", "fields": [{"type": "string", "name": "name"}, {"type": ["string", "null"], "name": "favorite_color"}, {"type": {"items": "int", "type": "array"}, "name": "favorite_numbers"}]} 2 | -------------------------------------------------------------------------------- /datasets/kv1.txt: -------------------------------------------------------------------------------- 1 | 238val_238 2 | 86val_86 3 | 311val_311 4 | 27val_27 5 | 165val_165 6 | 409val_409 7 | 255val_255 8 | 278val_278 9 | 98val_98 10 | 484val_484 11 | 265val_265 12 | 193val_193 13 | 401val_401 14 | 150val_150 15 | 273val_273 16 | 224val_224 17 | 369val_369 18 | 66val_66 19 | 128val_128 20 | 213val_213 21 | 146val_146 22 | 406val_406 23 | 429val_429 24 | 374val_374 25 | 152val_152 26 | 469val_469 27 | 145val_145 28 | 495val_495 29 | 37val_37 30 | 327val_327 31 | 281val_281 32 | 277val_277 33 | 209val_209 34 | 15val_15 35 | 82val_82 36 | 403val_403 37 | 166val_166 38 | 417val_417 39 | 430val_430 40 | 252val_252 41 | 292val_292 42 | 219val_219 43 | 287val_287 44 | 153val_153 45 | 193val_193 46 | 338val_338 47 | 446val_446 48 | 459val_459 49 | 394val_394 50 | 237val_237 51 | 482val_482 52 | 174val_174 53 | 413val_413 54 | 494val_494 55 | 207val_207 56 | 199val_199 57 | 466val_466 58 | 208val_208 59 | 174val_174 60 | 399val_399 61 | 396val_396 62 | 247val_247 63 | 417val_417 64 | 489val_489 65 | 162val_162 66 | 377val_377 67 | 397val_397 68 | 309val_309 69 | 365val_365 70 | 266val_266 71 | 439val_439 72 | 342val_342 73 | 367val_367 74 | 325val_325 75 | 167val_167 76 | 195val_195 77 | 475val_475 78 | 17val_17 79 | 113val_113 80 | 155val_155 81 | 203val_203 82 | 339val_339 83 | 0val_0 84 | 455val_455 85 | 128val_128 86 | 311val_311 87 | 316val_316 88 | 57val_57 89 | 302val_302 90 | 205val_205 91 | 149val_149 92 | 438val_438 93 | 345val_345 94 | 129val_129 95 | 170val_170 96 | 20val_20 97 | 489val_489 98 | 157val_157 99 | 378val_378 100 | 221val_221 101 | 92val_92 102 | 111val_111 103 | 47val_47 104 | 72val_72 105 | 4val_4 106 | 280val_280 107 | 35val_35 108 | 427val_427 109 | 277val_277 110 | 208val_208 111 | 356val_356 112 | 399val_399 113 | 169val_169 114 | 382val_382 115 | 498val_498 116 | 125val_125 117 | 386val_386 118 | 437val_437 119 | 469val_469 120 | 192val_192 121 | 286val_286 122 | 187val_187 123 | 176val_176 124 | 54val_54 125 | 459val_459 126 | 51val_51 127 | 138val_138 128 | 103val_103 129 | 239val_239 130 | 213val_213 131 | 216val_216 132 | 430val_430 133 | 278val_278 134 | 176val_176 135 | 289val_289 136 | 221val_221 137 | 65val_65 138 | 318val_318 139 | 332val_332 140 | 311val_311 141 | 275val_275 142 | 137val_137 143 | 241val_241 144 | 83val_83 145 | 333val_333 146 | 180val_180 147 | 284val_284 148 | 12val_12 149 | 230val_230 150 | 181val_181 151 | 67val_67 152 | 260val_260 153 | 404val_404 154 | 384val_384 155 | 489val_489 156 | 353val_353 157 | 373val_373 158 | 272val_272 159 | 138val_138 160 | 217val_217 161 | 84val_84 162 | 348val_348 163 | 466val_466 164 | 58val_58 165 | 8val_8 166 | 411val_411 167 | 230val_230 168 | 208val_208 169 | 348val_348 170 | 24val_24 171 | 463val_463 172 | 431val_431 173 | 179val_179 174 | 172val_172 175 | 42val_42 176 | 129val_129 177 | 158val_158 178 | 119val_119 179 | 496val_496 180 | 0val_0 181 | 322val_322 182 | 197val_197 183 | 468val_468 184 | 393val_393 185 | 454val_454 186 | 100val_100 187 | 298val_298 188 | 199val_199 189 | 191val_191 190 | 418val_418 191 | 96val_96 192 | 26val_26 193 | 165val_165 194 | 327val_327 195 | 230val_230 196 | 205val_205 197 | 120val_120 198 | 131val_131 199 | 51val_51 200 | 404val_404 201 | 43val_43 202 | 436val_436 203 | 156val_156 204 | 469val_469 205 | 468val_468 206 | 308val_308 207 | 95val_95 208 | 196val_196 209 | 288val_288 210 | 481val_481 211 | 457val_457 212 | 98val_98 213 | 282val_282 214 | 197val_197 215 | 187val_187 216 | 318val_318 217 | 318val_318 218 | 409val_409 219 | 470val_470 220 | 137val_137 221 | 369val_369 222 | 316val_316 223 | 169val_169 224 | 413val_413 225 | 85val_85 226 | 77val_77 227 | 0val_0 228 | 490val_490 229 | 87val_87 230 | 364val_364 231 | 179val_179 232 | 118val_118 233 | 134val_134 234 | 395val_395 235 | 282val_282 236 | 138val_138 237 | 238val_238 238 | 419val_419 239 | 15val_15 240 | 118val_118 241 | 72val_72 242 | 90val_90 243 | 307val_307 244 | 19val_19 245 | 435val_435 246 | 10val_10 247 | 277val_277 248 | 273val_273 249 | 306val_306 250 | 224val_224 251 | 309val_309 252 | 389val_389 253 | 327val_327 254 | 242val_242 255 | 369val_369 256 | 392val_392 257 | 272val_272 258 | 331val_331 259 | 401val_401 260 | 242val_242 261 | 452val_452 262 | 177val_177 263 | 226val_226 264 | 5val_5 265 | 497val_497 266 | 402val_402 267 | 396val_396 268 | 317val_317 269 | 395val_395 270 | 58val_58 271 | 35val_35 272 | 336val_336 273 | 95val_95 274 | 11val_11 275 | 168val_168 276 | 34val_34 277 | 229val_229 278 | 233val_233 279 | 143val_143 280 | 472val_472 281 | 322val_322 282 | 498val_498 283 | 160val_160 284 | 195val_195 285 | 42val_42 286 | 321val_321 287 | 430val_430 288 | 119val_119 289 | 489val_489 290 | 458val_458 291 | 78val_78 292 | 76val_76 293 | 41val_41 294 | 223val_223 295 | 492val_492 296 | 149val_149 297 | 449val_449 298 | 218val_218 299 | 228val_228 300 | 138val_138 301 | 453val_453 302 | 30val_30 303 | 209val_209 304 | 64val_64 305 | 468val_468 306 | 76val_76 307 | 74val_74 308 | 342val_342 309 | 69val_69 310 | 230val_230 311 | 33val_33 312 | 368val_368 313 | 103val_103 314 | 296val_296 315 | 113val_113 316 | 216val_216 317 | 367val_367 318 | 344val_344 319 | 167val_167 320 | 274val_274 321 | 219val_219 322 | 239val_239 323 | 485val_485 324 | 116val_116 325 | 223val_223 326 | 256val_256 327 | 263val_263 328 | 70val_70 329 | 487val_487 330 | 480val_480 331 | 401val_401 332 | 288val_288 333 | 191val_191 334 | 5val_5 335 | 244val_244 336 | 438val_438 337 | 128val_128 338 | 467val_467 339 | 432val_432 340 | 202val_202 341 | 316val_316 342 | 229val_229 343 | 469val_469 344 | 463val_463 345 | 280val_280 346 | 2val_2 347 | 35val_35 348 | 283val_283 349 | 331val_331 350 | 235val_235 351 | 80val_80 352 | 44val_44 353 | 193val_193 354 | 321val_321 355 | 335val_335 356 | 104val_104 357 | 466val_466 358 | 366val_366 359 | 175val_175 360 | 403val_403 361 | 483val_483 362 | 53val_53 363 | 105val_105 364 | 257val_257 365 | 406val_406 366 | 409val_409 367 | 190val_190 368 | 406val_406 369 | 401val_401 370 | 114val_114 371 | 258val_258 372 | 90val_90 373 | 203val_203 374 | 262val_262 375 | 348val_348 376 | 424val_424 377 | 12val_12 378 | 396val_396 379 | 201val_201 380 | 217val_217 381 | 164val_164 382 | 431val_431 383 | 454val_454 384 | 478val_478 385 | 298val_298 386 | 125val_125 387 | 431val_431 388 | 164val_164 389 | 424val_424 390 | 187val_187 391 | 382val_382 392 | 5val_5 393 | 70val_70 394 | 397val_397 395 | 480val_480 396 | 291val_291 397 | 24val_24 398 | 351val_351 399 | 255val_255 400 | 104val_104 401 | 70val_70 402 | 163val_163 403 | 438val_438 404 | 119val_119 405 | 414val_414 406 | 200val_200 407 | 491val_491 408 | 237val_237 409 | 439val_439 410 | 360val_360 411 | 248val_248 412 | 479val_479 413 | 305val_305 414 | 417val_417 415 | 199val_199 416 | 444val_444 417 | 120val_120 418 | 429val_429 419 | 169val_169 420 | 443val_443 421 | 323val_323 422 | 325val_325 423 | 277val_277 424 | 230val_230 425 | 478val_478 426 | 178val_178 427 | 468val_468 428 | 310val_310 429 | 317val_317 430 | 333val_333 431 | 493val_493 432 | 460val_460 433 | 207val_207 434 | 249val_249 435 | 265val_265 436 | 480val_480 437 | 83val_83 438 | 136val_136 439 | 353val_353 440 | 172val_172 441 | 214val_214 442 | 462val_462 443 | 233val_233 444 | 406val_406 445 | 133val_133 446 | 175val_175 447 | 189val_189 448 | 454val_454 449 | 375val_375 450 | 401val_401 451 | 421val_421 452 | 407val_407 453 | 384val_384 454 | 256val_256 455 | 26val_26 456 | 134val_134 457 | 67val_67 458 | 384val_384 459 | 379val_379 460 | 18val_18 461 | 462val_462 462 | 492val_492 463 | 100val_100 464 | 298val_298 465 | 9val_9 466 | 341val_341 467 | 498val_498 468 | 146val_146 469 | 458val_458 470 | 362val_362 471 | 186val_186 472 | 285val_285 473 | 348val_348 474 | 167val_167 475 | 18val_18 476 | 273val_273 477 | 183val_183 478 | 281val_281 479 | 344val_344 480 | 97val_97 481 | 469val_469 482 | 315val_315 483 | 84val_84 484 | 28val_28 485 | 37val_37 486 | 448val_448 487 | 152val_152 488 | 348val_348 489 | 307val_307 490 | 194val_194 491 | 414val_414 492 | 477val_477 493 | 222val_222 494 | 126val_126 495 | 90val_90 496 | 169val_169 497 | 403val_403 498 | 400val_400 499 | 200val_200 500 | 97val_97 501 | -------------------------------------------------------------------------------- /datasets/people.csv: -------------------------------------------------------------------------------- 1 | name;age;job 2 | Jorge;30;Developer 3 | Bob;32;Developer 4 | -------------------------------------------------------------------------------- /datasets/people.json: -------------------------------------------------------------------------------- 1 | {"name":"Michael"} 2 | {"name":"Andy", "age":30} 3 | {"name":"Justin", "age":19} 4 | -------------------------------------------------------------------------------- /datasets/people.txt: -------------------------------------------------------------------------------- 1 | Michael, 29 2 | Andy, 30 3 | Justin, 19 4 | -------------------------------------------------------------------------------- /datasets/user.avsc: -------------------------------------------------------------------------------- 1 | {"namespace": "example.avro", 2 | "type": "record", 3 | "name": "User", 4 | "fields": [ 5 | {"name": "name", "type": "string"}, 6 | {"name": "favorite_color", "type": ["string", "null"]} 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /datasets/users.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/251ebce2005c3b24a30d0e3ac9dd52089e8afcaa/datasets/users.avro -------------------------------------------------------------------------------- /datasets/users.orc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/251ebce2005c3b24a30d0e3ac9dd52089e8afcaa/datasets/users.orc -------------------------------------------------------------------------------- /datasets/users.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/251ebce2005c3b24a30d0e3ac9dd52089e8afcaa/datasets/users.parquet -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | spark: 3 | image: "apache/spark:3.5.3-scala2.12-java11-r-ubuntu" 4 | command: > 5 | /opt/spark/sbin/start-connect-server.sh 6 | --packages "org.apache.spark:spark-connect_2.12:3.5.3,io.delta:delta-spark_2.12:3.0.0" 7 | --conf "spark.driver.extraJavaOptions=-Divy.cache.dir=/tmp -Divy.home=/tmp" 8 | --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" 9 | --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" 10 | environment: 11 | - SPARK_NO_DAEMONIZE=true 12 | ports: 13 | - "4040:4040" 14 | - "15002:15002" 15 | volumes: 16 | - ./datasets:/opt/spark/work-dir/datasets 17 | -------------------------------------------------------------------------------- /examples/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "examples" 3 | version = "0.0.0" 4 | authors.workspace = true 5 | edition.workspace = true 6 | license.workspace = true 7 | publish = false 8 | 9 | [[bin]] 10 | name = "sql" 11 | path = "src/sql.rs" 12 | 13 | [[bin]] 14 | name = "deltalake" 15 | path = "src/deltalake.rs" 16 | 17 | [[bin]] 18 | name = "reader" 19 | path = "src/reader.rs" 20 | 21 | [[bin]] 22 | name = "writer" 23 | path = "src/writer.rs" 24 | 25 | [[bin]] 26 | name = "readstream" 27 | path = "src/readstream.rs" 28 | 29 | [[bin]] 30 | name = "databricks" 31 | path = "src/databricks.rs" 32 | required-feature = ["tls"] 33 | 34 | [dependencies] 35 | spark-connect-rs = { version = "0.0.2", path = "../crates/connect" } 36 | tokio = { workspace = true, features = ["rt-multi-thread"] } 37 | 38 | [features] 39 | tls = ["spark-connect-rs/tls"] 40 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Set of examples that show off different features provided by `spark-connect-rs` client. 4 | 5 | In order to build these examples, you must have the `protoc` protocol buffers compilter 6 | installed, along with the git submodule synced. 7 | 8 | ```bash 9 | git clone https://github.com/sjrusso8/spark-connect-rs.git 10 | git submodule update --init --recursive 11 | ``` 12 | 13 | ### sql 14 | 15 | Write a simple SQL statement and save the dataframe as a parquet 16 | 17 | ```bash 18 | $ cargo run --bin sql 19 | ``` 20 | 21 | ### reader 22 | 23 | Read a CSV file, select specific columns, and display the results 24 | 25 | ```bash 26 | $ cargo run --bin reader 27 | ``` 28 | 29 | ### writer 30 | 31 | Create a dataframe, and save the results to a file 32 | 33 | ```bash 34 | $ cargo run --bin writer 35 | ``` 36 | 37 | ### readstream 38 | 39 | Create a streaming query, and monitor the progress of the stream 40 | 41 | ```bash 42 | $ cargo run --bin readstream 43 | ``` 44 | 45 | ### deltalake 46 | 47 | Read a file into a dataframe, save the result as a deltalake table, and append a new record to the table. 48 | 49 | **Prerequisite** the spark cluster must be started with the deltalake package. The `docker-compose.yml` provided in the repo has deltalake pre-installed. 50 | Or if you are running a spark connect server location, run the below scripts first 51 | 52 | If you are running a local spark connect server. The Delta Lake jars need to be added onto the server before it starts. 53 | 54 | ```bash 55 | $ $SPARK_HOME/sbin/start-connect-server.sh --packages "org.apache.spark:spark-connect_2.12:3.5.1,io.delta:delta-spark_2.12:3.0.0" \ 56 | --conf "spark.driver.extraJavaOptions=-Divy.cache.dir=/tmp -Divy.home=/tmp" \ 57 | --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" \ 58 | --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" 59 | ``` 60 | 61 | ```bash 62 | $ cargo run --bin deltalake 63 | ``` 64 | 65 | ### databricks 66 | 67 | Read a Unity Catalog table, perform an aggregation, and display the results. 68 | 69 | **Prerequisite** must have access to a Databricks workspace, a personal access token, and cluster running >=13.3LTS. 70 | 71 | ```bash 72 | $ cargo run --bin databricks --features=tls 73 | ``` 74 | -------------------------------------------------------------------------------- /examples/src/databricks.rs: -------------------------------------------------------------------------------- 1 | // This example demonstrates connecting to a Databricks Cluster via a tls connection. 2 | // 3 | // This demo requires access to a Databricks Workspace, a personal access token, 4 | // and a cluster id. The cluster should be running a 13.3LTS runtime or greater. Populate 5 | // the remote URL string between the `<>` with the appropriate details. 6 | // 7 | // The Databricks workspace instance name is the same as the Server Hostname value for your cluster. 8 | // Get connection details for a Databricks compute resource via https://docs.databricks.com/en/integrations/compute-details.html 9 | // 10 | // To view the connected Spark Session, go to the cluster Spark UI and select the 'Connect' tab. 11 | 12 | use spark_connect_rs::functions::{avg, col}; 13 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 14 | 15 | #[tokio::main] 16 | async fn main() -> Result<(), Box> { 17 | let conn_str = "sc://:443/;token=;x-databricks-cluster-id="; 18 | 19 | // connect the databricks cluster 20 | let spark: SparkSession = SparkSessionBuilder::remote(conn_str).build().await?; 21 | 22 | // read unity catalog table 23 | let df = spark.read().table("samples.nyctaxi.trips", None)?; 24 | 25 | // apply a filter 26 | let filter = "trip_distance BETWEEN 0 AND 10 AND fare_amount BETWEEN 0 AND 50"; 27 | let df = df.filter(filter); 28 | 29 | // groupby the pickup 30 | let df = df 31 | .select(["pickup_zip", "fare_amount"]) 32 | .group_by(Some(["pickup_zip"])); 33 | 34 | // average the fare amount and order by the top 10 zip codes 35 | let df = df 36 | .agg([avg(col("fare_amount")).alias("avg_fare_amount")]) 37 | .order_by([col("avg_fare_amount").desc()]); 38 | 39 | df.show(Some(10), None, None).await?; 40 | 41 | // +---------------------------------+ 42 | // | show_string | 43 | // +---------------------------------+ 44 | // | +----------+------------------+ | 45 | // | |pickup_zip|avg_fare_amount | | 46 | // | +----------+------------------+ | 47 | // | |7086 |40.0 | | 48 | // | |7030 |40.0 | | 49 | // | |11424 |34.25 | | 50 | // | |7087 |31.0 | | 51 | // | |10470 |28.0 | | 52 | // | |11371 |25.532619926199263| | 53 | // | |11375 |25.5 | | 54 | // | |11370 |22.452380952380953| | 55 | // | |11207 |20.5 | | 56 | // | |11218 |20.0 | | 57 | // | +----------+------------------+ | 58 | // | only showing top 10 rows | 59 | // | | 60 | // +---------------------------------+ 61 | 62 | Ok(()) 63 | } 64 | -------------------------------------------------------------------------------- /examples/src/deltalake.rs: -------------------------------------------------------------------------------- 1 | // This example demonstrates creating a Spark DataFrame from a CSV with read options 2 | // and then adding transformations for 'select' & 'sort' 3 | // The resulting dataframe is saved in the `delta` format as a `managed` table 4 | // and `spark.sql` queries are run against the delta table 5 | // 6 | // The remote spark session must have the spark package `io.delta:delta-spark_2.12:{DELTA_VERSION}` enabled. 7 | // Where the `DELTA_VERSION` is the specified Delta Lake version. 8 | 9 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 10 | 11 | use spark_connect_rs::dataframe::SaveMode; 12 | 13 | #[tokio::main] 14 | async fn main() -> Result<(), Box> { 15 | let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/") 16 | .build() 17 | .await?; 18 | 19 | // path might vary based on where you started your spark cluster 20 | // the `/datasets/` folder of spark contains dummy data 21 | let paths = ["./datasets/people.csv"]; 22 | 23 | // Load a CSV file from the spark server 24 | let df = spark 25 | .read() 26 | .format("csv") 27 | .option("header", "True") 28 | .option("delimiter", ";") 29 | .option("inferSchema", "True") 30 | .load(paths)?; 31 | 32 | // write as a delta table and register it as a table 33 | df.write() 34 | .format("delta") 35 | .mode(SaveMode::Overwrite) 36 | .save_as_table("default.people_delta") 37 | .await?; 38 | 39 | // view the history of the table 40 | spark 41 | .sql("DESCRIBE HISTORY default.people_delta") 42 | .await? 43 | .show(Some(1), None, Some(true)) 44 | .await?; 45 | 46 | // create another dataframe 47 | let df = spark 48 | .sql("SELECT 'john' as name, 40 as age, 'engineer' as job") 49 | .await?; 50 | 51 | // append to the delta table 52 | df.write() 53 | .format("delta") 54 | .mode(SaveMode::Append) 55 | .save_as_table("default.people_delta") 56 | .await?; 57 | 58 | // view history 59 | spark 60 | .sql("DESCRIBE HISTORY default.people_delta") 61 | .await? 62 | .show(Some(2), None, Some(true)) 63 | .await?; 64 | 65 | // +-------------------------------------------------------------------------------------------------------+ 66 | // | show_string | 67 | // +-------------------------------------------------------------------------------------------------------+ 68 | // | -RECORD 0-------------------------------------------------------------------------------------------- | 69 | // | version | 1 | 70 | // | timestamp | 2024-05-17 14:27:34.462 | 71 | // | userId | NULL | 72 | // | userName | NULL | 73 | // | operation | WRITE | 74 | // | operationParameters | {mode -> Append, partitionBy -> []} | 75 | // | job | NULL | 76 | // | notebook | NULL | 77 | // | clusterId | NULL | 78 | // | readVersion | 0 | 79 | // | isolationLevel | Serializable | 80 | // | isBlindAppend | true | 81 | // | operationMetrics | {numFiles -> 1, numOutputRows -> 1, numOutputBytes -> 947} | 82 | // | userMetadata | NULL | 83 | // | engineInfo | Apache-Spark/3.5.1 Delta-Lake/3.0.0 | 84 | // | -RECORD 1-------------------------------------------------------------------------------------------- | 85 | // | version | 0 | 86 | // | timestamp | 2024-05-17 14:27:30.726 | 87 | // | userId | NULL | 88 | // | userName | NULL | 89 | // | operation | CREATE OR REPLACE TABLE AS SELECT | 90 | // | operationParameters | {isManaged -> true, description -> NULL, partitionBy -> [], properties -> {}} | 91 | // | job | NULL | 92 | // | notebook | NULL | 93 | // | clusterId | NULL | 94 | // | readVersion | NULL | 95 | // | isolationLevel | Serializable | 96 | // | isBlindAppend | false | 97 | // | operationMetrics | {numFiles -> 1, numOutputRows -> 2, numOutputBytes -> 988} | 98 | // | userMetadata | NULL | 99 | // | engineInfo | Apache-Spark/3.5.1 Delta-Lake/3.0.0 | 100 | // | | 101 | // +-------------------------------------------------------------------------------------------------------+ 102 | 103 | Ok(()) 104 | } 105 | -------------------------------------------------------------------------------- /examples/src/reader.rs: -------------------------------------------------------------------------------- 1 | // This example demonstrates creating a Spark DataFrame from a CSV with read options 2 | // and then adding transformations for 'select' & 'sort' 3 | // printing the results as "show(...)" 4 | 5 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 6 | 7 | use spark_connect_rs::functions as F; 8 | use spark_connect_rs::types::DataType; 9 | 10 | #[tokio::main] 11 | async fn main() -> Result<(), Box> { 12 | let spark: SparkSession = SparkSessionBuilder::default().build().await?; 13 | 14 | let path = "./datasets/people.csv"; 15 | 16 | let df = spark 17 | .read() 18 | .format("csv") 19 | .option("header", "True") 20 | .option("delimiter", ";") 21 | .load([path])?; 22 | 23 | // select columns and perform data manipulations 24 | let df = df 25 | .select([ 26 | F::col("name"), 27 | F::col("age").cast(DataType::Integer).alias("age_int"), 28 | (F::lit(3.0) + F::col("age_int")).alias("addition"), 29 | ]) 30 | .sort([F::col("name").desc()]); 31 | 32 | df.show(Some(5), None, None).await?; 33 | 34 | // print results 35 | // +--------------------------+ 36 | // | show_string | 37 | // +--------------------------+ 38 | // | +-----+-------+--------+ | 39 | // | |name |age_int|addition| | 40 | // | +-----+-------+--------+ | 41 | // | |Jorge|30 |33.0 | | 42 | // | |Bob |32 |35.0 | | 43 | // | +-----+-------+--------+ | 44 | // | | 45 | // +--------------------------+ 46 | 47 | Ok(()) 48 | } 49 | -------------------------------------------------------------------------------- /examples/src/readstream.rs: -------------------------------------------------------------------------------- 1 | use spark_connect_rs::streaming::{OutputMode, Trigger}; 2 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 3 | 4 | use std::{thread, time}; 5 | 6 | // This example demonstrates creating a Spark Stream and monitoring the progress 7 | #[tokio::main] 8 | async fn main() -> Result<(), Box> { 9 | let spark: SparkSession = 10 | SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=stream_example") 11 | .build() 12 | .await?; 13 | 14 | let df = spark 15 | .read_stream() 16 | .format("rate") 17 | .option("rowsPerSecond", "5") 18 | .load(None)?; 19 | 20 | let query = df 21 | .write_stream() 22 | .format("console") 23 | .query_name("example_stream") 24 | .output_mode(OutputMode::Append) 25 | .trigger(Trigger::ProcessingTimeInterval("1 seconds".to_string())) 26 | .start(None) 27 | .await?; 28 | 29 | // loop to get multiple progression stats 30 | for _ in 1..5 { 31 | thread::sleep(time::Duration::from_secs(5)); 32 | let val = query.last_progress().await?; 33 | println!("{}", val); 34 | } 35 | 36 | // stop the active stream 37 | query.stop().await?; 38 | 39 | Ok(()) 40 | } 41 | -------------------------------------------------------------------------------- /examples/src/sql.rs: -------------------------------------------------------------------------------- 1 | // This example demonstrates creating a Spark DataFrame from a SQL command 2 | // and saving the results as a parquet and reading the new parquet file 3 | 4 | use spark_connect_rs::dataframe::SaveMode; 5 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 6 | 7 | #[tokio::main] 8 | async fn main() -> Result<(), Box> { 9 | let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/") 10 | .build() 11 | .await?; 12 | 13 | let df = spark.sql("select 'apple' as word, 123 as count").await?; 14 | 15 | df.write() 16 | .mode(SaveMode::Overwrite) 17 | .format("parquet") 18 | .save("file:///tmp/spark-connect-write-example-output.parquet") 19 | .await?; 20 | 21 | let df = spark 22 | .read() 23 | .format("parquet") 24 | .load(["file:///tmp/spark-connect-write-example-output.parquet"])?; 25 | 26 | df.show(Some(100), None, None).await?; 27 | 28 | // +---------------+ 29 | // | show_string | 30 | // +---------------+ 31 | // | +-----+-----+ | 32 | // | |word |count| | 33 | // | +-----+-----+ | 34 | // | |apple|123 | | 35 | // | +-----+-----+ | 36 | // | | 37 | // +---------------+ 38 | 39 | Ok(()) 40 | } 41 | -------------------------------------------------------------------------------- /examples/src/writer.rs: -------------------------------------------------------------------------------- 1 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 2 | 3 | use spark_connect_rs::functions::col; 4 | 5 | use spark_connect_rs::dataframe::SaveMode; 6 | 7 | // This example demonstrates creating a Spark DataFrame from range() 8 | // alias the column name, writing the results to a CSV 9 | // then reading the csv file back 10 | #[tokio::main] 11 | async fn main() -> Result<(), Box> { 12 | let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/") 13 | .build() 14 | .await?; 15 | 16 | let df = spark 17 | .range(None, 1000, 1, Some(16)) 18 | .select([col("id").alias("range_id")]); 19 | 20 | let path = "file:///tmp/range_table/"; 21 | 22 | df.write() 23 | .format("csv") 24 | .mode(SaveMode::Overwrite) 25 | .option("header", "true") 26 | .save(path) 27 | .await?; 28 | 29 | let df = spark 30 | .read() 31 | .format("csv") 32 | .option("header", "true") 33 | .load([path])?; 34 | 35 | df.show(Some(10), None, None).await?; 36 | 37 | // print results may slighty vary but should be close to the below 38 | // +--------------------------+ 39 | // | show_string | 40 | // +--------------------------+ 41 | // | +--------+ | 42 | // | |range_id| | 43 | // | +--------+ | 44 | // | |312 | | 45 | // | |313 | | 46 | // | |314 | | 47 | // | |315 | | 48 | // | |316 | | 49 | // | |317 | | 50 | // | |318 | | 51 | // | |319 | | 52 | // | |320 | | 53 | // | |321 | | 54 | // | +--------+ | 55 | // | only showing top 10 rows | 56 | // | | 57 | // +--------------------------+ 58 | 59 | Ok(()) 60 | } 61 | -------------------------------------------------------------------------------- /pre-commit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | # This file is git pre-commit hook. 21 | # 22 | # Soft link it as git hook under top dir of apache arrow git repository: 23 | # $ ln -s ../../pre-commit.sh .git/hooks/pre-commit 24 | # 25 | # This file be run directly: 26 | # $ ./pre-commit.sh 27 | 28 | function RED() { 29 | echo "\033[0;31m$@\033[0m" 30 | } 31 | 32 | function GREEN() { 33 | echo "\033[0;32m$@\033[0m" 34 | } 35 | 36 | function BYELLOW() { 37 | echo "\033[1;33m$@\033[0m" 38 | } 39 | 40 | # env GIT_DIR is set by git when run a pre-commit hook. 41 | if [ -z "${GIT_DIR}" ]; then 42 | GIT_DIR=$(git rev-parse --show-toplevel) 43 | fi 44 | 45 | cd ${GIT_DIR} 46 | 47 | NUM_CHANGES=$(git diff --cached --name-only . | 48 | grep -e ".*/*.rs$" | 49 | awk '{print $1}' | 50 | wc -l) 51 | 52 | if [ ${NUM_CHANGES} -eq 0 ]; then 53 | echo -e "$(GREEN INFO): no staged changes in *.rs, $(GREEN skip cargo fmt/clippy)" 54 | exit 0 55 | fi 56 | 57 | # 1. cargo clippy 58 | 59 | echo -e "$(GREEN INFO): cargo clippy ..." 60 | 61 | # Cargo clippy always return exit code 0, and `tee` doesn't work. 62 | # So let's just run cargo clippy. 63 | cargo clippy 64 | echo -e "$(GREEN INFO): cargo clippy done" 65 | 66 | # 2. cargo fmt: format with nightly and stable. 67 | 68 | CHANGED_BY_CARGO_FMT=false 69 | echo -e "$(GREEN INFO): cargo fmt with nightly and stable ..." 70 | 71 | for version in nightly stable; do 72 | CMD="cargo +${version} fmt" 73 | ${CMD} --all -q -- --check 2>/dev/null 74 | if [ $? -ne 0 ]; then 75 | ${CMD} --all 76 | echo -e "$(BYELLOW WARN): ${CMD} changed some files" 77 | CHANGED_BY_CARGO_FMT=true 78 | fi 79 | done 80 | 81 | if ${CHANGED_BY_CARGO_FMT}; then 82 | echo -e "$(RED FAIL): git commit $(RED ABORTED), please have a look and run git add/commit again" 83 | exit 1 84 | fi 85 | 86 | exit 0 87 | --------------------------------------------------------------------------------