├── datasets ├── dir1 │ ├── file3.json │ ├── file1.parquet │ └── dir2 │ │ └── file2.parquet ├── people.txt ├── people.csv ├── users.orc ├── users.avro ├── people.json ├── users.parquet ├── employees.json ├── user.avsc ├── full_user.avsc └── kv1.txt ├── .gitignore ├── crates └── connect │ ├── protobuf │ └── spark-3.5 │ │ ├── buf.yaml │ │ └── spark │ │ └── connect │ │ ├── common.proto │ │ ├── types.proto │ │ ├── catalog.proto │ │ ├── expressions.proto │ │ └── commands.proto │ ├── build.rs │ ├── Cargo.toml │ └── src │ ├── client │ ├── middleware.rs │ ├── config.rs │ └── builder.rs │ ├── lib.rs │ ├── storage.rs │ ├── conf.rs │ ├── errors.rs │ ├── group.rs │ ├── expressions.rs │ ├── window.rs │ ├── session.rs │ └── column.rs ├── .github ├── pull_request_template.md └── workflows │ ├── release.yml │ └── build.yml ├── .pre-commit-config.yaml ├── docker-compose.yml ├── examples ├── Cargo.toml ├── src │ ├── readstream.rs │ ├── sql.rs │ ├── reader.rs │ ├── writer.rs │ ├── databricks.rs │ └── deltalake.rs └── README.md ├── Cargo.toml ├── pre-commit.sh └── LICENSE /datasets/dir1/file3.json: -------------------------------------------------------------------------------- 1 | {"file":"corrupt.json"} 2 | -------------------------------------------------------------------------------- /datasets/people.txt: -------------------------------------------------------------------------------- 1 | Michael, 29 2 | Andy, 30 3 | Justin, 19 4 | -------------------------------------------------------------------------------- /datasets/people.csv: -------------------------------------------------------------------------------- 1 | name;age;job 2 | Jorge;30;Developer 3 | Bob;32;Developer 4 | -------------------------------------------------------------------------------- /datasets/users.orc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/HEAD/datasets/users.orc -------------------------------------------------------------------------------- /datasets/users.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/HEAD/datasets/users.avro -------------------------------------------------------------------------------- /datasets/people.json: -------------------------------------------------------------------------------- 1 | {"name":"Michael"} 2 | {"name":"Andy", "age":30} 3 | {"name":"Justin", "age":19} 4 | -------------------------------------------------------------------------------- /datasets/users.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/HEAD/datasets/users.parquet -------------------------------------------------------------------------------- /datasets/dir1/file1.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/HEAD/datasets/dir1/file1.parquet -------------------------------------------------------------------------------- /datasets/dir1/dir2/file2.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjrusso8/spark-connect-rs/HEAD/datasets/dir1/dir2/file2.parquet -------------------------------------------------------------------------------- /datasets/employees.json: -------------------------------------------------------------------------------- 1 | {"name":"Michael", "salary":3000} 2 | {"name":"Andy", "salary":4500} 3 | {"name":"Justin", "salary":3500} 4 | {"name":"Berta", "salary":4000} 5 | -------------------------------------------------------------------------------- /datasets/user.avsc: -------------------------------------------------------------------------------- 1 | {"namespace": "example.avro", 2 | "type": "record", 3 | "name": "User", 4 | "fields": [ 5 | {"name": "name", "type": "string"}, 6 | {"name": "favorite_color", "type": ["string", "null"]} 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /datasets/full_user.avsc: -------------------------------------------------------------------------------- 1 | {"type": "record", "namespace": "example.avro", "name": "User", "fields": [{"type": "string", "name": "name"}, {"type": ["string", "null"], "name": "favorite_color"}, {"type": {"items": "int", "type": "array"}, "name": "favorite_numbers"}]} 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | /target 19 | 20 | .vscode 21 | *.ipynb 22 | 23 | /spark-warehouse 24 | /artifacts 25 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/buf.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | version: v1 18 | breaking: 19 | use: 20 | - FILE 21 | except: 22 | - FILE_SAME_GO_PACKAGE 23 | lint: 24 | use: 25 | - DEFAULT 26 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 19 | 20 | # Description 21 | 22 | The description of the main changes of your pull request 23 | 24 | ## Related Issue(s) 25 | 30 | 31 | ## Documentation 32 | 33 | 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | repos: 19 | - repo: https://github.com/pre-commit/pre-commit-hooks 20 | rev: v4.4.0 21 | hooks: 22 | - id: check-byte-order-marker 23 | - id: check-case-conflict 24 | - id: check-merge-conflict 25 | - id: check-symlinks 26 | - id: check-yaml 27 | - id: end-of-file-fixer 28 | - id: mixed-line-ending 29 | - id: trailing-whitespace 30 | - repo: https://github.com/doublify/pre-commit-rust 31 | rev: v1.0 32 | hooks: 33 | - id: fmt 34 | - id: cargo-check 35 | - id: clippy 36 | -------------------------------------------------------------------------------- /crates/connect/build.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | use std::fs; 19 | 20 | fn main() -> Result<(), Box> { 21 | let files = fs::read_dir("./protobuf/spark-3.5/spark/connect/")?; 22 | 23 | let mut file_paths: Vec = vec![]; 24 | 25 | for file in files { 26 | let entry = file?.path(); 27 | file_paths.push(entry.to_str().unwrap().to_string()); 28 | } 29 | 30 | tonic_build::configure() 31 | .protoc_arg("--experimental_allow_proto3_optional") 32 | .build_server(false) 33 | .build_client(true) 34 | .build_transport(true) 35 | .compile(file_paths.as_ref(), &["./protobuf/spark-3.5/"])?; 36 | 37 | Ok(()) 38 | } 39 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | services: 19 | spark: 20 | image: "apache/spark:3.5.3-scala2.12-java11-r-ubuntu" 21 | command: > 22 | /opt/spark/sbin/start-connect-server.sh 23 | --packages "org.apache.spark:spark-connect_2.12:3.5.3,io.delta:delta-spark_2.12:3.0.0" 24 | --conf "spark.driver.extraJavaOptions=-Divy.cache.dir=/tmp -Divy.home=/tmp" 25 | --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" 26 | --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" 27 | environment: 28 | - SPARK_NO_DAEMONIZE=true 29 | ports: 30 | - "4040:4040" 31 | - "15002:15002" 32 | volumes: 33 | - ./datasets:/opt/spark/work-dir/datasets 34 | -------------------------------------------------------------------------------- /examples/Cargo.toml: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | [package] 19 | name = "examples" 20 | version = "0.0.0" 21 | authors.workspace = true 22 | edition.workspace = true 23 | license.workspace = true 24 | publish = false 25 | 26 | [[bin]] 27 | name = "sql" 28 | path = "src/sql.rs" 29 | 30 | [[bin]] 31 | name = "deltalake" 32 | path = "src/deltalake.rs" 33 | 34 | [[bin]] 35 | name = "reader" 36 | path = "src/reader.rs" 37 | 38 | [[bin]] 39 | name = "writer" 40 | path = "src/writer.rs" 41 | 42 | [[bin]] 43 | name = "readstream" 44 | path = "src/readstream.rs" 45 | 46 | [[bin]] 47 | name = "databricks" 48 | path = "src/databricks.rs" 49 | required-feature = ["tls"] 50 | 51 | [dependencies] 52 | spark-connect-rs = { version = "0.0.2", path = "../crates/connect" } 53 | tokio = { workspace = true, features = ["rt-multi-thread"] } 54 | 55 | [features] 56 | tls = ["spark-connect-rs/tls"] 57 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/common.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | package spark.connect; 21 | 22 | option java_multiple_files = true; 23 | option java_package = "org.apache.spark.connect.proto"; 24 | option go_package = "internal/generated"; 25 | 26 | // StorageLevel for persisting Datasets/Tables. 27 | message StorageLevel { 28 | // (Required) Whether the cache should use disk or not. 29 | bool use_disk = 1; 30 | // (Required) Whether the cache should use memory or not. 31 | bool use_memory = 2; 32 | // (Required) Whether the cache should use off-heap or not. 33 | bool use_off_heap = 3; 34 | // (Required) Whether the cached data is deserialized or not. 35 | bool deserialized = 4; 36 | // (Required) The number of replicas. 37 | int32 replication = 5; 38 | } 39 | 40 | 41 | // ResourceInformation to hold information about a type of Resource. 42 | // The corresponding class is 'org.apache.spark.resource.ResourceInformation' 43 | message ResourceInformation { 44 | // (Required) The name of the resource 45 | string name = 1; 46 | // (Required) An array of strings describing the addresses of the resource. 47 | repeated string addresses = 2; 48 | } 49 | -------------------------------------------------------------------------------- /examples/src/readstream.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | use spark_connect_rs::streaming::{OutputMode, Trigger}; 19 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 20 | 21 | use std::{thread, time}; 22 | 23 | // This example demonstrates creating a Spark Stream and monitoring the progress 24 | #[tokio::main] 25 | async fn main() -> Result<(), Box> { 26 | let spark: SparkSession = 27 | SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=stream_example") 28 | .build() 29 | .await?; 30 | 31 | let df = spark 32 | .read_stream() 33 | .format("rate") 34 | .option("rowsPerSecond", "5") 35 | .load(None)?; 36 | 37 | let query = df 38 | .write_stream() 39 | .format("console") 40 | .query_name("example_stream") 41 | .output_mode(OutputMode::Append) 42 | .trigger(Trigger::ProcessingTimeInterval("1 seconds".to_string())) 43 | .start(None) 44 | .await?; 45 | 46 | // loop to get multiple progression stats 47 | for _ in 1..5 { 48 | thread::sleep(time::Duration::from_secs(5)); 49 | let val = query.last_progress().await?; 50 | println!("{}", val); 51 | } 52 | 53 | // stop the active stream 54 | query.stop().await?; 55 | 56 | Ok(()) 57 | } 58 | -------------------------------------------------------------------------------- /examples/src/sql.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | // This example demonstrates creating a Spark DataFrame from a SQL command 19 | // and saving the results as a parquet and reading the new parquet file 20 | 21 | use spark_connect_rs::dataframe::SaveMode; 22 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 23 | 24 | #[tokio::main] 25 | async fn main() -> Result<(), Box> { 26 | let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/") 27 | .build() 28 | .await?; 29 | 30 | let df = spark.sql("select 'apple' as word, 123 as count").await?; 31 | 32 | df.write() 33 | .mode(SaveMode::Overwrite) 34 | .format("parquet") 35 | .save("file:///tmp/spark-connect-write-example-output.parquet") 36 | .await?; 37 | 38 | let df = spark 39 | .read() 40 | .format("parquet") 41 | .load(["file:///tmp/spark-connect-write-example-output.parquet"])?; 42 | 43 | df.show(Some(100), None, None).await?; 44 | 45 | // +---------------+ 46 | // | show_string | 47 | // +---------------+ 48 | // | +-----+-----+ | 49 | // | |word |count| | 50 | // | +-----+-----+ | 51 | // | |apple|123 | | 52 | // | +-----+-----+ | 53 | // | | 54 | // +---------------+ 55 | 56 | Ok(()) 57 | } 58 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | name: release to cargo 19 | 20 | on: 21 | push: 22 | tags: ["v*"] 23 | 24 | jobs: 25 | validate-release-tag: 26 | name: Validate git tag 27 | runs-on: ubuntu-20.04 28 | steps: 29 | - uses: actions/checkout@v4 30 | - name: compare git tag with cargo metadata 31 | run: | 32 | PUSHED_TAG=${GITHUB_REF##*/} 33 | CURR_VER=$( grep version Cargo.toml | head -n 1 | awk '{print $3}' | tr -d '"' ) 34 | if [[ "${PUSHED_TAG}" != "v${CURR_VER}" ]]; then 35 | echo "Cargo metadata has version set to ${CURR_VER}, but got pushed tag ${PUSHED_TAG}." 36 | exit 1 37 | fi 38 | working-directory: ./crates 39 | 40 | release-crate: 41 | needs: validate-release-tag 42 | name: Release crate 43 | runs-on: ubuntu-20.04 44 | steps: 45 | - uses: actions/checkout@v4 46 | 47 | - uses: actions-rs/toolchain@v1 48 | with: 49 | profile: minimal 50 | toolchain: stable 51 | override: true 52 | 53 | - name: install protoc 54 | uses: arduino/setup-protoc@v2 55 | with: 56 | version: 23.x 57 | 58 | - name: cargo publish rust 59 | uses: actions-rs/cargo@v1 60 | env: 61 | CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} 62 | with: 63 | command: publish 64 | args: --token "${CARGO_REGISTRY_TOKEN}" --package spark-connect-rs --manifest-path ./Cargo.toml 65 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | [workspace] 19 | members = ["crates/*", "examples"] 20 | resolver = "2" 21 | 22 | [workspace.package] 23 | authors = ["Steve Russo <64294847+sjrusso8@users.noreply.github.com>"] 24 | keywords = ["spark", "spark_connect"] 25 | readme = "README.md" 26 | edition = "2021" 27 | homepage = "https://github.com/sjrusso8/spark-connect-rs" 28 | description = "Apache Spark Connect Client for Rust" 29 | license = "Apache-2.0" 30 | documentation = "https://docs.rs/spark-connect-rs" 31 | repository = "https://github.com/sjrusso8/spark-connect-rs" 32 | rust-version = "1.81" 33 | 34 | [workspace.dependencies] 35 | tonic = { version ="0.11", default-features = false } 36 | 37 | tokio = { version = "1.44", default-features = false, features = ["macros"] } 38 | tower = { version = "0.5" } 39 | 40 | futures-util = { version = "0.3" } 41 | thiserror = { version = "2.0" } 42 | 43 | http-body = { version = "0.4.6" } 44 | 45 | arrow = { version = "55", features = ["prettyprint"] } 46 | arrow-ipc = { version = "55" } 47 | 48 | serde_json = { version = "1" } 49 | 50 | prost = { version = "0.12" } 51 | prost-types = { version = "0.12" } 52 | 53 | rand = { version = "0.9" } 54 | uuid = { version = "1.16", features = ["v4"] } 55 | url = { version = "2.5" } 56 | regex = { version = "1" } 57 | 58 | chrono = { version = "0.4" } 59 | 60 | datafusion = { version = "47.0", default-features = false } 61 | polars = { version = "0.43", default-features = false } 62 | polars-arrow = { version = "0.43", default-features = false, features = ["arrow_rs"] } 63 | -------------------------------------------------------------------------------- /examples/src/reader.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | // This example demonstrates creating a Spark DataFrame from a CSV with read options 19 | // and then adding transformations for 'select' & 'sort' 20 | // printing the results as "show(...)" 21 | 22 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 23 | 24 | use spark_connect_rs::functions as F; 25 | use spark_connect_rs::types::DataType; 26 | 27 | #[tokio::main] 28 | async fn main() -> Result<(), Box> { 29 | let spark: SparkSession = SparkSessionBuilder::default().build().await?; 30 | 31 | let path = "./datasets/people.csv"; 32 | 33 | let df = spark 34 | .read() 35 | .format("csv") 36 | .option("header", "True") 37 | .option("delimiter", ";") 38 | .load([path])?; 39 | 40 | // select columns and perform data manipulations 41 | let df = df 42 | .select([ 43 | F::col("name"), 44 | F::col("age").cast(DataType::Integer).alias("age_int"), 45 | (F::lit(3.0) + F::col("age_int")).alias("addition"), 46 | ]) 47 | .sort([F::col("name").desc()]); 48 | 49 | df.show(Some(5), None, None).await?; 50 | 51 | // print results 52 | // +--------------------------+ 53 | // | show_string | 54 | // +--------------------------+ 55 | // | +-----+-------+--------+ | 56 | // | |name |age_int|addition| | 57 | // | +-----+-------+--------+ | 58 | // | |Jorge|30 |33.0 | | 59 | // | |Bob |32 |35.0 | | 60 | // | +-----+-------+--------+ | 61 | // | | 62 | // +--------------------------+ 63 | 64 | Ok(()) 65 | } 66 | -------------------------------------------------------------------------------- /pre-commit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | # This file is git pre-commit hook. 21 | # 22 | # Soft link it as git hook under top dir of apache arrow git repository: 23 | # $ ln -s ../../pre-commit.sh .git/hooks/pre-commit 24 | # 25 | # This file be run directly: 26 | # $ ./pre-commit.sh 27 | 28 | function RED() { 29 | echo "\033[0;31m$@\033[0m" 30 | } 31 | 32 | function GREEN() { 33 | echo "\033[0;32m$@\033[0m" 34 | } 35 | 36 | function BYELLOW() { 37 | echo "\033[1;33m$@\033[0m" 38 | } 39 | 40 | # env GIT_DIR is set by git when run a pre-commit hook. 41 | if [ -z "${GIT_DIR}" ]; then 42 | GIT_DIR=$(git rev-parse --show-toplevel) 43 | fi 44 | 45 | cd ${GIT_DIR} 46 | 47 | NUM_CHANGES=$(git diff --cached --name-only . | 48 | grep -e ".*/*.rs$" | 49 | awk '{print $1}' | 50 | wc -l) 51 | 52 | if [ ${NUM_CHANGES} -eq 0 ]; then 53 | echo -e "$(GREEN INFO): no staged changes in *.rs, $(GREEN skip cargo fmt/clippy)" 54 | exit 0 55 | fi 56 | 57 | # 1. cargo clippy 58 | 59 | echo -e "$(GREEN INFO): cargo clippy ..." 60 | 61 | # Cargo clippy always return exit code 0, and `tee` doesn't work. 62 | # So let's just run cargo clippy. 63 | cargo clippy 64 | echo -e "$(GREEN INFO): cargo clippy done" 65 | 66 | # 2. cargo fmt: format with nightly and stable. 67 | 68 | CHANGED_BY_CARGO_FMT=false 69 | echo -e "$(GREEN INFO): cargo fmt with nightly and stable ..." 70 | 71 | for version in nightly stable; do 72 | CMD="cargo +${version} fmt" 73 | ${CMD} --all -q -- --check 2>/dev/null 74 | if [ $? -ne 0 ]; then 75 | ${CMD} --all 76 | echo -e "$(BYELLOW WARN): ${CMD} changed some files" 77 | CHANGED_BY_CARGO_FMT=true 78 | fi 79 | done 80 | 81 | if ${CHANGED_BY_CARGO_FMT}; then 82 | echo -e "$(RED FAIL): git commit $(RED ABORTED), please have a look and run git add/commit again" 83 | exit 1 84 | fi 85 | 86 | exit 0 87 | -------------------------------------------------------------------------------- /crates/connect/Cargo.toml: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | [package] 19 | name = "spark-connect-rs" 20 | version = "0.0.2" 21 | authors.workspace = true 22 | keywords.workspace = true 23 | readme.workspace = true 24 | edition.workspace = true 25 | homepage.workspace = true 26 | description.workspace = true 27 | license.workspace = true 28 | documentation.workspace = true 29 | repository.workspace = true 30 | rust-version.workspace = true 31 | include = [ 32 | "build.rs", 33 | "src/**/*", 34 | "protobuf/**/*", 35 | ] 36 | 37 | [dependencies] 38 | tonic = { workspace = true, default-features = false, optional = true } 39 | 40 | tower = { workspace = true } 41 | tokio = { workspace = true, optional = true } 42 | 43 | futures-util = { workspace = true } 44 | thiserror = { workspace = true } 45 | 46 | http-body = { workspace = true } 47 | 48 | arrow = { workspace = true } 49 | arrow-ipc = { workspace = true } 50 | 51 | serde_json = { workspace = true } 52 | 53 | prost = { workspace = true } 54 | prost-types = { workspace = true } 55 | 56 | rand = { workspace = true } 57 | uuid = { workspace = true } 58 | url = { workspace = true } 59 | regex = { workspace = true } 60 | 61 | chrono = { workspace = true } 62 | 63 | datafusion = { workspace = true, optional = true } 64 | 65 | polars = { workspace = true, optional = true } 66 | polars-arrow = { workspace = true, optional = true } 67 | 68 | [dev-dependencies] 69 | futures = "0.3" 70 | tokio = { workspace = true, features = ["rt-multi-thread"] } 71 | 72 | [build-dependencies] 73 | tonic-build = "0.11" 74 | 75 | [lib] 76 | doctest = false 77 | 78 | [features] 79 | default = [ 80 | "tokio", 81 | "tonic/codegen", 82 | "tonic/prost", 83 | "tonic/transport", 84 | ] 85 | 86 | tls = [ 87 | "tonic/tls", 88 | "tonic/tls-roots" 89 | ] 90 | 91 | datafusion = [ 92 | "dep:datafusion" 93 | ] 94 | 95 | polars = [ 96 | "dep:polars", 97 | "dep:polars-arrow" 98 | ] 99 | -------------------------------------------------------------------------------- /examples/src/writer.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 19 | 20 | use spark_connect_rs::functions::col; 21 | 22 | use spark_connect_rs::dataframe::SaveMode; 23 | 24 | // This example demonstrates creating a Spark DataFrame from range() 25 | // alias the column name, writing the results to a CSV 26 | // then reading the csv file back 27 | #[tokio::main] 28 | async fn main() -> Result<(), Box> { 29 | let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/") 30 | .build() 31 | .await?; 32 | 33 | let df = spark 34 | .range(None, 1000, 1, Some(16)) 35 | .select([col("id").alias("range_id")]); 36 | 37 | let path = "file:///tmp/range_table/"; 38 | 39 | df.write() 40 | .format("csv") 41 | .mode(SaveMode::Overwrite) 42 | .option("header", "true") 43 | .save(path) 44 | .await?; 45 | 46 | let df = spark 47 | .read() 48 | .format("csv") 49 | .option("header", "true") 50 | .load([path])?; 51 | 52 | df.show(Some(10), None, None).await?; 53 | 54 | // print results may slighty vary but should be close to the below 55 | // +--------------------------+ 56 | // | show_string | 57 | // +--------------------------+ 58 | // | +--------+ | 59 | // | |range_id| | 60 | // | +--------+ | 61 | // | |312 | | 62 | // | |313 | | 63 | // | |314 | | 64 | // | |315 | | 65 | // | |316 | | 66 | // | |317 | | 67 | // | |318 | | 68 | // | |319 | | 69 | // | |320 | | 70 | // | |321 | | 71 | // | +--------+ | 72 | // | only showing top 10 rows | 73 | // | | 74 | // +--------------------------+ 75 | 76 | Ok(()) 77 | } 78 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 19 | 20 | # Examples 21 | 22 | Set of examples that show off different features provided by `spark-connect-rs` client. 23 | 24 | In order to build these examples, you must have the `protoc` protocol buffers compilter 25 | installed, along with the git submodule synced. 26 | 27 | ```bash 28 | git clone https://github.com/sjrusso8/spark-connect-rs.git 29 | git submodule update --init --recursive 30 | ``` 31 | 32 | ### sql 33 | 34 | Write a simple SQL statement and save the dataframe as a parquet 35 | 36 | ```bash 37 | cargo run --bin sql 38 | ``` 39 | 40 | ### reader 41 | 42 | Read a CSV file, select specific columns, and display the results 43 | 44 | ```bash 45 | cargo run --bin reader 46 | ``` 47 | 48 | ### writer 49 | 50 | Create a dataframe, and save the results to a file 51 | 52 | ```bash 53 | cargo run --bin writer 54 | ``` 55 | 56 | ### readstream 57 | 58 | Create a streaming query, and monitor the progress of the stream 59 | 60 | ```bash 61 | cargo run --bin readstream 62 | ``` 63 | 64 | ### deltalake 65 | 66 | Read a file into a dataframe, save the result as a deltalake table, and append a new record to the table. 67 | 68 | **Prerequisite** the spark cluster must be started with the deltalake package. The `docker-compose.yml` provided in the repo has deltalake pre-installed. 69 | Or if you are running a spark connect server location, run the below scripts first 70 | 71 | If you are running a local spark connect server. The Delta Lake jars need to be added onto the server before it starts. 72 | 73 | ```bash 74 | $ $SPARK_HOME/sbin/start-connect-server.sh --packages "org.apache.spark:spark-connect_2.12:3.5.1,io.delta:delta-spark_2.12:3.0.0" \ 75 | --conf "spark.driver.extraJavaOptions=-Divy.cache.dir=/tmp -Divy.home=/tmp" \ 76 | --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" \ 77 | --conf "spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog" 78 | ``` 79 | 80 | ```bash 81 | cargo run --bin deltalake 82 | ``` 83 | 84 | ### databricks 85 | 86 | Read a Unity Catalog table, perform an aggregation, and display the results. 87 | 88 | **Prerequisite** must have access to a Databricks workspace, a personal access token, and cluster running >=13.3LTS. 89 | 90 | ```bash 91 | cargo run --bin databricks --features=tls 92 | ``` 93 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | ensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | name: build 19 | 20 | on: 21 | push: 22 | branches: [main, "v*"] 23 | pull_request: 24 | branches: [main, "v*"] 25 | 26 | jobs: 27 | format: 28 | runs-on: ubuntu-latest 29 | steps: 30 | - uses: actions/checkout@v4 31 | with: 32 | submodules: "true" 33 | 34 | - name: install protoc 35 | uses: arduino/setup-protoc@v2 36 | with: 37 | version: 23.x 38 | 39 | - name: Install minimal stable with clippy and rustfmt 40 | uses: actions-rs/toolchain@v1 41 | with: 42 | profile: default 43 | toolchain: stable 44 | override: true 45 | 46 | - name: Format 47 | run: cargo fmt -- --check 48 | 49 | build: 50 | runs-on: ubuntu-latest 51 | 52 | steps: 53 | - uses: actions/checkout@v4 54 | with: 55 | submodules: "true" 56 | 57 | - name: install protoc 58 | uses: arduino/setup-protoc@v2 59 | with: 60 | version: 23.x 61 | 62 | - name: install minimal stable with clippy and rustfmt 63 | uses: actions-rs/toolchain@v1 64 | with: 65 | profile: default 66 | toolchain: stable 67 | override: true 68 | 69 | - uses: Swatinem/rust-cache@v2 70 | 71 | - name: build and lint with clippy 72 | run: cargo clippy 73 | 74 | - name: Check docs 75 | run: cargo doc 76 | 77 | - name: Check no default features (except rustls) 78 | run: cargo check 79 | 80 | integration_test: 81 | name: integration tests 82 | runs-on: ubuntu-latest 83 | env: 84 | CARGO_INCREMENTAL: 0 85 | # Disable full debug symbol generation to speed up CI build and keep memory down 86 | # 87 | RUSTFLAGS: "-C debuginfo=line-tables-only" 88 | # https://github.com/rust-lang/cargo/issues/10280 89 | CARGO_NET_GIT_FETCH_WITH_CLI: "true" 90 | RUST_BACKTRACE: "1" 91 | 92 | steps: 93 | - uses: actions/checkout@v4 94 | with: 95 | submodules: "true" 96 | 97 | - name: install protoc 98 | uses: arduino/setup-protoc@v2 99 | with: 100 | version: 23.x 101 | 102 | - name: install minimal stable with clippy and rustfmt 103 | uses: actions-rs/toolchain@v1 104 | with: 105 | profile: default 106 | toolchain: stable 107 | override: true 108 | 109 | - uses: Swatinem/rust-cache@v2 110 | 111 | - name: Start emulated services 112 | run: docker compose up -d 113 | 114 | - name: Run tests 115 | run: cargo test -p spark-connect-rs --features polars,datafusion 116 | -------------------------------------------------------------------------------- /examples/src/databricks.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | // This example demonstrates connecting to a Databricks Cluster via a tls connection. 19 | // 20 | // This demo requires access to a Databricks Workspace, a personal access token, 21 | // and a cluster id. The cluster should be running a 13.3LTS runtime or greater. Populate 22 | // the remote URL string between the `<>` with the appropriate details. 23 | // 24 | // The Databricks workspace instance name is the same as the Server Hostname value for your cluster. 25 | // Get connection details for a Databricks compute resource via https://docs.databricks.com/en/integrations/compute-details.html 26 | // 27 | // To view the connected Spark Session, go to the cluster Spark UI and select the 'Connect' tab. 28 | 29 | use spark_connect_rs::functions::{avg, col}; 30 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 31 | 32 | #[tokio::main] 33 | async fn main() -> Result<(), Box> { 34 | let conn_str = "sc://:443/;token=;x-databricks-cluster-id="; 35 | 36 | // connect the databricks cluster 37 | let spark: SparkSession = SparkSessionBuilder::remote(conn_str).build().await?; 38 | 39 | // read unity catalog table 40 | let df = spark.read().table("samples.nyctaxi.trips", None)?; 41 | 42 | // apply a filter 43 | let filter = "trip_distance BETWEEN 0 AND 10 AND fare_amount BETWEEN 0 AND 50"; 44 | let df = df.filter(filter); 45 | 46 | // groupby the pickup 47 | let df = df 48 | .select(["pickup_zip", "fare_amount"]) 49 | .group_by(Some(["pickup_zip"])); 50 | 51 | // average the fare amount and order by the top 10 zip codes 52 | let df = df 53 | .agg([avg(col("fare_amount")).alias("avg_fare_amount")]) 54 | .order_by([col("avg_fare_amount").desc()]); 55 | 56 | df.show(Some(10), None, None).await?; 57 | 58 | // +---------------------------------+ 59 | // | show_string | 60 | // +---------------------------------+ 61 | // | +----------+------------------+ | 62 | // | |pickup_zip|avg_fare_amount | | 63 | // | +----------+------------------+ | 64 | // | |7086 |40.0 | | 65 | // | |7030 |40.0 | | 66 | // | |11424 |34.25 | | 67 | // | |7087 |31.0 | | 68 | // | |10470 |28.0 | | 69 | // | |11371 |25.532619926199263| | 70 | // | |11375 |25.5 | | 71 | // | |11370 |22.452380952380953| | 72 | // | |11207 |20.5 | | 73 | // | |11218 |20.0 | | 74 | // | +----------+------------------+ | 75 | // | only showing top 10 rows | 76 | // | | 77 | // +---------------------------------+ 78 | 79 | Ok(()) 80 | } 81 | -------------------------------------------------------------------------------- /crates/connect/src/client/middleware.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Middleware services implemented with tower.rs 19 | 20 | use std::collections::HashMap; 21 | use std::fmt::Debug; 22 | use std::str::FromStr; 23 | use std::task::{Context, Poll}; 24 | 25 | use futures_util::future::BoxFuture; 26 | use http_body::combinators::UnsyncBoxBody; 27 | 28 | use tonic::codegen::http::Request; 29 | use tonic::codegen::http::{HeaderName, HeaderValue}; 30 | 31 | use tower::Service; 32 | 33 | /// Headers to apply a gRPC request 34 | #[derive(Debug, Clone)] 35 | pub struct HeadersLayer { 36 | headers: HashMap, 37 | } 38 | 39 | impl HeadersLayer { 40 | pub fn new(headers: HashMap) -> Self { 41 | Self { headers } 42 | } 43 | } 44 | 45 | impl tower::Layer for HeadersLayer { 46 | type Service = HeadersMiddleware; 47 | 48 | fn layer(&self, inner: S) -> Self::Service { 49 | HeadersMiddleware::new(inner, self.headers.clone()) 50 | } 51 | } 52 | 53 | /// Middleware used to apply provided headers onto a gRPC request 54 | #[derive(Clone, Debug)] 55 | pub struct HeadersMiddleware { 56 | inner: S, 57 | headers: HashMap, 58 | } 59 | 60 | #[allow(dead_code)] 61 | impl HeadersMiddleware { 62 | pub fn new(inner: S, headers: HashMap) -> Self { 63 | Self { inner, headers } 64 | } 65 | } 66 | 67 | // TODO! as of now Request is not clone. So the retry logic does not work. 68 | // https://github.com/tower-rs/tower/pull/790 69 | impl Service>> for HeadersMiddleware 70 | where 71 | S: Service>> 72 | + Clone 73 | + Send 74 | + Sync 75 | + 'static, 76 | S::Future: Send + 'static, 77 | S::Response: Send + Debug + 'static, 78 | S::Error: Debug, 79 | { 80 | type Response = S::Response; 81 | type Error = S::Error; 82 | type Future = BoxFuture<'static, Result>; 83 | 84 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 85 | self.inner.poll_ready(cx).map_err(Into::into) 86 | } 87 | 88 | fn call( 89 | &mut self, 90 | mut request: Request>, 91 | ) -> Self::Future { 92 | let clone = self.inner.clone(); 93 | let mut inner = std::mem::replace(&mut self.inner, clone); 94 | 95 | let headers = self.headers.clone(); 96 | 97 | Box::pin(async move { 98 | for (key, value) in &headers { 99 | let meta_key = HeaderName::from_str(key.as_str()).unwrap(); 100 | let meta_val = HeaderValue::from_str(value.as_str()).unwrap(); 101 | 102 | request.headers_mut().insert(meta_key, meta_val); 103 | } 104 | 105 | inner.call(request).await 106 | }) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /crates/connect/src/client/config.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | use std::collections::HashMap; 19 | use uuid::Uuid; 20 | 21 | use crate::client::builder::{Host, Port}; 22 | use crate::client::ChannelBuilder; 23 | 24 | /// Config handler to set custom SparkSessionBuilder options 25 | #[derive(Clone, Debug, Default)] 26 | pub struct Config { 27 | pub host: Host, 28 | pub port: Port, 29 | pub session_id: Uuid, 30 | pub token: Option, 31 | pub user_id: Option, 32 | pub user_agent: Option, 33 | pub use_ssl: bool, 34 | pub headers: Option>, 35 | } 36 | 37 | impl Config { 38 | pub fn new() -> Self { 39 | Config { 40 | host: "localhost".to_string(), 41 | port: 15002, 42 | token: None, 43 | session_id: Uuid::new_v4(), 44 | user_id: ChannelBuilder::create_user_id(None), 45 | user_agent: ChannelBuilder::create_user_agent(None), 46 | use_ssl: false, 47 | headers: None, 48 | } 49 | } 50 | 51 | pub fn host(mut self, val: &str) -> Self { 52 | self.host = val.to_string(); 53 | self 54 | } 55 | 56 | pub fn port(mut self, val: Port) -> Self { 57 | self.port = val; 58 | self 59 | } 60 | 61 | pub fn token(mut self, val: &str) -> Self { 62 | self.token = Some(val.to_string()); 63 | self 64 | } 65 | 66 | pub fn session_id(mut self, val: Uuid) -> Self { 67 | self.session_id = val; 68 | self 69 | } 70 | 71 | pub fn user_id(mut self, val: &str) -> Self { 72 | self.user_id = Some(val.to_string()); 73 | self 74 | } 75 | 76 | pub fn user_agent(mut self, val: &str) -> Self { 77 | self.user_agent = Some(val.to_string()); 78 | self 79 | } 80 | 81 | pub fn use_ssl(mut self, val: bool) -> Self { 82 | self.use_ssl = val; 83 | self 84 | } 85 | 86 | pub fn headers(mut self, val: HashMap) -> Self { 87 | self.headers = Some(val); 88 | self 89 | } 90 | } 91 | 92 | impl From for ChannelBuilder { 93 | fn from(config: Config) -> Self { 94 | // if there is a token, then it needs to be added to the headers 95 | // do not overwrite any existing authentication header 96 | 97 | let mut headers = config.headers.unwrap_or_default(); 98 | 99 | if let Some(token) = &config.token { 100 | headers 101 | .entry("authorization".to_string()) 102 | .or_insert_with(|| format!("Bearer {}", token)); 103 | } 104 | 105 | Self { 106 | host: config.host, 107 | port: config.port, 108 | session_id: config.session_id, 109 | token: config.token, 110 | user_id: config.user_id, 111 | user_agent: config.user_agent, 112 | use_ssl: config.use_ssl, 113 | headers: if headers.is_empty() { 114 | None 115 | } else { 116 | Some(headers) 117 | }, 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/types.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | package spark.connect; 21 | 22 | option java_multiple_files = true; 23 | option java_package = "org.apache.spark.connect.proto"; 24 | option go_package = "internal/generated"; 25 | 26 | // This message describes the logical [[DataType]] of something. It does not carry the value 27 | // itself but only describes it. 28 | message DataType { 29 | oneof kind { 30 | NULL null = 1; 31 | 32 | Binary binary = 2; 33 | 34 | Boolean boolean = 3; 35 | 36 | // Numeric types 37 | Byte byte = 4; 38 | Short short = 5; 39 | Integer integer = 6; 40 | Long long = 7; 41 | 42 | Float float = 8; 43 | Double double = 9; 44 | Decimal decimal = 10; 45 | 46 | // String types 47 | String string = 11; 48 | Char char = 12; 49 | VarChar var_char = 13; 50 | 51 | // Datatime types 52 | Date date = 14; 53 | Timestamp timestamp = 15; 54 | TimestampNTZ timestamp_ntz = 16; 55 | 56 | // Interval types 57 | CalendarInterval calendar_interval = 17; 58 | YearMonthInterval year_month_interval = 18; 59 | DayTimeInterval day_time_interval = 19; 60 | 61 | // Complex types 62 | Array array = 20; 63 | Struct struct = 21; 64 | Map map = 22; 65 | 66 | // UserDefinedType 67 | UDT udt = 23; 68 | 69 | // UnparsedDataType 70 | Unparsed unparsed = 24; 71 | } 72 | 73 | message Boolean { 74 | uint32 type_variation_reference = 1; 75 | } 76 | 77 | message Byte { 78 | uint32 type_variation_reference = 1; 79 | } 80 | 81 | message Short { 82 | uint32 type_variation_reference = 1; 83 | } 84 | 85 | message Integer { 86 | uint32 type_variation_reference = 1; 87 | } 88 | 89 | message Long { 90 | uint32 type_variation_reference = 1; 91 | } 92 | 93 | message Float { 94 | uint32 type_variation_reference = 1; 95 | } 96 | 97 | message Double { 98 | uint32 type_variation_reference = 1; 99 | } 100 | 101 | message String { 102 | uint32 type_variation_reference = 1; 103 | } 104 | 105 | message Binary { 106 | uint32 type_variation_reference = 1; 107 | } 108 | 109 | message NULL { 110 | uint32 type_variation_reference = 1; 111 | } 112 | 113 | message Timestamp { 114 | uint32 type_variation_reference = 1; 115 | } 116 | 117 | message Date { 118 | uint32 type_variation_reference = 1; 119 | } 120 | 121 | message TimestampNTZ { 122 | uint32 type_variation_reference = 1; 123 | } 124 | 125 | message CalendarInterval { 126 | uint32 type_variation_reference = 1; 127 | } 128 | 129 | message YearMonthInterval { 130 | optional int32 start_field = 1; 131 | optional int32 end_field = 2; 132 | uint32 type_variation_reference = 3; 133 | } 134 | 135 | message DayTimeInterval { 136 | optional int32 start_field = 1; 137 | optional int32 end_field = 2; 138 | uint32 type_variation_reference = 3; 139 | } 140 | 141 | // Start compound types. 142 | message Char { 143 | int32 length = 1; 144 | uint32 type_variation_reference = 2; 145 | } 146 | 147 | message VarChar { 148 | int32 length = 1; 149 | uint32 type_variation_reference = 2; 150 | } 151 | 152 | message Decimal { 153 | optional int32 scale = 1; 154 | optional int32 precision = 2; 155 | uint32 type_variation_reference = 3; 156 | } 157 | 158 | message StructField { 159 | string name = 1; 160 | DataType data_type = 2; 161 | bool nullable = 3; 162 | optional string metadata = 4; 163 | } 164 | 165 | message Struct { 166 | repeated StructField fields = 1; 167 | uint32 type_variation_reference = 2; 168 | } 169 | 170 | message Array { 171 | DataType element_type = 1; 172 | bool contains_null = 2; 173 | uint32 type_variation_reference = 3; 174 | } 175 | 176 | message Map { 177 | DataType key_type = 1; 178 | DataType value_type = 2; 179 | bool value_contains_null = 3; 180 | uint32 type_variation_reference = 4; 181 | } 182 | 183 | message UDT { 184 | string type = 1; 185 | optional string jvm_class = 2; 186 | optional string python_class = 3; 187 | optional string serialized_python_class = 4; 188 | DataType sql_type = 5; 189 | } 190 | 191 | message Unparsed { 192 | // (Required) The unparsed data type string 193 | string data_type_string = 1; 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /crates/connect/src/lib.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Spark Connection Client for Rust 19 | //! 20 | //! Currently, the Spark Connect client for Rust is **highly experimental** and **should 21 | //! not be used in any production setting**. This is currently a "proof of concept" to identify the methods 22 | //! of interacting with Spark cluster from rust. 23 | //! 24 | //! # Quickstart 25 | //! 26 | //! Create a Spark Session and create a [DataFrame] from a [arrow::array::RecordBatch]. 27 | //! 28 | //! ```rust 29 | //! use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 30 | //! use spark_connect_rs::functions::{col, lit} 31 | //! 32 | //! #[tokio::main] 33 | //! async fn main() -> Result<(), Box> { 34 | //! 35 | //! let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") 36 | //! .build() 37 | //! .await?; 38 | //! 39 | //! let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Alice", "Bob"])); 40 | //! let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16])); 41 | //! 42 | //! let data = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])? 43 | //! 44 | //! let df = spark.create_dataframe(&data).await? 45 | //! 46 | //! // 2 records total 47 | //! let records = df.select(["*"]) 48 | //! .with_column("age_plus", col("age") + lit(4)) 49 | //! .filter(col("name").contains("o")) 50 | //! .count() 51 | //! .await?; 52 | //! 53 | //! Ok(()) 54 | //! }; 55 | //!``` 56 | //! 57 | //! Create a Spark Session and create a DataFrame from a SQL statement: 58 | //! 59 | //! ```rust 60 | //! use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 61 | //! 62 | //! #[tokio::main] 63 | //! async fn main() -> Result<(), Box> { 64 | //! 65 | //! let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") 66 | //! .build() 67 | //! .await?; 68 | //! 69 | //! let df = spark.sql("SELECT * FROM json.`/datasets/employees.json`").await?; 70 | //! 71 | //! // Show the first 5 records 72 | //! df.filter("salary > 3000").show(Some(5), None, None).await?; 73 | //! 74 | //! Ok(()) 75 | //! }; 76 | //!``` 77 | //! 78 | //! Create a Spark Session, read a CSV file into a DataFrame, apply function transformations, and write the results: 79 | //! 80 | //! ```rust 81 | //! use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 82 | //! 83 | //! use spark_connect_rs::functions as F; 84 | //! 85 | //! #[tokio::main] 86 | //! async fn main() -> Result<(), Box> { 87 | //! 88 | //! let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs") 89 | //! .build() 90 | //! .await?; 91 | //! 92 | //! let paths = ["/datasets/people.csv"]; 93 | //! 94 | //! let df = spark 95 | //! .read() 96 | //! .format("csv") 97 | //! .option("header", "True") 98 | //! .option("delimiter", ";") 99 | //! .load(paths)?; 100 | //! 101 | //! let df = df 102 | //! .filter("age > 30") 103 | //! .select([ 104 | //! F::col("name"), 105 | //! F::col("age").cast("int") 106 | //! ]); 107 | //! 108 | //! df.write() 109 | //! .format("csv") 110 | //! .option("header", "true") 111 | //! .save("/opt/spark/examples/src/main/rust/people/") 112 | //! .await?; 113 | //! 114 | //! Ok(()) 115 | //! }; 116 | //!``` 117 | //! 118 | //! ## Databricks Connection 119 | //! 120 | //! Spark Connect is enabled for Databricks Runtime 13.3 LTS and above, and requires the feature 121 | //! flag `feature = "tls"`. The connection string for the remote session must contain the following 122 | //! values in the string; 123 | //! 124 | //! ```bash 125 | //! "sc://:443/;token=;x-databricks-cluster-id=" 126 | //! ``` 127 | //! 128 | //! 129 | 130 | /// Spark Connect gRPC protobuf translated using [tonic] 131 | pub mod spark { 132 | tonic::include_proto!("spark.connect"); 133 | } 134 | 135 | pub mod catalog; 136 | pub mod client; 137 | pub mod column; 138 | pub mod conf; 139 | pub mod dataframe; 140 | pub mod errors; 141 | pub mod expressions; 142 | pub mod functions; 143 | pub mod group; 144 | pub mod plan; 145 | pub mod readwriter; 146 | pub mod session; 147 | pub mod storage; 148 | pub mod streaming; 149 | pub mod types; 150 | pub mod window; 151 | 152 | pub use dataframe::{DataFrame, DataFrameReader, DataFrameWriter}; 153 | pub use session::{SparkSession, SparkSessionBuilder}; 154 | -------------------------------------------------------------------------------- /crates/connect/src/storage.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Enum for handling Spark Storage representations 19 | 20 | use crate::spark; 21 | 22 | #[derive(Clone, Copy, Debug)] 23 | pub enum StorageLevel { 24 | None, 25 | DiskOnly, 26 | DiskOnly2, 27 | DiskOnly3, 28 | MemoryOnly, 29 | MemoryOnly2, 30 | MemoryAndDisk, 31 | MemoryAndDisk2, 32 | OffHeap, 33 | MemoryAndDiskDeser, 34 | } 35 | 36 | impl From for StorageLevel { 37 | fn from(spark_level: spark::StorageLevel) -> Self { 38 | match ( 39 | spark_level.use_disk, 40 | spark_level.use_memory, 41 | spark_level.use_off_heap, 42 | spark_level.deserialized, 43 | spark_level.replication, 44 | ) { 45 | (false, false, false, false, _) => StorageLevel::None, 46 | (true, false, false, false, 1) => StorageLevel::DiskOnly, 47 | (true, false, false, false, 2) => StorageLevel::DiskOnly2, 48 | (true, false, false, false, 3) => StorageLevel::DiskOnly3, 49 | (false, true, false, false, 1) => StorageLevel::MemoryOnly, 50 | (false, true, false, false, 2) => StorageLevel::MemoryOnly2, 51 | (true, true, false, false, 1) => StorageLevel::MemoryAndDisk, 52 | (true, true, false, false, 2) => StorageLevel::MemoryAndDisk2, 53 | (true, true, true, false, 1) => StorageLevel::OffHeap, 54 | (true, true, false, true, 1) => StorageLevel::MemoryAndDiskDeser, 55 | _ => unimplemented!(), 56 | } 57 | } 58 | } 59 | 60 | impl From for spark::StorageLevel { 61 | fn from(storage: StorageLevel) -> spark::StorageLevel { 62 | match storage { 63 | StorageLevel::None => spark::StorageLevel { 64 | use_disk: false, 65 | use_memory: false, 66 | use_off_heap: false, 67 | deserialized: false, 68 | replication: 1, 69 | }, 70 | StorageLevel::DiskOnly => spark::StorageLevel { 71 | use_disk: true, 72 | use_memory: false, 73 | use_off_heap: false, 74 | deserialized: false, 75 | replication: 1, 76 | }, 77 | StorageLevel::DiskOnly2 => spark::StorageLevel { 78 | use_disk: true, 79 | use_memory: false, 80 | use_off_heap: false, 81 | deserialized: false, 82 | replication: 2, 83 | }, 84 | StorageLevel::DiskOnly3 => spark::StorageLevel { 85 | use_disk: true, 86 | use_memory: false, 87 | use_off_heap: false, 88 | deserialized: false, 89 | replication: 3, 90 | }, 91 | StorageLevel::MemoryOnly => spark::StorageLevel { 92 | use_disk: false, 93 | use_memory: true, 94 | use_off_heap: false, 95 | deserialized: false, 96 | replication: 1, 97 | }, 98 | StorageLevel::MemoryOnly2 => spark::StorageLevel { 99 | use_disk: false, 100 | use_memory: true, 101 | use_off_heap: false, 102 | deserialized: false, 103 | replication: 2, 104 | }, 105 | StorageLevel::MemoryAndDisk => spark::StorageLevel { 106 | use_disk: true, 107 | use_memory: true, 108 | use_off_heap: false, 109 | deserialized: false, 110 | replication: 1, 111 | }, 112 | StorageLevel::MemoryAndDisk2 => spark::StorageLevel { 113 | use_disk: true, 114 | use_memory: true, 115 | use_off_heap: false, 116 | deserialized: false, 117 | replication: 2, 118 | }, 119 | StorageLevel::OffHeap => spark::StorageLevel { 120 | use_disk: true, 121 | use_memory: true, 122 | use_off_heap: true, 123 | deserialized: false, 124 | replication: 1, 125 | }, 126 | StorageLevel::MemoryAndDiskDeser => spark::StorageLevel { 127 | use_disk: true, 128 | use_memory: true, 129 | use_off_heap: false, 130 | deserialized: true, 131 | replication: 1, 132 | }, 133 | } 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /crates/connect/src/conf.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. 19 | 20 | use std::collections::HashMap; 21 | 22 | use crate::spark; 23 | 24 | use crate::client::SparkClient; 25 | use crate::errors::SparkError; 26 | 27 | /// User-facing configuration API, accessible through SparkSession.conf. 28 | pub struct RunTimeConfig { 29 | pub(crate) client: SparkClient, 30 | } 31 | 32 | /// User-facing configuration API, accessible through SparkSession.conf. 33 | /// 34 | /// Options set here are automatically propagated to the Hadoop configuration during I/O. 35 | /// 36 | /// # Example 37 | /// ```rust 38 | /// spark 39 | /// .conf() 40 | /// .set("spark.sql.shuffle.partitions", "42") 41 | /// .await?; 42 | /// ``` 43 | impl RunTimeConfig { 44 | pub fn new(client: &SparkClient) -> RunTimeConfig { 45 | RunTimeConfig { 46 | client: client.clone(), 47 | } 48 | } 49 | 50 | pub(crate) async fn set_configs( 51 | &mut self, 52 | map: &HashMap, 53 | ) -> Result<(), SparkError> { 54 | for (key, value) in map { 55 | self.set(key.as_str(), value.as_str()).await? 56 | } 57 | Ok(()) 58 | } 59 | 60 | /// Sets the given Spark runtime configuration property. 61 | pub async fn set(&mut self, key: &str, value: &str) -> Result<(), SparkError> { 62 | let op_type = spark::config_request::operation::OpType::Set(spark::config_request::Set { 63 | pairs: vec![spark::KeyValue { 64 | key: key.into(), 65 | value: Some(value.into()), 66 | }], 67 | }); 68 | let operation = spark::config_request::Operation { 69 | op_type: Some(op_type), 70 | }; 71 | 72 | let _ = self.client.config_request(operation).await?; 73 | 74 | Ok(()) 75 | } 76 | 77 | /// Resets the configuration property for the given key. 78 | pub async fn unset(&mut self, key: &str) -> Result<(), SparkError> { 79 | let op_type = 80 | spark::config_request::operation::OpType::Unset(spark::config_request::Unset { 81 | keys: vec![key.to_string()], 82 | }); 83 | let operation = spark::config_request::Operation { 84 | op_type: Some(op_type), 85 | }; 86 | 87 | let _ = self.client.config_request(operation).await?; 88 | 89 | Ok(()) 90 | } 91 | 92 | /// Indicates whether the configuration property with the given key is modifiable in the current session. 93 | pub async fn get(&mut self, key: &str, default: Option<&str>) -> Result { 94 | let operation = match default { 95 | Some(default) => { 96 | let op_type = spark::config_request::operation::OpType::GetWithDefault( 97 | spark::config_request::GetWithDefault { 98 | pairs: vec![spark::KeyValue { 99 | key: key.into(), 100 | value: Some(default.into()), 101 | }], 102 | }, 103 | ); 104 | spark::config_request::Operation { 105 | op_type: Some(op_type), 106 | } 107 | } 108 | None => { 109 | let op_type = 110 | spark::config_request::operation::OpType::Get(spark::config_request::Get { 111 | keys: vec![key.to_string()], 112 | }); 113 | spark::config_request::Operation { 114 | op_type: Some(op_type), 115 | } 116 | } 117 | }; 118 | 119 | let resp = self.client.config_request(operation).await?; 120 | 121 | let val = resp.pairs.first().unwrap().value().to_string(); 122 | 123 | Ok(val) 124 | } 125 | 126 | /// Indicates whether the configuration property with the given key is modifiable in the current session. 127 | pub async fn is_modifable(&mut self, key: &str) -> Result { 128 | let op_type = spark::config_request::operation::OpType::IsModifiable( 129 | spark::config_request::IsModifiable { 130 | keys: vec![key.to_string()], 131 | }, 132 | ); 133 | let operation = spark::config_request::Operation { 134 | op_type: Some(op_type), 135 | }; 136 | 137 | let resp = self.client.config_request(operation).await?; 138 | 139 | let val = resp.pairs.first().unwrap().value(); 140 | 141 | match val { 142 | "true" => Ok(true), 143 | "false" => Ok(false), 144 | _ => Err(SparkError::AnalysisException( 145 | "Unexpected response value for boolean".to_string(), 146 | )), 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/catalog.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | package spark.connect; 21 | 22 | import "spark/connect/common.proto"; 23 | import "spark/connect/types.proto"; 24 | 25 | option java_multiple_files = true; 26 | option java_package = "org.apache.spark.connect.proto"; 27 | option go_package = "internal/generated"; 28 | 29 | // Catalog messages are marked as unstable. 30 | message Catalog { 31 | oneof cat_type { 32 | CurrentDatabase current_database = 1; 33 | SetCurrentDatabase set_current_database = 2; 34 | ListDatabases list_databases = 3; 35 | ListTables list_tables = 4; 36 | ListFunctions list_functions = 5; 37 | ListColumns list_columns = 6; 38 | GetDatabase get_database = 7; 39 | GetTable get_table = 8; 40 | GetFunction get_function = 9; 41 | DatabaseExists database_exists = 10; 42 | TableExists table_exists = 11; 43 | FunctionExists function_exists = 12; 44 | CreateExternalTable create_external_table = 13; 45 | CreateTable create_table = 14; 46 | DropTempView drop_temp_view = 15; 47 | DropGlobalTempView drop_global_temp_view = 16; 48 | RecoverPartitions recover_partitions = 17; 49 | IsCached is_cached = 18; 50 | CacheTable cache_table = 19; 51 | UncacheTable uncache_table = 20; 52 | ClearCache clear_cache = 21; 53 | RefreshTable refresh_table = 22; 54 | RefreshByPath refresh_by_path = 23; 55 | CurrentCatalog current_catalog = 24; 56 | SetCurrentCatalog set_current_catalog = 25; 57 | ListCatalogs list_catalogs = 26; 58 | } 59 | } 60 | 61 | // See `spark.catalog.currentDatabase` 62 | message CurrentDatabase { } 63 | 64 | // See `spark.catalog.setCurrentDatabase` 65 | message SetCurrentDatabase { 66 | // (Required) 67 | string db_name = 1; 68 | } 69 | 70 | // See `spark.catalog.listDatabases` 71 | message ListDatabases { 72 | // (Optional) The pattern that the database name needs to match 73 | optional string pattern = 1; 74 | } 75 | 76 | // See `spark.catalog.listTables` 77 | message ListTables { 78 | // (Optional) 79 | optional string db_name = 1; 80 | // (Optional) The pattern that the table name needs to match 81 | optional string pattern = 2; 82 | } 83 | 84 | // See `spark.catalog.listFunctions` 85 | message ListFunctions { 86 | // (Optional) 87 | optional string db_name = 1; 88 | // (Optional) The pattern that the function name needs to match 89 | optional string pattern = 2; 90 | } 91 | 92 | // See `spark.catalog.listColumns` 93 | message ListColumns { 94 | // (Required) 95 | string table_name = 1; 96 | // (Optional) 97 | optional string db_name = 2; 98 | } 99 | 100 | // See `spark.catalog.getDatabase` 101 | message GetDatabase { 102 | // (Required) 103 | string db_name = 1; 104 | } 105 | 106 | // See `spark.catalog.getTable` 107 | message GetTable { 108 | // (Required) 109 | string table_name = 1; 110 | // (Optional) 111 | optional string db_name = 2; 112 | } 113 | 114 | // See `spark.catalog.getFunction` 115 | message GetFunction { 116 | // (Required) 117 | string function_name = 1; 118 | // (Optional) 119 | optional string db_name = 2; 120 | } 121 | 122 | // See `spark.catalog.databaseExists` 123 | message DatabaseExists { 124 | // (Required) 125 | string db_name = 1; 126 | } 127 | 128 | // See `spark.catalog.tableExists` 129 | message TableExists { 130 | // (Required) 131 | string table_name = 1; 132 | // (Optional) 133 | optional string db_name = 2; 134 | } 135 | 136 | // See `spark.catalog.functionExists` 137 | message FunctionExists { 138 | // (Required) 139 | string function_name = 1; 140 | // (Optional) 141 | optional string db_name = 2; 142 | } 143 | 144 | // See `spark.catalog.createExternalTable` 145 | message CreateExternalTable { 146 | // (Required) 147 | string table_name = 1; 148 | // (Optional) 149 | optional string path = 2; 150 | // (Optional) 151 | optional string source = 3; 152 | // (Optional) 153 | optional DataType schema = 4; 154 | // Options could be empty for valid data source format. 155 | // The map key is case insensitive. 156 | map options = 5; 157 | } 158 | 159 | // See `spark.catalog.createTable` 160 | message CreateTable { 161 | // (Required) 162 | string table_name = 1; 163 | // (Optional) 164 | optional string path = 2; 165 | // (Optional) 166 | optional string source = 3; 167 | // (Optional) 168 | optional string description = 4; 169 | // (Optional) 170 | optional DataType schema = 5; 171 | // Options could be empty for valid data source format. 172 | // The map key is case insensitive. 173 | map options = 6; 174 | } 175 | 176 | // See `spark.catalog.dropTempView` 177 | message DropTempView { 178 | // (Required) 179 | string view_name = 1; 180 | } 181 | 182 | // See `spark.catalog.dropGlobalTempView` 183 | message DropGlobalTempView { 184 | // (Required) 185 | string view_name = 1; 186 | } 187 | 188 | // See `spark.catalog.recoverPartitions` 189 | message RecoverPartitions { 190 | // (Required) 191 | string table_name = 1; 192 | } 193 | 194 | // See `spark.catalog.isCached` 195 | message IsCached { 196 | // (Required) 197 | string table_name = 1; 198 | } 199 | 200 | // See `spark.catalog.cacheTable` 201 | message CacheTable { 202 | // (Required) 203 | string table_name = 1; 204 | 205 | // (Optional) 206 | optional StorageLevel storage_level = 2; 207 | } 208 | 209 | // See `spark.catalog.uncacheTable` 210 | message UncacheTable { 211 | // (Required) 212 | string table_name = 1; 213 | } 214 | 215 | // See `spark.catalog.clearCache` 216 | message ClearCache { } 217 | 218 | // See `spark.catalog.refreshTable` 219 | message RefreshTable { 220 | // (Required) 221 | string table_name = 1; 222 | } 223 | 224 | // See `spark.catalog.refreshByPath` 225 | message RefreshByPath { 226 | // (Required) 227 | string path = 1; 228 | } 229 | 230 | // See `spark.catalog.currentCatalog` 231 | message CurrentCatalog { } 232 | 233 | // See `spark.catalog.setCurrentCatalog` 234 | message SetCurrentCatalog { 235 | // (Required) 236 | string catalog_name = 1; 237 | } 238 | 239 | // See `spark.catalog.listCatalogs` 240 | message ListCatalogs { 241 | // (Optional) The pattern that the catalog name needs to match 242 | optional string pattern = 1; 243 | } 244 | -------------------------------------------------------------------------------- /crates/connect/src/errors.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Defines a [SparkError] for representing failures in various Spark operations. 19 | //! Most of these are wrappers for tonic or arrow error messages 20 | use std::error::Error; 21 | use std::fmt::Debug; 22 | use std::io::Write; 23 | 24 | use arrow::error::ArrowError; 25 | use thiserror::Error; 26 | 27 | use tonic::Code; 28 | 29 | #[cfg(feature = "datafusion")] 30 | use datafusion::error::DataFusionError; 31 | #[cfg(feature = "polars")] 32 | use polars::error::PolarsError; 33 | 34 | /// Different `Spark` Error types 35 | #[derive(Error, Debug)] 36 | pub enum SparkError { 37 | #[error("Aborted: {0}")] 38 | Aborted(String), 39 | 40 | #[error("Already Exists: {0}")] 41 | AlreadyExists(String), 42 | 43 | #[error("Analysis Exception: {0}")] 44 | AnalysisException(String), 45 | 46 | #[error("Apache Arrow Error: {0}")] 47 | ArrowError(#[from] ArrowError), 48 | 49 | #[error("Cancelled: {0}")] 50 | Cancelled(String), 51 | 52 | #[error("Data Loss Exception: {0}")] 53 | DataLoss(String), 54 | 55 | #[error("Deadline Exceeded: {0}")] 56 | DeadlineExceeded(String), 57 | 58 | #[error("External Error: {0}")] 59 | ExternalError(Box), 60 | 61 | #[error("Failed Precondition: {0}")] 62 | FailedPrecondition(String), 63 | 64 | #[error("Invalid Connection Url: {0}")] 65 | InvalidConnectionUrl(String), 66 | 67 | #[error("Invalid Argument: {0}")] 68 | InvalidArgument(String), 69 | 70 | #[error("Io Error: {0}")] 71 | IoError(String, std::io::Error), 72 | 73 | #[error("Not Found: {0}")] 74 | NotFound(String), 75 | 76 | #[error("Not Yet Implemented: {0}")] 77 | NotYetImplemented(String), 78 | 79 | #[error("Permission Denied: {0}")] 80 | PermissionDenied(String), 81 | 82 | #[error("Resource Exhausted: {0}")] 83 | ResourceExhausted(String), 84 | 85 | #[error("Spark Session ID is not the same: {0}")] 86 | SessionNotSameException(String), 87 | 88 | #[error("Unauthenticated: {0}")] 89 | Unauthenticated(String), 90 | 91 | #[error("Unavailable: {0}")] 92 | Unavailable(String), 93 | 94 | #[error("Unkown: {0}")] 95 | Unknown(String), 96 | 97 | #[error("Unimplemented; {0}")] 98 | Unimplemented(String), 99 | 100 | #[error("Invalid UUID")] 101 | Uuid(#[from] uuid::Error), 102 | 103 | #[error("Out of Range: {0}")] 104 | OutOfRange(String), 105 | } 106 | 107 | impl SparkError { 108 | /// Wraps an external error in an `SparkError`. 109 | pub fn from_external_error(error: Box) -> Self { 110 | Self::ExternalError(error) 111 | } 112 | } 113 | 114 | impl From for SparkError { 115 | fn from(error: std::io::Error) -> Self { 116 | SparkError::IoError(error.to_string(), error) 117 | } 118 | } 119 | 120 | impl From for SparkError { 121 | fn from(error: std::str::Utf8Error) -> Self { 122 | SparkError::AnalysisException(error.to_string()) 123 | } 124 | } 125 | 126 | impl From for SparkError { 127 | fn from(error: std::string::FromUtf8Error) -> Self { 128 | SparkError::AnalysisException(error.to_string()) 129 | } 130 | } 131 | 132 | impl From for SparkError { 133 | fn from(status: tonic::Status) -> Self { 134 | match status.code() { 135 | Code::Ok => SparkError::AnalysisException(status.message().to_string()), 136 | Code::Unknown => SparkError::Unknown(status.message().to_string()), 137 | Code::Aborted => SparkError::Aborted(status.message().to_string()), 138 | Code::NotFound => SparkError::NotFound(status.message().to_string()), 139 | Code::Internal => SparkError::AnalysisException(status.message().to_string()), 140 | Code::DataLoss => SparkError::DataLoss(status.message().to_string()), 141 | Code::Cancelled => SparkError::Cancelled(status.message().to_string()), 142 | Code::OutOfRange => SparkError::OutOfRange(status.message().to_string()), 143 | Code::Unavailable => SparkError::Unavailable(status.message().to_string()), 144 | Code::AlreadyExists => SparkError::AnalysisException(status.message().to_string()), 145 | Code::InvalidArgument => SparkError::InvalidArgument(status.message().to_string()), 146 | Code::DeadlineExceeded => SparkError::DeadlineExceeded(status.message().to_string()), 147 | Code::Unimplemented => SparkError::Unimplemented(status.message().to_string()), 148 | Code::Unauthenticated => SparkError::Unauthenticated(status.message().to_string()), 149 | Code::PermissionDenied => SparkError::PermissionDenied(status.message().to_string()), 150 | Code::ResourceExhausted => SparkError::ResourceExhausted(status.message().to_string()), 151 | Code::FailedPrecondition => { 152 | SparkError::FailedPrecondition(status.message().to_string()) 153 | } 154 | } 155 | } 156 | } 157 | 158 | impl From for SparkError { 159 | fn from(value: serde_json::Error) -> Self { 160 | SparkError::AnalysisException(value.to_string()) 161 | } 162 | } 163 | 164 | #[cfg(feature = "datafusion")] 165 | impl From for SparkError { 166 | fn from(_value: DataFusionError) -> Self { 167 | SparkError::AnalysisException("Error converting to DataFusion DataFrame".to_string()) 168 | } 169 | } 170 | 171 | #[cfg(feature = "polars")] 172 | impl From for SparkError { 173 | fn from(_value: PolarsError) -> Self { 174 | SparkError::AnalysisException("Error converting to Polars DataFrame".to_string()) 175 | } 176 | } 177 | 178 | impl From for SparkError { 179 | fn from(value: tonic::codegen::http::uri::InvalidUri) -> Self { 180 | SparkError::InvalidConnectionUrl(value.to_string()) 181 | } 182 | } 183 | 184 | impl From for SparkError { 185 | fn from(value: tonic::transport::Error) -> Self { 186 | SparkError::InvalidConnectionUrl(value.to_string()) 187 | } 188 | } 189 | 190 | impl From> for SparkError { 191 | fn from(error: std::io::IntoInnerError) -> Self { 192 | SparkError::IoError(error.to_string(), error.into()) 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /examples/src/deltalake.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | // This example demonstrates creating a Spark DataFrame from a CSV with read options 19 | // and then adding transformations for 'select' & 'sort' 20 | // The resulting dataframe is saved in the `delta` format as a `managed` table 21 | // and `spark.sql` queries are run against the delta table 22 | // 23 | // The remote spark session must have the spark package `io.delta:delta-spark_2.12:{DELTA_VERSION}` enabled. 24 | // Where the `DELTA_VERSION` is the specified Delta Lake version. 25 | 26 | use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 27 | 28 | use spark_connect_rs::dataframe::SaveMode; 29 | 30 | #[tokio::main] 31 | async fn main() -> Result<(), Box> { 32 | let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/") 33 | .build() 34 | .await?; 35 | 36 | // path might vary based on where you started your spark cluster 37 | // the `/datasets/` folder of spark contains dummy data 38 | let paths = ["./datasets/people.csv"]; 39 | 40 | // Load a CSV file from the spark server 41 | let df = spark 42 | .read() 43 | .format("csv") 44 | .option("header", "True") 45 | .option("delimiter", ";") 46 | .option("inferSchema", "True") 47 | .load(paths)?; 48 | 49 | // write as a delta table and register it as a table 50 | df.write() 51 | .format("delta") 52 | .mode(SaveMode::Overwrite) 53 | .save_as_table("default.people_delta") 54 | .await?; 55 | 56 | // view the history of the table 57 | spark 58 | .sql("DESCRIBE HISTORY default.people_delta") 59 | .await? 60 | .show(Some(1), None, Some(true)) 61 | .await?; 62 | 63 | // create another dataframe 64 | let df = spark 65 | .sql("SELECT 'john' as name, 40 as age, 'engineer' as job") 66 | .await?; 67 | 68 | // append to the delta table 69 | df.write() 70 | .format("delta") 71 | .mode(SaveMode::Append) 72 | .save_as_table("default.people_delta") 73 | .await?; 74 | 75 | // view history 76 | spark 77 | .sql("DESCRIBE HISTORY default.people_delta") 78 | .await? 79 | .show(Some(2), None, Some(true)) 80 | .await?; 81 | 82 | // +-------------------------------------------------------------------------------------------------------+ 83 | // | show_string | 84 | // +-------------------------------------------------------------------------------------------------------+ 85 | // | -RECORD 0-------------------------------------------------------------------------------------------- | 86 | // | version | 1 | 87 | // | timestamp | 2024-05-17 14:27:34.462 | 88 | // | userId | NULL | 89 | // | userName | NULL | 90 | // | operation | WRITE | 91 | // | operationParameters | {mode -> Append, partitionBy -> []} | 92 | // | job | NULL | 93 | // | notebook | NULL | 94 | // | clusterId | NULL | 95 | // | readVersion | 0 | 96 | // | isolationLevel | Serializable | 97 | // | isBlindAppend | true | 98 | // | operationMetrics | {numFiles -> 1, numOutputRows -> 1, numOutputBytes -> 947} | 99 | // | userMetadata | NULL | 100 | // | engineInfo | Apache-Spark/3.5.1 Delta-Lake/3.0.0 | 101 | // | -RECORD 1-------------------------------------------------------------------------------------------- | 102 | // | version | 0 | 103 | // | timestamp | 2024-05-17 14:27:30.726 | 104 | // | userId | NULL | 105 | // | userName | NULL | 106 | // | operation | CREATE OR REPLACE TABLE AS SELECT | 107 | // | operationParameters | {isManaged -> true, description -> NULL, partitionBy -> [], properties -> {}} | 108 | // | job | NULL | 109 | // | notebook | NULL | 110 | // | clusterId | NULL | 111 | // | readVersion | NULL | 112 | // | isolationLevel | Serializable | 113 | // | isBlindAppend | false | 114 | // | operationMetrics | {numFiles -> 1, numOutputRows -> 2, numOutputBytes -> 988} | 115 | // | userMetadata | NULL | 116 | // | engineInfo | Apache-Spark/3.5.1 Delta-Lake/3.0.0 | 117 | // | | 118 | // +-------------------------------------------------------------------------------------------------------+ 119 | 120 | Ok(()) 121 | } 122 | -------------------------------------------------------------------------------- /crates/connect/src/group.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! A DataFrame created with an aggregate statement 19 | 20 | use crate::column::Column; 21 | use crate::dataframe::DataFrame; 22 | use crate::plan::LogicalPlanBuilder; 23 | 24 | use crate::functions::{invoke_func, lit}; 25 | 26 | use crate::spark; 27 | use crate::spark::aggregate::GroupType; 28 | 29 | /// A set of methods for aggregations on a [DataFrame], created by DataFrame.groupBy(). 30 | #[derive(Clone, Debug)] 31 | pub struct GroupedData { 32 | df: DataFrame, 33 | group_type: GroupType, 34 | grouping_cols: Vec, 35 | pivot_col: Option, 36 | pivot_vals: Option>, 37 | } 38 | 39 | impl GroupedData { 40 | pub fn new( 41 | df: DataFrame, 42 | group_type: GroupType, 43 | grouping_cols: Vec, 44 | pivot_col: Option, 45 | pivot_vals: Option>, 46 | ) -> GroupedData { 47 | Self { 48 | df, 49 | group_type, 50 | grouping_cols, 51 | pivot_col, 52 | pivot_vals, 53 | } 54 | } 55 | 56 | /// Compute aggregates and returns the result as a [DataFrame] 57 | pub fn agg(self, exprs: I) -> DataFrame 58 | where 59 | I: IntoIterator, 60 | S: Into, 61 | { 62 | let plan = LogicalPlanBuilder::aggregate( 63 | self.df.plan, 64 | self.group_type, 65 | self.grouping_cols, 66 | exprs, 67 | self.pivot_col, 68 | self.pivot_vals, 69 | ); 70 | 71 | DataFrame { 72 | spark_session: self.df.spark_session, 73 | plan, 74 | } 75 | } 76 | 77 | /// Computes average values for each numeric columns for each group. 78 | pub fn avg(self, cols: I) -> DataFrame 79 | where 80 | I: IntoIterator, 81 | S: Into, 82 | { 83 | self.agg([invoke_func("avg", cols)]) 84 | } 85 | 86 | /// Computes the min value for each numeric column for each group. 87 | pub fn min(self, cols: I) -> DataFrame 88 | where 89 | I: IntoIterator, 90 | S: Into, 91 | { 92 | self.agg([invoke_func("min", cols)]) 93 | } 94 | 95 | /// Computes the max value for each numeric columns for each group. 96 | pub fn max(self, cols: I) -> DataFrame 97 | where 98 | I: IntoIterator, 99 | S: Into, 100 | { 101 | self.agg([invoke_func("max", cols)]) 102 | } 103 | 104 | /// Computes the sum for each numeric columns for each group. 105 | pub fn sum(self, cols: I) -> DataFrame 106 | where 107 | I: IntoIterator, 108 | S: Into, 109 | { 110 | self.agg([invoke_func("sum", cols)]) 111 | } 112 | 113 | /// Counts the number of records for each group. 114 | pub fn count(self) -> DataFrame { 115 | self.agg([invoke_func("count", [lit(1).alias("count")])]) 116 | } 117 | 118 | /// Pivots a column of the current [DataFrame] and perform the specified aggregation 119 | pub fn pivot(self, col: &str, values: Option>) -> GroupedData { 120 | let pivot_vals = values.map(|vals| vals.iter().map(|val| val.to_string().into()).collect()); 121 | 122 | GroupedData::new( 123 | self.df, 124 | GroupType::Pivot, 125 | self.grouping_cols, 126 | Some(Column::from(col).into()), 127 | pivot_vals, 128 | ) 129 | } 130 | } 131 | 132 | #[cfg(test)] 133 | mod tests { 134 | 135 | use arrow::array::{ArrayRef, Int64Array, StringArray}; 136 | use arrow::datatypes::{DataType, Field, Schema}; 137 | use arrow::record_batch::RecordBatch; 138 | use std::sync::Arc; 139 | 140 | use crate::errors::SparkError; 141 | use crate::SparkSession; 142 | use crate::SparkSessionBuilder; 143 | 144 | use crate::functions::col; 145 | 146 | use crate::column::Column; 147 | 148 | async fn setup() -> SparkSession { 149 | println!("SparkSession Setup"); 150 | 151 | let connection = 152 | "sc://127.0.0.1:15002/;user_id=rust_group;session_id=02c25694-e875-4a25-9955-bc5bc56c4ade"; 153 | 154 | SparkSessionBuilder::remote(connection) 155 | .build() 156 | .await 157 | .unwrap() 158 | } 159 | 160 | #[tokio::test] 161 | async fn test_group_count() -> Result<(), SparkError> { 162 | let spark = setup().await; 163 | 164 | let df = spark.range(None, 100, 1, Some(8)); 165 | 166 | let res = df.group_by::>(None).count().collect().await?; 167 | 168 | let a: ArrayRef = Arc::new(Int64Array::from(vec![100])); 169 | 170 | let expected = RecordBatch::try_from_iter(vec![("count(1 AS count)", a)])?; 171 | 172 | assert_eq!(expected, res); 173 | Ok(()) 174 | } 175 | 176 | #[tokio::test] 177 | async fn test_group_pivot() -> Result<(), SparkError> { 178 | let spark = setup().await; 179 | 180 | let course: ArrayRef = Arc::new(StringArray::from(vec![ 181 | "dotNET", "Java", "dotNET", "dotNET", "Java", 182 | ])); 183 | let year: ArrayRef = Arc::new(Int64Array::from(vec![2012, 2012, 2012, 2013, 2013])); 184 | let earnings: ArrayRef = Arc::new(Int64Array::from(vec![10000, 20000, 5000, 48000, 30000])); 185 | 186 | let data = RecordBatch::try_from_iter(vec![ 187 | ("course", course), 188 | ("year", year), 189 | ("earnings", earnings), 190 | ])?; 191 | 192 | let df = spark.create_dataframe(&data)?; 193 | 194 | let res = df 195 | .clone() 196 | .group_by(Some([col("year")])) 197 | .pivot("course", Some(vec!["Java"])) 198 | .sum(["earnings"]) 199 | .collect() 200 | .await?; 201 | 202 | let year: ArrayRef = Arc::new(Int64Array::from(vec![2012, 2013])); 203 | let earnings: ArrayRef = Arc::new(Int64Array::from(vec![20000, 30000])); 204 | 205 | let schema = Schema::new(vec![ 206 | Field::new("year", DataType::Int64, false), 207 | Field::new("Java", DataType::Int64, true), 208 | ]); 209 | 210 | let expected = RecordBatch::try_new(Arc::new(schema), vec![year, earnings])?; 211 | 212 | assert_eq!(expected, res); 213 | 214 | let res = df 215 | .group_by(Some([col("year")])) 216 | .pivot("course", None) 217 | .sum(["earnings"]) 218 | .collect() 219 | .await?; 220 | 221 | let year: ArrayRef = Arc::new(Int64Array::from(vec![2012, 2013])); 222 | let java_earnings: ArrayRef = Arc::new(Int64Array::from(vec![20000, 30000])); 223 | let dnet_earnings: ArrayRef = Arc::new(Int64Array::from(vec![15000, 48000])); 224 | 225 | let schema = Schema::new(vec![ 226 | Field::new("year", DataType::Int64, false), 227 | Field::new("Java", DataType::Int64, true), 228 | Field::new("dotNET", DataType::Int64, true), 229 | ]); 230 | 231 | let expected = 232 | RecordBatch::try_new(Arc::new(schema), vec![year, java_earnings, dnet_earnings])?; 233 | 234 | assert_eq!(expected, res); 235 | 236 | Ok(()) 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /datasets/kv1.txt: -------------------------------------------------------------------------------- 1 | 238val_238 2 | 86val_86 3 | 311val_311 4 | 27val_27 5 | 165val_165 6 | 409val_409 7 | 255val_255 8 | 278val_278 9 | 98val_98 10 | 484val_484 11 | 265val_265 12 | 193val_193 13 | 401val_401 14 | 150val_150 15 | 273val_273 16 | 224val_224 17 | 369val_369 18 | 66val_66 19 | 128val_128 20 | 213val_213 21 | 146val_146 22 | 406val_406 23 | 429val_429 24 | 374val_374 25 | 152val_152 26 | 469val_469 27 | 145val_145 28 | 495val_495 29 | 37val_37 30 | 327val_327 31 | 281val_281 32 | 277val_277 33 | 209val_209 34 | 15val_15 35 | 82val_82 36 | 403val_403 37 | 166val_166 38 | 417val_417 39 | 430val_430 40 | 252val_252 41 | 292val_292 42 | 219val_219 43 | 287val_287 44 | 153val_153 45 | 193val_193 46 | 338val_338 47 | 446val_446 48 | 459val_459 49 | 394val_394 50 | 237val_237 51 | 482val_482 52 | 174val_174 53 | 413val_413 54 | 494val_494 55 | 207val_207 56 | 199val_199 57 | 466val_466 58 | 208val_208 59 | 174val_174 60 | 399val_399 61 | 396val_396 62 | 247val_247 63 | 417val_417 64 | 489val_489 65 | 162val_162 66 | 377val_377 67 | 397val_397 68 | 309val_309 69 | 365val_365 70 | 266val_266 71 | 439val_439 72 | 342val_342 73 | 367val_367 74 | 325val_325 75 | 167val_167 76 | 195val_195 77 | 475val_475 78 | 17val_17 79 | 113val_113 80 | 155val_155 81 | 203val_203 82 | 339val_339 83 | 0val_0 84 | 455val_455 85 | 128val_128 86 | 311val_311 87 | 316val_316 88 | 57val_57 89 | 302val_302 90 | 205val_205 91 | 149val_149 92 | 438val_438 93 | 345val_345 94 | 129val_129 95 | 170val_170 96 | 20val_20 97 | 489val_489 98 | 157val_157 99 | 378val_378 100 | 221val_221 101 | 92val_92 102 | 111val_111 103 | 47val_47 104 | 72val_72 105 | 4val_4 106 | 280val_280 107 | 35val_35 108 | 427val_427 109 | 277val_277 110 | 208val_208 111 | 356val_356 112 | 399val_399 113 | 169val_169 114 | 382val_382 115 | 498val_498 116 | 125val_125 117 | 386val_386 118 | 437val_437 119 | 469val_469 120 | 192val_192 121 | 286val_286 122 | 187val_187 123 | 176val_176 124 | 54val_54 125 | 459val_459 126 | 51val_51 127 | 138val_138 128 | 103val_103 129 | 239val_239 130 | 213val_213 131 | 216val_216 132 | 430val_430 133 | 278val_278 134 | 176val_176 135 | 289val_289 136 | 221val_221 137 | 65val_65 138 | 318val_318 139 | 332val_332 140 | 311val_311 141 | 275val_275 142 | 137val_137 143 | 241val_241 144 | 83val_83 145 | 333val_333 146 | 180val_180 147 | 284val_284 148 | 12val_12 149 | 230val_230 150 | 181val_181 151 | 67val_67 152 | 260val_260 153 | 404val_404 154 | 384val_384 155 | 489val_489 156 | 353val_353 157 | 373val_373 158 | 272val_272 159 | 138val_138 160 | 217val_217 161 | 84val_84 162 | 348val_348 163 | 466val_466 164 | 58val_58 165 | 8val_8 166 | 411val_411 167 | 230val_230 168 | 208val_208 169 | 348val_348 170 | 24val_24 171 | 463val_463 172 | 431val_431 173 | 179val_179 174 | 172val_172 175 | 42val_42 176 | 129val_129 177 | 158val_158 178 | 119val_119 179 | 496val_496 180 | 0val_0 181 | 322val_322 182 | 197val_197 183 | 468val_468 184 | 393val_393 185 | 454val_454 186 | 100val_100 187 | 298val_298 188 | 199val_199 189 | 191val_191 190 | 418val_418 191 | 96val_96 192 | 26val_26 193 | 165val_165 194 | 327val_327 195 | 230val_230 196 | 205val_205 197 | 120val_120 198 | 131val_131 199 | 51val_51 200 | 404val_404 201 | 43val_43 202 | 436val_436 203 | 156val_156 204 | 469val_469 205 | 468val_468 206 | 308val_308 207 | 95val_95 208 | 196val_196 209 | 288val_288 210 | 481val_481 211 | 457val_457 212 | 98val_98 213 | 282val_282 214 | 197val_197 215 | 187val_187 216 | 318val_318 217 | 318val_318 218 | 409val_409 219 | 470val_470 220 | 137val_137 221 | 369val_369 222 | 316val_316 223 | 169val_169 224 | 413val_413 225 | 85val_85 226 | 77val_77 227 | 0val_0 228 | 490val_490 229 | 87val_87 230 | 364val_364 231 | 179val_179 232 | 118val_118 233 | 134val_134 234 | 395val_395 235 | 282val_282 236 | 138val_138 237 | 238val_238 238 | 419val_419 239 | 15val_15 240 | 118val_118 241 | 72val_72 242 | 90val_90 243 | 307val_307 244 | 19val_19 245 | 435val_435 246 | 10val_10 247 | 277val_277 248 | 273val_273 249 | 306val_306 250 | 224val_224 251 | 309val_309 252 | 389val_389 253 | 327val_327 254 | 242val_242 255 | 369val_369 256 | 392val_392 257 | 272val_272 258 | 331val_331 259 | 401val_401 260 | 242val_242 261 | 452val_452 262 | 177val_177 263 | 226val_226 264 | 5val_5 265 | 497val_497 266 | 402val_402 267 | 396val_396 268 | 317val_317 269 | 395val_395 270 | 58val_58 271 | 35val_35 272 | 336val_336 273 | 95val_95 274 | 11val_11 275 | 168val_168 276 | 34val_34 277 | 229val_229 278 | 233val_233 279 | 143val_143 280 | 472val_472 281 | 322val_322 282 | 498val_498 283 | 160val_160 284 | 195val_195 285 | 42val_42 286 | 321val_321 287 | 430val_430 288 | 119val_119 289 | 489val_489 290 | 458val_458 291 | 78val_78 292 | 76val_76 293 | 41val_41 294 | 223val_223 295 | 492val_492 296 | 149val_149 297 | 449val_449 298 | 218val_218 299 | 228val_228 300 | 138val_138 301 | 453val_453 302 | 30val_30 303 | 209val_209 304 | 64val_64 305 | 468val_468 306 | 76val_76 307 | 74val_74 308 | 342val_342 309 | 69val_69 310 | 230val_230 311 | 33val_33 312 | 368val_368 313 | 103val_103 314 | 296val_296 315 | 113val_113 316 | 216val_216 317 | 367val_367 318 | 344val_344 319 | 167val_167 320 | 274val_274 321 | 219val_219 322 | 239val_239 323 | 485val_485 324 | 116val_116 325 | 223val_223 326 | 256val_256 327 | 263val_263 328 | 70val_70 329 | 487val_487 330 | 480val_480 331 | 401val_401 332 | 288val_288 333 | 191val_191 334 | 5val_5 335 | 244val_244 336 | 438val_438 337 | 128val_128 338 | 467val_467 339 | 432val_432 340 | 202val_202 341 | 316val_316 342 | 229val_229 343 | 469val_469 344 | 463val_463 345 | 280val_280 346 | 2val_2 347 | 35val_35 348 | 283val_283 349 | 331val_331 350 | 235val_235 351 | 80val_80 352 | 44val_44 353 | 193val_193 354 | 321val_321 355 | 335val_335 356 | 104val_104 357 | 466val_466 358 | 366val_366 359 | 175val_175 360 | 403val_403 361 | 483val_483 362 | 53val_53 363 | 105val_105 364 | 257val_257 365 | 406val_406 366 | 409val_409 367 | 190val_190 368 | 406val_406 369 | 401val_401 370 | 114val_114 371 | 258val_258 372 | 90val_90 373 | 203val_203 374 | 262val_262 375 | 348val_348 376 | 424val_424 377 | 12val_12 378 | 396val_396 379 | 201val_201 380 | 217val_217 381 | 164val_164 382 | 431val_431 383 | 454val_454 384 | 478val_478 385 | 298val_298 386 | 125val_125 387 | 431val_431 388 | 164val_164 389 | 424val_424 390 | 187val_187 391 | 382val_382 392 | 5val_5 393 | 70val_70 394 | 397val_397 395 | 480val_480 396 | 291val_291 397 | 24val_24 398 | 351val_351 399 | 255val_255 400 | 104val_104 401 | 70val_70 402 | 163val_163 403 | 438val_438 404 | 119val_119 405 | 414val_414 406 | 200val_200 407 | 491val_491 408 | 237val_237 409 | 439val_439 410 | 360val_360 411 | 248val_248 412 | 479val_479 413 | 305val_305 414 | 417val_417 415 | 199val_199 416 | 444val_444 417 | 120val_120 418 | 429val_429 419 | 169val_169 420 | 443val_443 421 | 323val_323 422 | 325val_325 423 | 277val_277 424 | 230val_230 425 | 478val_478 426 | 178val_178 427 | 468val_468 428 | 310val_310 429 | 317val_317 430 | 333val_333 431 | 493val_493 432 | 460val_460 433 | 207val_207 434 | 249val_249 435 | 265val_265 436 | 480val_480 437 | 83val_83 438 | 136val_136 439 | 353val_353 440 | 172val_172 441 | 214val_214 442 | 462val_462 443 | 233val_233 444 | 406val_406 445 | 133val_133 446 | 175val_175 447 | 189val_189 448 | 454val_454 449 | 375val_375 450 | 401val_401 451 | 421val_421 452 | 407val_407 453 | 384val_384 454 | 256val_256 455 | 26val_26 456 | 134val_134 457 | 67val_67 458 | 384val_384 459 | 379val_379 460 | 18val_18 461 | 462val_462 462 | 492val_492 463 | 100val_100 464 | 298val_298 465 | 9val_9 466 | 341val_341 467 | 498val_498 468 | 146val_146 469 | 458val_458 470 | 362val_362 471 | 186val_186 472 | 285val_285 473 | 348val_348 474 | 167val_167 475 | 18val_18 476 | 273val_273 477 | 183val_183 478 | 281val_281 479 | 344val_344 480 | 97val_97 481 | 469val_469 482 | 315val_315 483 | 84val_84 484 | 28val_28 485 | 37val_37 486 | 448val_448 487 | 152val_152 488 | 348val_348 489 | 307val_307 490 | 194val_194 491 | 414val_414 492 | 477val_477 493 | 222val_222 494 | 126val_126 495 | 90val_90 496 | 169val_169 497 | 403val_403 498 | 400val_400 499 | 200val_200 500 | 97val_97 501 | -------------------------------------------------------------------------------- /crates/connect/src/expressions.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Traits for converting Rust Types to Spark Connect Expression Types 19 | //! 20 | //! Spark Connect has a few different ways of creating expressions and different gRPC methods 21 | //! require expressions in different forms. These traits are used to either translate a value into 22 | //! a [spark::Expression] or into a [spark::expression::Literal]. 23 | 24 | use chrono::NaiveDateTime; 25 | 26 | use crate::spark; 27 | 28 | use crate::column::Column; 29 | use crate::types::DataType; 30 | 31 | pub struct VecExpression { 32 | pub(super) expr: Vec, 33 | } 34 | 35 | impl FromIterator for VecExpression 36 | where 37 | T: Into, 38 | { 39 | fn from_iter>(iter: I) -> Self { 40 | let expr = iter 41 | .into_iter() 42 | .map(Into::into) 43 | .map(|col| col.expression) 44 | .collect(); 45 | 46 | VecExpression { expr } 47 | } 48 | } 49 | 50 | impl From for Vec { 51 | fn from(value: VecExpression) -> Self { 52 | value.expr 53 | } 54 | } 55 | 56 | impl<'a> From<&'a str> for VecExpression { 57 | fn from(value: &'a str) -> Self { 58 | VecExpression { 59 | expr: vec![Column::from_str(value).expression], 60 | } 61 | } 62 | } 63 | 64 | impl From for VecExpression { 65 | fn from(value: String) -> Self { 66 | VecExpression { 67 | expr: vec![Column::from_string(value).expression], 68 | } 69 | } 70 | } 71 | 72 | impl From for spark::Expression { 73 | fn from(value: String) -> Self { 74 | Column::from(value).expression 75 | } 76 | } 77 | 78 | impl<'a> From<&'a str> for spark::Expression { 79 | fn from(value: &'a str) -> Self { 80 | Column::from(value).expression 81 | } 82 | } 83 | 84 | impl From for spark::Expression { 85 | fn from(value: Column) -> Self { 86 | value.expression 87 | } 88 | } 89 | 90 | /// Create a filter expression 91 | pub trait ToFilterExpr { 92 | fn to_filter_expr(&self) -> Option; 93 | } 94 | 95 | impl ToFilterExpr for Column { 96 | fn to_filter_expr(&self) -> Option { 97 | Some(self.expression.clone()) 98 | } 99 | } 100 | 101 | impl ToFilterExpr for &str { 102 | fn to_filter_expr(&self) -> Option { 103 | let expr_type = Some(spark::expression::ExprType::ExpressionString( 104 | spark::expression::ExpressionString { 105 | expression: self.to_string(), 106 | }, 107 | )); 108 | 109 | Some(spark::Expression { expr_type }) 110 | } 111 | } 112 | 113 | /// Translate a rust value into a literal type 114 | pub trait ToLiteral { 115 | fn to_literal(&self) -> spark::expression::Literal; 116 | } 117 | 118 | macro_rules! impl_to_literal { 119 | ($type:ty, $inner_type:ident) => { 120 | impl From<$type> for spark::expression::Literal { 121 | fn from(value: $type) -> spark::expression::Literal { 122 | spark::expression::Literal { 123 | literal_type: Some(spark::expression::literal::LiteralType::$inner_type(value)), 124 | } 125 | } 126 | } 127 | }; 128 | } 129 | 130 | impl_to_literal!(bool, Boolean); 131 | impl_to_literal!(i32, Integer); 132 | impl_to_literal!(i64, Long); 133 | impl_to_literal!(f32, Float); 134 | impl_to_literal!(f64, Double); 135 | impl_to_literal!(String, String); 136 | 137 | impl From<&[u8]> for spark::expression::Literal { 138 | fn from(value: &[u8]) -> Self { 139 | spark::expression::Literal { 140 | literal_type: Some(spark::expression::literal::LiteralType::Binary(Vec::from( 141 | value, 142 | ))), 143 | } 144 | } 145 | } 146 | 147 | impl From for spark::expression::Literal { 148 | fn from(value: i16) -> Self { 149 | spark::expression::Literal { 150 | literal_type: Some(spark::expression::literal::LiteralType::Short(value as i32)), 151 | } 152 | } 153 | } 154 | 155 | impl<'a> From<&'a str> for spark::expression::Literal { 156 | fn from(value: &'a str) -> Self { 157 | spark::expression::Literal { 158 | literal_type: Some(spark::expression::literal::LiteralType::String( 159 | value.to_string(), 160 | )), 161 | } 162 | } 163 | } 164 | 165 | impl From> for spark::expression::Literal { 166 | fn from(value: chrono::DateTime) -> Self { 167 | // timestamps for spark have to be the microsends since 1/1/1970 168 | let timestamp = value.timestamp_micros(); 169 | 170 | spark::expression::Literal { 171 | literal_type: Some(spark::expression::literal::LiteralType::Timestamp( 172 | timestamp, 173 | )), 174 | } 175 | } 176 | } 177 | 178 | impl From for spark::expression::Literal { 179 | fn from(value: NaiveDateTime) -> Self { 180 | // timestamps for spark have to be the microsends since 1/1/1970 181 | let timestamp = value.and_utc().timestamp_micros(); 182 | 183 | spark::expression::Literal { 184 | literal_type: Some(spark::expression::literal::LiteralType::TimestampNtz( 185 | timestamp, 186 | )), 187 | } 188 | } 189 | } 190 | 191 | impl From for spark::expression::Literal { 192 | fn from(value: chrono::NaiveDate) -> Self { 193 | // Spark works based on unix time. I.e. seconds since 1/1/1970 194 | // to get dates to work you have to do this math 195 | let days_since_unix_epoch = 196 | value.signed_duration_since(chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()); 197 | 198 | spark::expression::Literal { 199 | literal_type: Some(spark::expression::literal::LiteralType::Date( 200 | days_since_unix_epoch.num_days() as i32, 201 | )), 202 | } 203 | } 204 | } 205 | 206 | impl From> for spark::expression::Literal 207 | where 208 | T: Into + Clone, 209 | spark::DataType: From, 210 | { 211 | fn from(value: Vec) -> Self { 212 | let element_type = Some(spark::DataType::from( 213 | value.first().expect("Array can not be empty").clone(), 214 | )); 215 | 216 | let elements = value.iter().map(|val| val.clone().into()).collect(); 217 | 218 | let array_type = spark::expression::literal::Array { 219 | element_type, 220 | elements, 221 | }; 222 | 223 | spark::expression::Literal { 224 | literal_type: Some(spark::expression::literal::LiteralType::Array(array_type)), 225 | } 226 | } 227 | } 228 | 229 | impl From<[T; N]> for spark::expression::Literal 230 | where 231 | T: Into + Clone, 232 | spark::DataType: From, 233 | { 234 | fn from(value: [T; N]) -> Self { 235 | let element_type = Some(spark::DataType::from( 236 | value.first().expect("Array can not be empty").clone(), 237 | )); 238 | 239 | let elements = value.iter().map(|val| val.clone().into()).collect(); 240 | 241 | let array_type = spark::expression::literal::Array { 242 | element_type, 243 | elements, 244 | }; 245 | 246 | spark::expression::Literal { 247 | literal_type: Some(spark::expression::literal::LiteralType::Array(array_type)), 248 | } 249 | } 250 | } 251 | 252 | impl From<&str> for spark::expression::cast::CastToType { 253 | fn from(value: &str) -> Self { 254 | spark::expression::cast::CastToType::TypeStr(value.to_string()) 255 | } 256 | } 257 | 258 | impl From for spark::expression::cast::CastToType { 259 | fn from(value: String) -> Self { 260 | spark::expression::cast::CastToType::TypeStr(value) 261 | } 262 | } 263 | 264 | impl From for spark::expression::cast::CastToType { 265 | fn from(value: DataType) -> spark::expression::cast::CastToType { 266 | spark::expression::cast::CastToType::Type(value.into()) 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /crates/connect/src/client/builder.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Implementation of ChannelBuilder 19 | 20 | use std::collections::HashMap; 21 | use std::env; 22 | use std::str::FromStr; 23 | 24 | use crate::errors::SparkError; 25 | 26 | use url::Url; 27 | 28 | use uuid::Uuid; 29 | 30 | pub(crate) type Host = String; 31 | pub(crate) type Port = u16; 32 | pub(crate) type UrlParse = (Host, Port, Option>); 33 | 34 | /// ChannelBuilder validates a connection string 35 | /// based on the requirements from [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) 36 | #[derive(Clone, Debug)] 37 | pub struct ChannelBuilder { 38 | pub(super) host: Host, 39 | pub(super) port: Port, 40 | pub(super) session_id: Uuid, 41 | pub(super) token: Option, 42 | pub(super) user_id: Option, 43 | pub(super) user_agent: Option, 44 | pub(super) use_ssl: bool, 45 | pub(super) headers: Option>, 46 | } 47 | 48 | impl Default for ChannelBuilder { 49 | fn default() -> Self { 50 | let connection = match env::var("SPARK_REMOTE") { 51 | Ok(conn) => conn.to_string(), 52 | Err(_) => "sc://localhost:15002".to_string(), 53 | }; 54 | 55 | ChannelBuilder::create(&connection).unwrap() 56 | } 57 | } 58 | 59 | impl ChannelBuilder { 60 | pub fn new() -> Self { 61 | ChannelBuilder::default() 62 | } 63 | 64 | pub(crate) fn endpoint(&self) -> String { 65 | let scheme = if cfg!(feature = "tls") { 66 | "https" 67 | } else { 68 | "http" 69 | }; 70 | 71 | format!("{}://{}:{}", scheme, self.host, self.port) 72 | } 73 | 74 | pub(crate) fn headers(&self) -> Option> { 75 | self.headers.to_owned() 76 | } 77 | 78 | pub(crate) fn create_user_agent(user_agent: Option<&str>) -> Option { 79 | let user_agent = user_agent.unwrap_or("_SPARK_CONNECT_RUST"); 80 | let pkg_version = env!("CARGO_PKG_VERSION"); 81 | let os = env::consts::OS.to_lowercase(); 82 | 83 | Some(format!( 84 | "{} os/{} spark_connect_rs/{}", 85 | user_agent, os, pkg_version 86 | )) 87 | } 88 | 89 | pub(crate) fn create_user_id(user_id: Option<&str>) -> Option { 90 | match user_id { 91 | Some(user_id) => Some(user_id.to_string()), 92 | None => env::var("USER").ok(), 93 | } 94 | } 95 | 96 | pub(crate) fn parse_connection_string(connection: &str) -> Result { 97 | let url = Url::parse(connection).map_err(|_| { 98 | SparkError::InvalidConnectionUrl("Failed to parse the connection URL".to_string()) 99 | })?; 100 | 101 | if url.scheme() != "sc" { 102 | return Err(SparkError::InvalidConnectionUrl( 103 | "The URL must start with 'sc://'. Please update the URL to follow the correct format, e.g., 'sc://hostname:port'".to_string(), 104 | )); 105 | }; 106 | 107 | let host = url 108 | .host_str() 109 | .ok_or_else(|| { 110 | SparkError::InvalidConnectionUrl( 111 | "The hostname must not be empty. Please update 112 | the URL to follow the correct format, e.g., 'sc://hostname:port'." 113 | .to_string(), 114 | ) 115 | })? 116 | .to_string(); 117 | 118 | let port = url.port().ok_or_else(|| { 119 | SparkError::InvalidConnectionUrl( 120 | "The port must not be empty. Please update 121 | the URL to follow the correct format, e.g., 'sc://hostname:port'." 122 | .to_string(), 123 | ) 124 | })?; 125 | 126 | let headers = ChannelBuilder::parse_headers(url); 127 | 128 | Ok((host, port, headers)) 129 | } 130 | 131 | pub(crate) fn parse_headers(url: Url) -> Option> { 132 | let path: Vec<&str> = url 133 | .path() 134 | .split(';') 135 | .filter(|&pair| (pair != "/") & (!pair.is_empty())) 136 | .collect(); 137 | 138 | if path.is_empty() || (path.len() == 1 && (path[0].is_empty() || path[0] == "/")) { 139 | return None; 140 | } 141 | 142 | let headers: HashMap = path 143 | .iter() 144 | .copied() 145 | .map(|pair| { 146 | let mut parts = pair.splitn(2, '='); 147 | ( 148 | parts.next().unwrap_or("").to_string(), 149 | parts.next().unwrap_or("").to_string(), 150 | ) 151 | }) 152 | .collect(); 153 | 154 | if headers.is_empty() { 155 | return None; 156 | } 157 | 158 | Some(headers) 159 | } 160 | 161 | /// Create and validate a connnection string 162 | #[allow(unreachable_code)] 163 | pub fn create(connection: &str) -> Result { 164 | let (host, port, headers) = ChannelBuilder::parse_connection_string(connection)?; 165 | 166 | let mut channel_builder = ChannelBuilder { 167 | host, 168 | port, 169 | session_id: Uuid::new_v4(), 170 | token: None, 171 | user_id: ChannelBuilder::create_user_id(None), 172 | user_agent: ChannelBuilder::create_user_agent(None), 173 | use_ssl: false, 174 | headers: None, 175 | }; 176 | 177 | if let Some(mut headers) = headers { 178 | channel_builder.user_id = headers 179 | .remove("user_id") 180 | .map(|user_id| ChannelBuilder::create_user_id(Some(&user_id))) 181 | .unwrap_or_else(|| ChannelBuilder::create_user_id(None)); 182 | 183 | channel_builder.user_agent = headers 184 | .remove("user_agent") 185 | .map(|user_agent| ChannelBuilder::create_user_agent(Some(&user_agent))) 186 | .unwrap_or_else(|| ChannelBuilder::create_user_agent(None)); 187 | 188 | if let Some(token) = headers.remove("token") { 189 | let token = format!("Bearer {token}"); 190 | channel_builder.token = Some(token.clone()); 191 | headers.insert("authorization".to_string(), token); 192 | } 193 | 194 | if let Some(session_id) = headers.remove("session_id") { 195 | channel_builder.session_id = Uuid::from_str(&session_id)? 196 | } 197 | 198 | if let Some(use_ssl) = headers.remove("use_ssl") { 199 | if use_ssl.to_lowercase() == "true" { 200 | #[cfg(not(feature = "tls"))] 201 | { 202 | panic!( 203 | "The 'use_ssl' option requires the 'tls' feature, but it's not enabled!" 204 | ); 205 | }; 206 | channel_builder.use_ssl = true 207 | } 208 | }; 209 | 210 | if !headers.is_empty() { 211 | channel_builder.headers = Some(headers); 212 | } 213 | } 214 | 215 | Ok(channel_builder) 216 | } 217 | } 218 | 219 | #[cfg(test)] 220 | mod tests { 221 | use super::*; 222 | 223 | #[test] 224 | fn test_channel_builder_default() { 225 | let expected_url = "http://localhost:15002".to_string(); 226 | 227 | let cb = ChannelBuilder::default(); 228 | 229 | assert_eq!(expected_url, cb.endpoint()) 230 | } 231 | 232 | #[test] 233 | fn test_panic_incorrect_url_scheme() { 234 | let connection = "http://127.0.0.1:15002"; 235 | 236 | assert!(ChannelBuilder::create(connection).is_err()) 237 | } 238 | 239 | #[test] 240 | fn test_panic_missing_url_host() { 241 | let connection = "sc://:15002"; 242 | 243 | assert!(ChannelBuilder::create(connection).is_err()) 244 | } 245 | 246 | #[test] 247 | fn test_panic_missing_url_port() { 248 | let connection = "sc://127.0.0.1"; 249 | 250 | assert!(ChannelBuilder::create(connection).is_err()) 251 | } 252 | 253 | #[test] 254 | fn test_settings_builder() { 255 | let connection = "sc://myhost.com:443/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; 256 | 257 | let builder = ChannelBuilder::create(connection).unwrap(); 258 | 259 | assert_eq!("http://myhost.com:443".to_string(), builder.endpoint()); 260 | assert_eq!("Bearer ABCDEFG".to_string(), builder.token.unwrap()); 261 | assert_eq!("user123".to_string(), builder.user_id.unwrap()); 262 | } 263 | 264 | #[test] 265 | #[should_panic( 266 | expected = "The 'use_ssl' option requires the 'tls' feature, but it's not enabled!" 267 | )] 268 | fn test_panic_ssl() { 269 | let connection = "sc://127.0.0.1:443/;use_ssl=true"; 270 | 271 | ChannelBuilder::create(connection).unwrap(); 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/expressions.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | import "google/protobuf/any.proto"; 21 | import "spark/connect/types.proto"; 22 | 23 | package spark.connect; 24 | 25 | option java_multiple_files = true; 26 | option java_package = "org.apache.spark.connect.proto"; 27 | option go_package = "internal/generated"; 28 | 29 | // Expression used to refer to fields, functions and similar. This can be used everywhere 30 | // expressions in SQL appear. 31 | message Expression { 32 | 33 | oneof expr_type { 34 | Literal literal = 1; 35 | UnresolvedAttribute unresolved_attribute = 2; 36 | UnresolvedFunction unresolved_function = 3; 37 | ExpressionString expression_string = 4; 38 | UnresolvedStar unresolved_star = 5; 39 | Alias alias = 6; 40 | Cast cast = 7; 41 | UnresolvedRegex unresolved_regex = 8; 42 | SortOrder sort_order = 9; 43 | LambdaFunction lambda_function = 10; 44 | Window window = 11; 45 | UnresolvedExtractValue unresolved_extract_value = 12; 46 | UpdateFields update_fields = 13; 47 | UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14; 48 | CommonInlineUserDefinedFunction common_inline_user_defined_function = 15; 49 | CallFunction call_function = 16; 50 | 51 | // This field is used to mark extensions to the protocol. When plugins generate arbitrary 52 | // relations they can add them here. During the planning the correct resolution is done. 53 | google.protobuf.Any extension = 999; 54 | } 55 | 56 | 57 | // Expression for the OVER clause or WINDOW clause. 58 | message Window { 59 | 60 | // (Required) The window function. 61 | Expression window_function = 1; 62 | 63 | // (Optional) The way that input rows are partitioned. 64 | repeated Expression partition_spec = 2; 65 | 66 | // (Optional) Ordering of rows in a partition. 67 | repeated SortOrder order_spec = 3; 68 | 69 | // (Optional) Window frame in a partition. 70 | // 71 | // If not set, it will be treated as 'UnspecifiedFrame'. 72 | WindowFrame frame_spec = 4; 73 | 74 | // The window frame 75 | message WindowFrame { 76 | 77 | // (Required) The type of the frame. 78 | FrameType frame_type = 1; 79 | 80 | // (Required) The lower bound of the frame. 81 | FrameBoundary lower = 2; 82 | 83 | // (Required) The upper bound of the frame. 84 | FrameBoundary upper = 3; 85 | 86 | enum FrameType { 87 | FRAME_TYPE_UNDEFINED = 0; 88 | 89 | // RowFrame treats rows in a partition individually. 90 | FRAME_TYPE_ROW = 1; 91 | 92 | // RangeFrame treats rows in a partition as groups of peers. 93 | // All rows having the same 'ORDER BY' ordering are considered as peers. 94 | FRAME_TYPE_RANGE = 2; 95 | } 96 | 97 | message FrameBoundary { 98 | oneof boundary { 99 | // CURRENT ROW boundary 100 | bool current_row = 1; 101 | 102 | // UNBOUNDED boundary. 103 | // For lower bound, it will be converted to 'UnboundedPreceding'. 104 | // for upper bound, it will be converted to 'UnboundedFollowing'. 105 | bool unbounded = 2; 106 | 107 | // This is an expression for future proofing. We are expecting literals on the server side. 108 | Expression value = 3; 109 | } 110 | } 111 | } 112 | } 113 | 114 | // SortOrder is used to specify the data ordering, it is normally used in Sort and Window. 115 | // It is an unevaluable expression and cannot be evaluated, so can not be used in Projection. 116 | message SortOrder { 117 | // (Required) The expression to be sorted. 118 | Expression child = 1; 119 | 120 | // (Required) The sort direction, should be ASCENDING or DESCENDING. 121 | SortDirection direction = 2; 122 | 123 | // (Required) How to deal with NULLs, should be NULLS_FIRST or NULLS_LAST. 124 | NullOrdering null_ordering = 3; 125 | 126 | enum SortDirection { 127 | SORT_DIRECTION_UNSPECIFIED = 0; 128 | SORT_DIRECTION_ASCENDING = 1; 129 | SORT_DIRECTION_DESCENDING = 2; 130 | } 131 | 132 | enum NullOrdering { 133 | SORT_NULLS_UNSPECIFIED = 0; 134 | SORT_NULLS_FIRST = 1; 135 | SORT_NULLS_LAST = 2; 136 | } 137 | } 138 | 139 | message Cast { 140 | // (Required) the expression to be casted. 141 | Expression expr = 1; 142 | 143 | // (Required) the data type that the expr to be casted to. 144 | oneof cast_to_type { 145 | DataType type = 2; 146 | // If this is set, Server will use Catalyst parser to parse this string to DataType. 147 | string type_str = 3; 148 | } 149 | } 150 | 151 | message Literal { 152 | oneof literal_type { 153 | DataType null = 1; 154 | bytes binary = 2; 155 | bool boolean = 3; 156 | 157 | int32 byte = 4; 158 | int32 short = 5; 159 | int32 integer = 6; 160 | int64 long = 7; 161 | float float = 10; 162 | double double = 11; 163 | Decimal decimal = 12; 164 | 165 | string string = 13; 166 | 167 | // Date in units of days since the UNIX epoch. 168 | int32 date = 16; 169 | // Timestamp in units of microseconds since the UNIX epoch. 170 | int64 timestamp = 17; 171 | // Timestamp in units of microseconds since the UNIX epoch (without timezone information). 172 | int64 timestamp_ntz = 18; 173 | 174 | CalendarInterval calendar_interval = 19; 175 | int32 year_month_interval = 20; 176 | int64 day_time_interval = 21; 177 | Array array = 22; 178 | Map map = 23; 179 | Struct struct = 24; 180 | } 181 | 182 | message Decimal { 183 | // the string representation. 184 | string value = 1; 185 | // The maximum number of digits allowed in the value. 186 | // the maximum precision is 38. 187 | optional int32 precision = 2; 188 | // declared scale of decimal literal 189 | optional int32 scale = 3; 190 | } 191 | 192 | message CalendarInterval { 193 | int32 months = 1; 194 | int32 days = 2; 195 | int64 microseconds = 3; 196 | } 197 | 198 | message Array { 199 | DataType element_type = 1; 200 | repeated Literal elements = 2; 201 | } 202 | 203 | message Map { 204 | DataType key_type = 1; 205 | DataType value_type = 2; 206 | repeated Literal keys = 3; 207 | repeated Literal values = 4; 208 | } 209 | 210 | message Struct { 211 | DataType struct_type = 1; 212 | repeated Literal elements = 2; 213 | } 214 | } 215 | 216 | // An unresolved attribute that is not explicitly bound to a specific column, but the column 217 | // is resolved during analysis by name. 218 | message UnresolvedAttribute { 219 | // (Required) An identifier that will be parsed by Catalyst parser. This should follow the 220 | // Spark SQL identifier syntax. 221 | string unparsed_identifier = 1; 222 | 223 | // (Optional) The id of corresponding connect plan. 224 | optional int64 plan_id = 2; 225 | } 226 | 227 | // An unresolved function is not explicitly bound to one explicit function, but the function 228 | // is resolved during analysis following Sparks name resolution rules. 229 | message UnresolvedFunction { 230 | // (Required) name (or unparsed name for user defined function) for the unresolved function. 231 | string function_name = 1; 232 | 233 | // (Optional) Function arguments. Empty arguments are allowed. 234 | repeated Expression arguments = 2; 235 | 236 | // (Required) Indicate if this function should be applied on distinct values. 237 | bool is_distinct = 3; 238 | 239 | // (Required) Indicate if this is a user defined function. 240 | // 241 | // When it is not a user defined function, Connect will use the function name directly. 242 | // When it is a user defined function, Connect will parse the function name first. 243 | bool is_user_defined_function = 4; 244 | } 245 | 246 | // Expression as string. 247 | message ExpressionString { 248 | // (Required) A SQL expression that will be parsed by Catalyst parser. 249 | string expression = 1; 250 | } 251 | 252 | // UnresolvedStar is used to expand all the fields of a relation or struct. 253 | message UnresolvedStar { 254 | 255 | // (Optional) The target of the expansion. 256 | // 257 | // If set, it should end with '.*' and will be parsed by 'parseAttributeName' 258 | // in the server side. 259 | optional string unparsed_target = 1; 260 | } 261 | 262 | // Represents all of the input attributes to a given relational operator, for example in 263 | // "SELECT `(id)?+.+` FROM ...". 264 | message UnresolvedRegex { 265 | // (Required) The column name used to extract column with regex. 266 | string col_name = 1; 267 | 268 | // (Optional) The id of corresponding connect plan. 269 | optional int64 plan_id = 2; 270 | } 271 | 272 | // Extracts a value or values from an Expression 273 | message UnresolvedExtractValue { 274 | // (Required) The expression to extract value from, can be 275 | // Map, Array, Struct or array of Structs. 276 | Expression child = 1; 277 | 278 | // (Required) The expression to describe the extraction, can be 279 | // key of Map, index of Array, field name of Struct. 280 | Expression extraction = 2; 281 | } 282 | 283 | // Add, replace or drop a field of `StructType` expression by name. 284 | message UpdateFields { 285 | // (Required) The struct expression. 286 | Expression struct_expression = 1; 287 | 288 | // (Required) The field name. 289 | string field_name = 2; 290 | 291 | // (Optional) The expression to add or replace. 292 | // 293 | // When not set, it means this field will be dropped. 294 | Expression value_expression = 3; 295 | } 296 | 297 | message Alias { 298 | // (Required) The expression that alias will be added on. 299 | Expression expr = 1; 300 | 301 | // (Required) a list of name parts for the alias. 302 | // 303 | // Scalar columns only has one name that presents. 304 | repeated string name = 2; 305 | 306 | // (Optional) Alias metadata expressed as a JSON map. 307 | optional string metadata = 3; 308 | } 309 | 310 | message LambdaFunction { 311 | // (Required) The lambda function. 312 | // 313 | // The function body should use 'UnresolvedAttribute' as arguments, the sever side will 314 | // replace 'UnresolvedAttribute' with 'UnresolvedNamedLambdaVariable'. 315 | Expression function = 1; 316 | 317 | // (Required) Function variables. Must contains 1 ~ 3 variables. 318 | repeated Expression.UnresolvedNamedLambdaVariable arguments = 2; 319 | } 320 | 321 | message UnresolvedNamedLambdaVariable { 322 | 323 | // (Required) a list of name parts for the variable. Must not be empty. 324 | repeated string name_parts = 1; 325 | } 326 | } 327 | 328 | message CommonInlineUserDefinedFunction { 329 | // (Required) Name of the user-defined function. 330 | string function_name = 1; 331 | // (Optional) Indicate if the user-defined function is deterministic. 332 | bool deterministic = 2; 333 | // (Optional) Function arguments. Empty arguments are allowed. 334 | repeated Expression arguments = 3; 335 | // (Required) Indicate the function type of the user-defined function. 336 | oneof function { 337 | PythonUDF python_udf = 4; 338 | ScalarScalaUDF scalar_scala_udf = 5; 339 | JavaUDF java_udf = 6; 340 | } 341 | } 342 | 343 | message PythonUDF { 344 | // (Required) Output type of the Python UDF 345 | DataType output_type = 1; 346 | // (Required) EvalType of the Python UDF 347 | int32 eval_type = 2; 348 | // (Required) The encoded commands of the Python UDF 349 | bytes command = 3; 350 | // (Required) Python version being used in the client. 351 | string python_ver = 4; 352 | } 353 | 354 | message ScalarScalaUDF { 355 | // (Required) Serialized JVM object containing UDF definition, input encoders and output encoder 356 | bytes payload = 1; 357 | // (Optional) Input type(s) of the UDF 358 | repeated DataType inputTypes = 2; 359 | // (Required) Output type of the UDF 360 | DataType outputType = 3; 361 | // (Required) True if the UDF can return null value 362 | bool nullable = 4; 363 | } 364 | 365 | message JavaUDF { 366 | // (Required) Fully qualified name of Java class 367 | string class_name = 1; 368 | 369 | // (Optional) Output type of the Java UDF 370 | optional DataType output_type = 2; 371 | 372 | // (Required) Indicate if the Java user-defined function is an aggregate function 373 | bool aggregate = 3; 374 | } 375 | 376 | message CallFunction { 377 | // (Required) Unparsed name of the SQL function. 378 | string function_name = 1; 379 | 380 | // (Optional) Function arguments. Empty arguments are allowed. 381 | repeated Expression arguments = 2; 382 | } 383 | -------------------------------------------------------------------------------- /crates/connect/protobuf/spark-3.5/spark/connect/commands.proto: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | syntax = 'proto3'; 19 | 20 | import "google/protobuf/any.proto"; 21 | import "spark/connect/common.proto"; 22 | import "spark/connect/expressions.proto"; 23 | import "spark/connect/relations.proto"; 24 | 25 | package spark.connect; 26 | 27 | option java_multiple_files = true; 28 | option java_package = "org.apache.spark.connect.proto"; 29 | option go_package = "internal/generated"; 30 | 31 | // A [[Command]] is an operation that is executed by the server that does not directly consume or 32 | // produce a relational result. 33 | message Command { 34 | oneof command_type { 35 | CommonInlineUserDefinedFunction register_function = 1; 36 | WriteOperation write_operation = 2; 37 | CreateDataFrameViewCommand create_dataframe_view = 3; 38 | WriteOperationV2 write_operation_v2 = 4; 39 | SqlCommand sql_command = 5; 40 | WriteStreamOperationStart write_stream_operation_start = 6; 41 | StreamingQueryCommand streaming_query_command = 7; 42 | GetResourcesCommand get_resources_command = 8; 43 | StreamingQueryManagerCommand streaming_query_manager_command = 9; 44 | CommonInlineUserDefinedTableFunction register_table_function = 10; 45 | 46 | // This field is used to mark extensions to the protocol. When plugins generate arbitrary 47 | // Commands they can add them here. During the planning the correct resolution is done. 48 | google.protobuf.Any extension = 999; 49 | 50 | } 51 | } 52 | 53 | // A SQL Command is used to trigger the eager evaluation of SQL commands in Spark. 54 | // 55 | // When the SQL provide as part of the message is a command it will be immediately evaluated 56 | // and the result will be collected and returned as part of a LocalRelation. If the result is 57 | // not a command, the operation will simply return a SQL Relation. This allows the client to be 58 | // almost oblivious to the server-side behavior. 59 | message SqlCommand { 60 | // (Required) SQL Query. 61 | string sql = 1; 62 | 63 | // (Optional) A map of parameter names to literal expressions. 64 | map args = 2; 65 | 66 | // (Optional) A sequence of literal expressions for positional parameters in the SQL query text. 67 | repeated Expression.Literal pos_args = 3; 68 | } 69 | 70 | // A command that can create DataFrame global temp view or local temp view. 71 | message CreateDataFrameViewCommand { 72 | // (Required) The relation that this view will be built on. 73 | Relation input = 1; 74 | 75 | // (Required) View name. 76 | string name = 2; 77 | 78 | // (Required) Whether this is global temp view or local temp view. 79 | bool is_global = 3; 80 | 81 | // (Required) 82 | // 83 | // If true, and if the view already exists, updates it; if false, and if the view 84 | // already exists, throws exception. 85 | bool replace = 4; 86 | } 87 | 88 | // As writes are not directly handled during analysis and planning, they are modeled as commands. 89 | message WriteOperation { 90 | // (Required) The output of the `input` relation will be persisted according to the options. 91 | Relation input = 1; 92 | 93 | // (Optional) Format value according to the Spark documentation. Examples are: text, parquet, delta. 94 | optional string source = 2; 95 | 96 | // (Optional) 97 | // 98 | // The destination of the write operation can be either a path or a table. 99 | // If the destination is neither a path nor a table, such as jdbc and noop, 100 | // the `save_type` should not be set. 101 | oneof save_type { 102 | string path = 3; 103 | SaveTable table = 4; 104 | } 105 | 106 | // (Required) the save mode. 107 | SaveMode mode = 5; 108 | 109 | // (Optional) List of columns to sort the output by. 110 | repeated string sort_column_names = 6; 111 | 112 | // (Optional) List of columns for partitioning. 113 | repeated string partitioning_columns = 7; 114 | 115 | // (Optional) Bucketing specification. Bucketing must set the number of buckets and the columns 116 | // to bucket by. 117 | BucketBy bucket_by = 8; 118 | 119 | // (Optional) A list of configuration options. 120 | map options = 9; 121 | 122 | message SaveTable { 123 | // (Required) The table name. 124 | string table_name = 1; 125 | // (Required) The method to be called to write to the table. 126 | TableSaveMethod save_method = 2; 127 | 128 | enum TableSaveMethod { 129 | TABLE_SAVE_METHOD_UNSPECIFIED = 0; 130 | TABLE_SAVE_METHOD_SAVE_AS_TABLE = 1; 131 | TABLE_SAVE_METHOD_INSERT_INTO = 2; 132 | } 133 | } 134 | 135 | message BucketBy { 136 | repeated string bucket_column_names = 1; 137 | int32 num_buckets = 2; 138 | } 139 | 140 | enum SaveMode { 141 | SAVE_MODE_UNSPECIFIED = 0; 142 | SAVE_MODE_APPEND = 1; 143 | SAVE_MODE_OVERWRITE = 2; 144 | SAVE_MODE_ERROR_IF_EXISTS = 3; 145 | SAVE_MODE_IGNORE = 4; 146 | } 147 | } 148 | 149 | // As writes are not directly handled during analysis and planning, they are modeled as commands. 150 | message WriteOperationV2 { 151 | // (Required) The output of the `input` relation will be persisted according to the options. 152 | Relation input = 1; 153 | 154 | // (Required) The destination of the write operation must be either a path or a table. 155 | string table_name = 2; 156 | 157 | // (Optional) A provider for the underlying output data source. Spark's default catalog supports 158 | // "parquet", "json", etc. 159 | optional string provider = 3; 160 | 161 | // (Optional) List of columns for partitioning for output table created by `create`, 162 | // `createOrReplace`, or `replace` 163 | repeated Expression partitioning_columns = 4; 164 | 165 | // (Optional) A list of configuration options. 166 | map options = 5; 167 | 168 | // (Optional) A list of table properties. 169 | map table_properties = 6; 170 | 171 | // (Required) Write mode. 172 | Mode mode = 7; 173 | 174 | enum Mode { 175 | MODE_UNSPECIFIED = 0; 176 | MODE_CREATE = 1; 177 | MODE_OVERWRITE = 2; 178 | MODE_OVERWRITE_PARTITIONS = 3; 179 | MODE_APPEND = 4; 180 | MODE_REPLACE = 5; 181 | MODE_CREATE_OR_REPLACE = 6; 182 | } 183 | 184 | // (Optional) A condition for overwrite saving mode 185 | Expression overwrite_condition = 8; 186 | } 187 | 188 | // Starts write stream operation as streaming query. Query ID and Run ID of the streaming 189 | // query are returned. 190 | message WriteStreamOperationStart { 191 | 192 | // (Required) The output of the `input` streaming relation will be written. 193 | Relation input = 1; 194 | 195 | // The following fields directly map to API for DataStreamWriter(). 196 | // Consult API documentation unless explicitly documented here. 197 | 198 | string format = 2; 199 | map options = 3; 200 | repeated string partitioning_column_names = 4; 201 | 202 | oneof trigger { 203 | string processing_time_interval = 5; 204 | bool available_now = 6; 205 | bool once = 7; 206 | string continuous_checkpoint_interval = 8; 207 | } 208 | 209 | string output_mode = 9; 210 | string query_name = 10; 211 | 212 | // The destination is optional. When set, it can be a path or a table name. 213 | oneof sink_destination { 214 | string path = 11; 215 | string table_name = 12; 216 | } 217 | 218 | StreamingForeachFunction foreach_writer = 13; 219 | StreamingForeachFunction foreach_batch = 14; 220 | } 221 | 222 | message StreamingForeachFunction { 223 | oneof function { 224 | PythonUDF python_function = 1; 225 | ScalarScalaUDF scala_function = 2; 226 | } 227 | } 228 | 229 | message WriteStreamOperationStartResult { 230 | 231 | // (Required) Query instance. See `StreamingQueryInstanceId`. 232 | StreamingQueryInstanceId query_id = 1; 233 | 234 | // An optional query name. 235 | string name = 2; 236 | 237 | // TODO: How do we indicate errors? 238 | // TODO: Consider adding status, last progress etc here. 239 | } 240 | 241 | // A tuple that uniquely identifies an instance of streaming query run. It consists of `id` that 242 | // persists across the streaming runs and `run_id` that changes between each run of the 243 | // streaming query that resumes from the checkpoint. 244 | message StreamingQueryInstanceId { 245 | 246 | // (Required) The unique id of this query that persists across restarts from checkpoint data. 247 | // That is, this id is generated when a query is started for the first time, and 248 | // will be the same every time it is restarted from checkpoint data. 249 | string id = 1; 250 | 251 | // (Required) The unique id of this run of the query. That is, every start/restart of a query 252 | // will generate a unique run_id. Therefore, every time a query is restarted from 253 | // checkpoint, it will have the same `id` but different `run_id`s. 254 | string run_id = 2; 255 | } 256 | 257 | // Commands for a streaming query. 258 | message StreamingQueryCommand { 259 | 260 | // (Required) Query instance. See `StreamingQueryInstanceId`. 261 | StreamingQueryInstanceId query_id = 1; 262 | 263 | // See documentation for the corresponding API method in StreamingQuery. 264 | oneof command { 265 | // status() API. 266 | bool status = 2; 267 | // lastProgress() API. 268 | bool last_progress = 3; 269 | // recentProgress() API. 270 | bool recent_progress = 4; 271 | // stop() API. Stops the query. 272 | bool stop = 5; 273 | // processAllAvailable() API. Waits till all the available data is processed 274 | bool process_all_available = 6; 275 | // explain() API. Returns logical and physical plans. 276 | ExplainCommand explain = 7; 277 | // exception() API. Returns the exception in the query if any. 278 | bool exception = 8; 279 | // awaitTermination() API. Waits for the termination of the query. 280 | AwaitTerminationCommand await_termination = 9; 281 | } 282 | 283 | message ExplainCommand { 284 | // TODO: Consider reusing Explain from AnalyzePlanRequest message. 285 | // We can not do this right now since it base.proto imports this file. 286 | bool extended = 1; 287 | } 288 | 289 | message AwaitTerminationCommand { 290 | optional int64 timeout_ms = 2; 291 | } 292 | } 293 | 294 | // Response for commands on a streaming query. 295 | message StreamingQueryCommandResult { 296 | // (Required) Query instance id. See `StreamingQueryInstanceId`. 297 | StreamingQueryInstanceId query_id = 1; 298 | 299 | oneof result_type { 300 | StatusResult status = 2; 301 | RecentProgressResult recent_progress = 3; 302 | ExplainResult explain = 4; 303 | ExceptionResult exception = 5; 304 | AwaitTerminationResult await_termination = 6; 305 | } 306 | 307 | message StatusResult { 308 | // See documentation for these Scala 'StreamingQueryStatus' struct 309 | string status_message = 1; 310 | bool is_data_available = 2; 311 | bool is_trigger_active = 3; 312 | bool is_active = 4; 313 | } 314 | 315 | message RecentProgressResult { 316 | // Progress reports as an array of json strings. 317 | repeated string recent_progress_json = 5; 318 | } 319 | 320 | message ExplainResult { 321 | // Logical and physical plans as string 322 | string result = 1; 323 | } 324 | 325 | message ExceptionResult { 326 | // (Optional) Exception message as string, maps to the return value of original 327 | // StreamingQueryException's toString method 328 | optional string exception_message = 1; 329 | // (Optional) Exception error class as string 330 | optional string error_class = 2; 331 | // (Optional) Exception stack trace as string 332 | optional string stack_trace = 3; 333 | } 334 | 335 | message AwaitTerminationResult { 336 | bool terminated = 1; 337 | } 338 | } 339 | 340 | // Commands for the streaming query manager. 341 | message StreamingQueryManagerCommand { 342 | 343 | // See documentation for the corresponding API method in StreamingQueryManager. 344 | oneof command { 345 | // active() API, returns a list of active queries. 346 | bool active = 1; 347 | // get() API, returns the StreamingQuery identified by id. 348 | string get_query = 2; 349 | // awaitAnyTermination() API, wait until any query terminates or timeout. 350 | AwaitAnyTerminationCommand await_any_termination = 3; 351 | // resetTerminated() API. 352 | bool reset_terminated = 4; 353 | // addListener API. 354 | StreamingQueryListenerCommand add_listener = 5; 355 | // removeListener API. 356 | StreamingQueryListenerCommand remove_listener = 6; 357 | // listListeners() API, returns a list of streaming query listeners. 358 | bool list_listeners = 7; 359 | } 360 | 361 | message AwaitAnyTerminationCommand { 362 | // (Optional) The waiting time in milliseconds to wait for any query to terminate. 363 | optional int64 timeout_ms = 1; 364 | } 365 | 366 | message StreamingQueryListenerCommand { 367 | bytes listener_payload = 1; 368 | optional PythonUDF python_listener_payload = 2; 369 | string id = 3; 370 | } 371 | } 372 | 373 | // Response for commands on the streaming query manager. 374 | message StreamingQueryManagerCommandResult { 375 | oneof result_type { 376 | ActiveResult active = 1; 377 | StreamingQueryInstance query = 2; 378 | AwaitAnyTerminationResult await_any_termination = 3; 379 | bool reset_terminated = 4; 380 | bool add_listener = 5; 381 | bool remove_listener = 6; 382 | ListStreamingQueryListenerResult list_listeners = 7; 383 | } 384 | 385 | message ActiveResult { 386 | repeated StreamingQueryInstance active_queries = 1; 387 | } 388 | 389 | message StreamingQueryInstance { 390 | // (Required) The id and runId of this query. 391 | StreamingQueryInstanceId id = 1; 392 | // (Optional) The name of this query. 393 | optional string name = 2; 394 | } 395 | 396 | message AwaitAnyTerminationResult { 397 | bool terminated = 1; 398 | } 399 | 400 | message StreamingQueryListenerInstance { 401 | bytes listener_payload = 1; 402 | } 403 | 404 | message ListStreamingQueryListenerResult { 405 | // (Required) Reference IDs of listener instances. 406 | repeated string listener_ids = 1; 407 | } 408 | } 409 | 410 | // Command to get the output of 'SparkContext.resources' 411 | message GetResourcesCommand { } 412 | 413 | // Response for command 'GetResourcesCommand'. 414 | message GetResourcesCommandResult { 415 | map resources = 1; 416 | } 417 | -------------------------------------------------------------------------------- /crates/connect/src/window.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Utility structs for defining a window over a DataFrame 19 | 20 | use crate::column::Column; 21 | use crate::expressions::VecExpression; 22 | use crate::functions::lit; 23 | use crate::plan::sort_order; 24 | 25 | use crate::spark; 26 | use crate::spark::expression::window; 27 | 28 | /// A window specification that defines the partitioning, ordering, and frame boundaries. 29 | /// 30 | /// **Recommended to create a WindowSpec using [Window] and not directly** 31 | #[derive(Debug, Default, Clone)] 32 | pub struct WindowSpec { 33 | pub partition_spec: Vec, 34 | pub order_spec: Vec, 35 | pub frame_spec: Option>, 36 | } 37 | 38 | impl WindowSpec { 39 | pub fn new( 40 | partition_spec: Vec, 41 | order_spec: Vec, 42 | frame_spec: Option>, 43 | ) -> WindowSpec { 44 | WindowSpec { 45 | partition_spec, 46 | order_spec, 47 | frame_spec, 48 | } 49 | } 50 | 51 | pub fn partition_by(self, cols: I) -> WindowSpec 52 | where 53 | I: IntoIterator, 54 | S: Into, 55 | { 56 | WindowSpec::new( 57 | VecExpression::from_iter(cols).into(), 58 | self.order_spec, 59 | self.frame_spec, 60 | ) 61 | } 62 | 63 | pub fn order_by(self, cols: I) -> WindowSpec 64 | where 65 | I: IntoIterator, 66 | S: Into, 67 | { 68 | let order_spec = sort_order(cols); 69 | 70 | WindowSpec::new(self.partition_spec, order_spec, self.frame_spec) 71 | } 72 | 73 | pub fn rows_between(self, start: i64, end: i64) -> WindowSpec { 74 | let frame_spec = WindowSpec::window_frame(true, start, end); 75 | 76 | WindowSpec::new(self.partition_spec, self.order_spec, frame_spec) 77 | } 78 | 79 | pub fn range_between(self, start: i64, end: i64) -> WindowSpec { 80 | let frame_spec = WindowSpec::window_frame(false, start, end); 81 | 82 | WindowSpec::new(self.partition_spec, self.order_spec, frame_spec) 83 | } 84 | 85 | fn frame_boundary(value: i64) -> Option> { 86 | match value { 87 | 0 => { 88 | let boundary = Some(window::window_frame::frame_boundary::Boundary::CurrentRow( 89 | true, 90 | )); 91 | 92 | Some(Box::new(window::window_frame::FrameBoundary { boundary })) 93 | } 94 | i64::MIN => { 95 | let boundary = Some(window::window_frame::frame_boundary::Boundary::Unbounded( 96 | true, 97 | )); 98 | 99 | Some(Box::new(window::window_frame::FrameBoundary { boundary })) 100 | } 101 | _ => { 102 | // !TODO - I don't like casting this to i32 103 | // however, the window boundary is expecting an INT and not a BIGINT 104 | // i64 is a BIGINT (i.e. Long) 105 | let expr = lit(value as i32); 106 | 107 | let boundary = Some(window::window_frame::frame_boundary::Boundary::Value( 108 | Box::new(expr.into()), 109 | )); 110 | 111 | Some(Box::new(window::window_frame::FrameBoundary { boundary })) 112 | } 113 | } 114 | } 115 | 116 | fn window_frame(row_frame: bool, start: i64, end: i64) -> Option> { 117 | let frame_type = match row_frame { 118 | true => 1, 119 | false => 2, 120 | }; 121 | 122 | let lower = WindowSpec::frame_boundary(start); 123 | let upper = WindowSpec::frame_boundary(end); 124 | 125 | Some(Box::new(window::WindowFrame { 126 | frame_type, 127 | lower, 128 | upper, 129 | })) 130 | } 131 | } 132 | 133 | /// Primary utility struct for defining window in DataFrames 134 | #[derive(Debug, Default, Clone)] 135 | pub struct Window { 136 | spec: WindowSpec, 137 | } 138 | 139 | impl Window { 140 | /// Creates a new empty [WindowSpec] 141 | pub fn new() -> Self { 142 | Window { 143 | spec: WindowSpec::default(), 144 | } 145 | } 146 | 147 | /// Returns 0 148 | pub fn current_row() -> i64 { 149 | 0 150 | } 151 | 152 | /// Returns [i64::MAX] 153 | pub fn unbounded_following() -> i64 { 154 | i64::MAX 155 | } 156 | 157 | /// Returns [i64::MIN] 158 | pub fn unbounded_preceding() -> i64 { 159 | i64::MIN 160 | } 161 | 162 | /// Creates a [WindowSpec] with the partitioning defined 163 | pub fn partition_by(mut self, cols: I) -> WindowSpec 164 | where 165 | I: IntoIterator, 166 | S: Into, 167 | { 168 | self.spec = self.spec.partition_by(cols); 169 | 170 | self.spec 171 | } 172 | 173 | /// Creates a [WindowSpec] with the ordering defined 174 | pub fn order_by(mut self, cols: I) -> WindowSpec 175 | where 176 | I: IntoIterator, 177 | S: Into, 178 | { 179 | self.spec = self.spec.order_by(cols); 180 | 181 | self.spec 182 | } 183 | 184 | /// Creates a [WindowSpec] with the frame boundaries defined, from start (inclusive) to end (inclusive). 185 | /// 186 | /// Both start and end are relative from the current row. For example, “0” means “current row”, 187 | /// while “-1” means one off before the current row, and “5” means the five off after the current row. 188 | /// 189 | /// Recommended to use [Window::unbounded_preceding], [Window::unbounded_following], and [Window::current_row] 190 | /// to specify special boundary values, rather than using integral values directly. 191 | /// 192 | /// # Example 193 | /// 194 | /// ``` 195 | /// let window = Window::new() 196 | /// .partition_by(col("name")) 197 | /// .order_by([col("age")]) 198 | /// .range_between(Window::unbounded_preceding(), Window::current_row()); 199 | /// 200 | /// let df = df.with_column("rank", rank().over(window.clone())) 201 | /// .with_column("min", min("age").over(window)); 202 | /// ``` 203 | pub fn range_between(mut self, start: i64, end: i64) -> WindowSpec { 204 | self.spec = self.spec.range_between(start, end); 205 | 206 | self.spec 207 | } 208 | 209 | /// Creates a [WindowSpec] with the frame boundaries defined, from start (inclusive) to end (inclusive). 210 | /// 211 | /// Both start and end are relative from the current row. For example, “0” means “current row”, 212 | /// while “-1” means one off before the current row, and “5” means the five off after the current row. 213 | /// 214 | /// Recommended to use [Window::unbounded_preceding], [Window::unbounded_following], and [Window::current_row] 215 | /// to specify special boundary values, rather than using integral values directly. 216 | /// 217 | /// # Example 218 | /// 219 | /// ``` 220 | /// let window = Window::new() 221 | /// .partition_by(col("name")) 222 | /// .order_by([col("age")]) 223 | /// .rows_between(Window::unbounded_preceding(), Window::current_row()); 224 | /// 225 | /// let df = df.with_column("rank", rank().over(window.clone())) 226 | /// .with_column("min", min("age").over(window)); 227 | /// ``` 228 | pub fn rows_between(mut self, start: i64, end: i64) -> WindowSpec { 229 | self.spec = self.spec.rows_between(start, end); 230 | 231 | self.spec 232 | } 233 | } 234 | 235 | #[cfg(test)] 236 | mod tests { 237 | 238 | use arrow::{ 239 | array::{ArrayRef, Int32Array, Int64Array, StringArray}, 240 | datatypes::{DataType, Field, Schema}, 241 | record_batch::RecordBatch, 242 | }; 243 | 244 | use std::sync::Arc; 245 | 246 | use super::*; 247 | 248 | use crate::errors::SparkError; 249 | use crate::functions::*; 250 | use crate::SparkSession; 251 | use crate::SparkSessionBuilder; 252 | 253 | async fn setup() -> SparkSession { 254 | println!("SparkSession Setup"); 255 | 256 | let connection = "sc://127.0.0.1:15002/;user_id=rust_window"; 257 | 258 | SparkSessionBuilder::remote(connection) 259 | .build() 260 | .await 261 | .unwrap() 262 | } 263 | 264 | fn mock_data() -> RecordBatch { 265 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 2, 1, 2, 3])); 266 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "a", "b", "b", "b"])); 267 | 268 | RecordBatch::try_from_iter(vec![("id", id), ("category", category)]).unwrap() 269 | } 270 | 271 | #[tokio::test] 272 | async fn test_window_over() -> Result<(), SparkError> { 273 | let spark = setup().await; 274 | 275 | let name: ArrayRef = Arc::new(StringArray::from(vec!["Alice", "Bob"])); 276 | let age: ArrayRef = Arc::new(Int64Array::from(vec![2, 5])); 277 | 278 | let data = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])?; 279 | 280 | let df = spark.create_dataframe(&data)?; 281 | 282 | let window = Window::new() 283 | .partition_by([col("name")]) 284 | .order_by([col("age")]) 285 | .rows_between(Window::unbounded_preceding(), Window::current_row()); 286 | 287 | let res = df 288 | .with_column("rank", rank().over(window.clone())) 289 | .with_column("min", min("age").over(window)) 290 | .collect() 291 | .await?; 292 | 293 | let name: ArrayRef = Arc::new(StringArray::from(vec!["Alice", "Bob"])); 294 | let age: ArrayRef = Arc::new(Int64Array::from(vec![2, 5])); 295 | let rank: ArrayRef = Arc::new(Int32Array::from(vec![1, 1])); 296 | let min = age.clone(); 297 | 298 | let schema = Schema::new(vec![ 299 | Field::new("name", DataType::Utf8, false), 300 | Field::new("age", DataType::Int64, false), 301 | Field::new("rank", DataType::Int32, false), 302 | Field::new("min", DataType::Int64, true), 303 | ]); 304 | 305 | let expected = RecordBatch::try_new(Arc::new(schema), vec![name, age, rank, min])?; 306 | 307 | assert_eq!(expected, res); 308 | 309 | Ok(()) 310 | } 311 | 312 | #[tokio::test] 313 | async fn test_window_orderby() -> Result<(), SparkError> { 314 | let spark = setup().await; 315 | 316 | let data = mock_data(); 317 | 318 | let df = spark.create_dataframe(&data)?; 319 | 320 | let window = Window::new() 321 | .partition_by([col("id")]) 322 | .order_by([col("category")]); 323 | 324 | let res = df 325 | .with_column("row_number", row_number().over(window)) 326 | .collect() 327 | .await?; 328 | 329 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 1, 2, 2, 3])); 330 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "b", "a", "b", "b"])); 331 | let row_number: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 1])); 332 | 333 | let expected = RecordBatch::try_from_iter(vec![ 334 | ("id", id), 335 | ("category", category), 336 | ("row_number", row_number), 337 | ])?; 338 | 339 | assert_eq!(expected, res); 340 | 341 | Ok(()) 342 | } 343 | 344 | #[tokio::test] 345 | async fn test_window_partitionby() -> Result<(), SparkError> { 346 | let spark = setup().await; 347 | 348 | let data = mock_data(); 349 | 350 | let df = spark.create_dataframe(&data)?; 351 | 352 | let window = Window::new() 353 | .partition_by([col("category")]) 354 | .order_by([col("id")]); 355 | 356 | let res = df 357 | .with_column("row_number", row_number().over(window)) 358 | .collect() 359 | .await?; 360 | 361 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 2, 1, 2, 3])); 362 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "a", "b", "b", "b"])); 363 | let row_number: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3])); 364 | 365 | let expected = RecordBatch::try_from_iter(vec![ 366 | ("id", id), 367 | ("category", category), 368 | ("row_number", row_number), 369 | ])?; 370 | 371 | assert_eq!(expected, res); 372 | 373 | Ok(()) 374 | } 375 | 376 | #[tokio::test] 377 | async fn test_window_rangebetween() -> Result<(), SparkError> { 378 | let spark = setup().await; 379 | 380 | let data = mock_data(); 381 | 382 | let df = spark.create_dataframe(&data)?; 383 | 384 | let window = Window::new() 385 | .partition_by([col("category")]) 386 | .order_by([col("id")]) 387 | .range_between(Window::current_row(), 1); 388 | 389 | let res = df 390 | .with_column("sum", sum("id").over(window)) 391 | .sort([col("id"), col("category")]) 392 | .collect() 393 | .await?; 394 | 395 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 1, 2, 2, 3])); 396 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "b", "a", "b", "b"])); 397 | let sum: ArrayRef = Arc::new(Int64Array::from(vec![4, 4, 3, 2, 5, 3])); 398 | 399 | let expected = RecordBatch::try_from_iter_with_nullable(vec![ 400 | ("id", id, false), 401 | ("category", category, false), 402 | ("sum", sum, true), 403 | ])?; 404 | 405 | assert_eq!(expected, res); 406 | 407 | Ok(()) 408 | } 409 | 410 | #[tokio::test] 411 | async fn test_window_rowsbetween() -> Result<(), SparkError> { 412 | let spark = setup().await; 413 | 414 | let data = mock_data(); 415 | 416 | let df = spark.create_dataframe(&data)?; 417 | 418 | let window = Window::new() 419 | .partition_by([col("category")]) 420 | .order_by([col("id")]) 421 | .rows_between(Window::current_row(), 1); 422 | 423 | let res = df 424 | .with_column("sum", sum("id").over(window)) 425 | .sort([col("id"), col("category"), col("sum")]) 426 | .collect() 427 | .await?; 428 | 429 | let id: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 1, 2, 2, 3])); 430 | let category: ArrayRef = Arc::new(StringArray::from(vec!["a", "a", "b", "a", "b", "b"])); 431 | let sum: ArrayRef = Arc::new(Int64Array::from(vec![2, 3, 3, 2, 5, 3])); 432 | 433 | let expected = RecordBatch::try_from_iter_with_nullable(vec![ 434 | ("id", id, false), 435 | ("category", category, false), 436 | ("sum", sum, true), 437 | ])?; 438 | 439 | assert_eq!(expected, res); 440 | 441 | Ok(()) 442 | } 443 | } 444 | -------------------------------------------------------------------------------- /crates/connect/src/session.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! Spark Session containing the remote gRPC client 19 | 20 | use std::collections::HashMap; 21 | use std::sync::Arc; 22 | 23 | use crate::client::{ChannelBuilder, Config, HeadersLayer, SparkClient, SparkConnectClient}; 24 | 25 | use crate::catalog::Catalog; 26 | use crate::conf::RunTimeConfig; 27 | use crate::dataframe::{DataFrame, DataFrameReader}; 28 | use crate::errors::SparkError; 29 | use crate::plan::LogicalPlanBuilder; 30 | use crate::streaming::{DataStreamReader, StreamingQueryManager}; 31 | 32 | use crate::spark; 33 | use spark::spark_connect_service_client::SparkConnectServiceClient; 34 | 35 | use arrow::record_batch::RecordBatch; 36 | 37 | use tokio::sync::RwLock; 38 | 39 | use tower::ServiceBuilder; 40 | 41 | use tonic::transport::Channel; 42 | 43 | /// SparkSessionBuilder creates a remote Spark Session a connection string. 44 | /// 45 | /// The connection string is define based on the requirements from [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) 46 | #[derive(Clone, Debug)] 47 | pub struct SparkSessionBuilder { 48 | pub channel_builder: ChannelBuilder, 49 | configs: HashMap, 50 | } 51 | 52 | /// Default connects a Spark cluster running at `sc://127.0.0.1:15002/` 53 | impl Default for SparkSessionBuilder { 54 | fn default() -> Self { 55 | let channel_builder = ChannelBuilder::default(); 56 | 57 | Self { 58 | channel_builder, 59 | configs: HashMap::new(), 60 | } 61 | } 62 | } 63 | 64 | impl SparkSessionBuilder { 65 | fn new(connection: &str) -> Self { 66 | let channel_builder = ChannelBuilder::create(connection).unwrap(); 67 | 68 | Self { 69 | channel_builder, 70 | configs: HashMap::new(), 71 | } 72 | } 73 | 74 | /// Create a new Spark Session from a [Config] object 75 | pub fn from_config(config: Config) -> Self { 76 | Self { 77 | channel_builder: config.into(), 78 | configs: HashMap::new(), 79 | } 80 | } 81 | 82 | /// Validate a connect string for a remote Spark Session 83 | /// 84 | /// String must conform to the [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) 85 | pub fn remote(connection: &str) -> Self { 86 | Self::new(connection) 87 | } 88 | 89 | /// Sets a config option. 90 | pub fn config(mut self, key: &str, value: &str) -> Self { 91 | self.configs.insert(key.into(), value.into()); 92 | self 93 | } 94 | 95 | /// Sets a name for the application, which will be shown in the Spark web UI. 96 | pub fn app_name(mut self, name: &str) -> Self { 97 | self.configs 98 | .insert("spark.app.name".to_string(), name.into()); 99 | self 100 | } 101 | 102 | async fn create_client(&self) -> Result { 103 | let channel = Channel::from_shared(self.channel_builder.endpoint())? 104 | .connect() 105 | .await?; 106 | 107 | let channel = ServiceBuilder::new() 108 | .layer(HeadersLayer::new( 109 | self.channel_builder.headers().unwrap_or_default(), 110 | )) 111 | .service(channel); 112 | 113 | let client = SparkConnectServiceClient::new(channel); 114 | 115 | let spark_connnect_client = 116 | SparkConnectClient::new(Arc::new(RwLock::new(client)), self.channel_builder.clone()); 117 | 118 | let mut rt_config = RunTimeConfig::new(&spark_connnect_client); 119 | 120 | rt_config.set_configs(&self.configs).await?; 121 | 122 | Ok(SparkSession::new(spark_connnect_client)) 123 | } 124 | 125 | /// Attempt to connect to a remote Spark Session 126 | /// 127 | /// and return a [SparkSession] 128 | pub async fn build(&self) -> Result { 129 | self.create_client().await 130 | } 131 | } 132 | 133 | /// The entry point to connecting to a Spark Cluster 134 | /// using the Spark Connection gRPC protocol. 135 | #[derive(Clone, Debug)] 136 | pub struct SparkSession { 137 | client: SparkClient, 138 | session_id: String, 139 | } 140 | 141 | impl SparkSession { 142 | pub fn new(client: SparkClient) -> Self { 143 | Self { 144 | session_id: client.session_id(), 145 | client, 146 | } 147 | } 148 | 149 | pub fn session(&self) -> SparkSession { 150 | self.clone() 151 | } 152 | 153 | /// Create a [DataFrame] with a spingle column named `id`, 154 | /// containing elements in a range from `start` (default 0) to 155 | /// `end` (exclusive) with a step value `step`, and control the number 156 | /// of partitions with `num_partitions` 157 | pub fn range( 158 | &self, 159 | start: Option, 160 | end: i64, 161 | step: i64, 162 | num_partitions: Option, 163 | ) -> DataFrame { 164 | let range_relation = spark::relation::RelType::Range(spark::Range { 165 | start, 166 | end, 167 | step, 168 | num_partitions, 169 | }); 170 | 171 | DataFrame::new(self.session(), LogicalPlanBuilder::from(range_relation)) 172 | } 173 | 174 | /// Returns a [DataFrameReader] that can be used to read datra in as a [DataFrame] 175 | pub fn read(&self) -> DataFrameReader { 176 | DataFrameReader::new(self.session()) 177 | } 178 | 179 | /// Returns a [DataFrameReader] that can be used to read datra in as a [DataFrame] 180 | pub fn read_stream(&self) -> DataStreamReader { 181 | DataStreamReader::new(self.session()) 182 | } 183 | 184 | pub fn table(&self, name: &str) -> Result { 185 | DataFrameReader::new(self.session()).table(name, None) 186 | } 187 | 188 | /// Interface through which the user may create, drop, alter or query underlying databases, 189 | /// tables, functions, etc. 190 | pub fn catalog(&self) -> Catalog { 191 | Catalog::new(self.session()) 192 | } 193 | 194 | /// Returns a [DataFrame] representing the result of the given query 195 | pub async fn sql(&self, sql_query: &str) -> Result { 196 | let sql_cmd = spark::command::CommandType::SqlCommand(spark::SqlCommand { 197 | sql: sql_query.to_string(), 198 | args: HashMap::default(), 199 | pos_args: vec![], 200 | }); 201 | 202 | let plan = LogicalPlanBuilder::plan_cmd(sql_cmd); 203 | 204 | let resp = self 205 | .clone() 206 | .client() 207 | .execute_command_and_fetch(plan) 208 | .await?; 209 | 210 | let relation = resp.sql_command_result.to_owned().unwrap().relation; 211 | 212 | let logical_plan = LogicalPlanBuilder::new(relation.unwrap()); 213 | 214 | Ok(DataFrame::new(self.session(), logical_plan)) 215 | } 216 | 217 | pub fn create_dataframe(&self, data: &RecordBatch) -> Result { 218 | let logical_plan = LogicalPlanBuilder::local_relation(data)?; 219 | 220 | Ok(DataFrame::new(self.session(), logical_plan)) 221 | } 222 | 223 | /// Return the session ID 224 | pub fn session_id(&self) -> &str { 225 | &self.session_id 226 | } 227 | 228 | /// Spark Connection gRPC client interface 229 | pub fn client(self) -> SparkClient { 230 | self.client 231 | } 232 | 233 | /// Interrupt all operations of this session currently running on the connected server. 234 | pub async fn interrupt_all(&self) -> Result, SparkError> { 235 | let resp = self 236 | .client 237 | .interrupt_request(spark::interrupt_request::InterruptType::All, None) 238 | .await?; 239 | 240 | Ok(resp.interrupted_ids) 241 | } 242 | 243 | /// Interrupt all operations of this session with the given operation tag. 244 | pub async fn interrupt_tag(&self, tag: &str) -> Result, SparkError> { 245 | let resp = self 246 | .client 247 | .interrupt_request( 248 | spark::interrupt_request::InterruptType::Tag, 249 | Some(tag.to_string()), 250 | ) 251 | .await?; 252 | 253 | Ok(resp.interrupted_ids) 254 | } 255 | 256 | /// Interrupt an operation of this session with the given operationId. 257 | pub async fn interrupt_operation(&self, op_id: &str) -> Result, SparkError> { 258 | let resp = self 259 | .client 260 | .interrupt_request( 261 | spark::interrupt_request::InterruptType::OperationId, 262 | Some(op_id.to_string()), 263 | ) 264 | .await?; 265 | 266 | Ok(resp.interrupted_ids) 267 | } 268 | 269 | /// Add a tag to be assigned to all the operations started by this thread in this session. 270 | pub fn add_tag(&mut self, tag: &str) -> Result<(), SparkError> { 271 | self.client.add_tag(tag) 272 | } 273 | 274 | /// Remove a tag previously added to be assigned to all the operations started by this thread in this session. 275 | pub fn remove_tag(&mut self, tag: &str) -> Result<(), SparkError> { 276 | self.client.remove_tag(tag) 277 | } 278 | 279 | /// Get the tags that are currently set to be assigned to all the operations started by this thread. 280 | pub fn get_tags(&mut self) -> &Vec { 281 | self.client.get_tags() 282 | } 283 | 284 | /// Clear the current thread’s operation tags. 285 | pub fn clear_tags(&mut self) { 286 | self.client.clear_tags() 287 | } 288 | 289 | /// The version of Spark on which this application is running. 290 | pub async fn version(&self) -> Result { 291 | let version = spark::analyze_plan_request::Analyze::SparkVersion( 292 | spark::analyze_plan_request::SparkVersion {}, 293 | ); 294 | 295 | let mut client = self.client.clone(); 296 | 297 | client.analyze(version).await?.spark_version() 298 | } 299 | 300 | /// [RunTimeConfig] configuration interface for Spark. 301 | pub fn conf(&self) -> RunTimeConfig { 302 | RunTimeConfig::new(&self.client) 303 | } 304 | 305 | /// Returns a [StreamingQueryManager] that allows managing all the StreamingQuery instances active on this context. 306 | pub fn streams(&self) -> StreamingQueryManager { 307 | StreamingQueryManager::new(self) 308 | } 309 | } 310 | 311 | #[cfg(test)] 312 | mod tests { 313 | use super::*; 314 | 315 | use arrow::{ 316 | array::{ArrayRef, StringArray}, 317 | record_batch::RecordBatch, 318 | }; 319 | 320 | use regex::Regex; 321 | 322 | async fn setup() -> SparkSession { 323 | println!("SparkSession Setup"); 324 | 325 | let connection = "sc://127.0.0.1:15002/;user_id=rust_test;session_id=0d2af2a9-cc3c-4d4b-bf27-e2fefeaca233"; 326 | 327 | SparkSessionBuilder::remote(connection) 328 | .build() 329 | .await 330 | .unwrap() 331 | } 332 | 333 | #[tokio::test] 334 | async fn test_spark_range() -> Result<(), SparkError> { 335 | let spark = setup().await; 336 | 337 | let df = spark.range(None, 100, 1, Some(8)); 338 | 339 | let records = df.collect().await?; 340 | 341 | assert_eq!(records.num_rows(), 100); 342 | Ok(()) 343 | } 344 | 345 | #[tokio::test] 346 | async fn test_spark_create_dataframe() -> Result<(), SparkError> { 347 | let spark = setup().await; 348 | 349 | let a: ArrayRef = Arc::new(StringArray::from(vec!["hello", "world"])); 350 | 351 | let record_batch = RecordBatch::try_from_iter(vec![("a", a)])?; 352 | 353 | let df = spark.create_dataframe(&record_batch)?; 354 | 355 | let rows = df.collect().await?; 356 | 357 | assert_eq!(record_batch, rows); 358 | Ok(()) 359 | } 360 | 361 | #[tokio::test] 362 | async fn test_spark_session_create() { 363 | let connection = 364 | "sc://localhost:15002/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; 365 | 366 | let spark = SparkSessionBuilder::remote(connection).build().await; 367 | 368 | assert!(spark.is_ok()); 369 | } 370 | 371 | #[tokio::test] 372 | async fn test_session_tags() -> Result<(), SparkError> { 373 | let mut spark = SparkSessionBuilder::default().build().await?; 374 | 375 | spark.add_tag("hello-tag")?; 376 | 377 | spark.add_tag("hello-tag-2")?; 378 | 379 | let expected = vec!["hello-tag".to_string(), "hello-tag-2".to_string()]; 380 | 381 | let res = spark.get_tags(); 382 | 383 | assert_eq!(&expected, res); 384 | 385 | spark.clear_tags(); 386 | let res = spark.get_tags(); 387 | 388 | let expected: Vec = vec![]; 389 | 390 | assert_eq!(&expected, res); 391 | 392 | Ok(()) 393 | } 394 | 395 | #[tokio::test] 396 | async fn test_session_tags_panic() -> Result<(), SparkError> { 397 | let mut spark = SparkSessionBuilder::default().build().await?; 398 | 399 | assert!(spark.add_tag("bad,tag").is_err()); 400 | assert!(spark.add_tag("").is_err()); 401 | 402 | assert!(spark.remove_tag("bad,tag").is_err()); 403 | assert!(spark.remove_tag("").is_err()); 404 | 405 | Ok(()) 406 | } 407 | 408 | #[tokio::test] 409 | async fn test_session_version() -> Result<(), SparkError> { 410 | let spark = SparkSessionBuilder::default().build().await?; 411 | 412 | let version = spark.version().await?; 413 | 414 | let version_pattern = Regex::new(r"^\d+\.\d+\.\d+$").unwrap(); 415 | assert!( 416 | version_pattern.is_match(&version), 417 | "Version {} does not match X.X.X format", 418 | version 419 | ); 420 | 421 | Ok(()) 422 | } 423 | 424 | #[tokio::test] 425 | async fn test_session_config() -> Result<(), SparkError> { 426 | let value = "rust-test-app"; 427 | 428 | let spark = SparkSessionBuilder::default() 429 | .app_name("rust-test-app") 430 | .build() 431 | .await?; 432 | 433 | let name = spark.conf().get("spark.app.name", None).await?; 434 | 435 | assert_eq!(value, &name); 436 | 437 | // validate set 438 | spark 439 | .conf() 440 | .set("spark.sql.shuffle.partitions", "42") 441 | .await?; 442 | 443 | // validate get 444 | let val = spark 445 | .conf() 446 | .get("spark.sql.shuffle.partitions", None) 447 | .await?; 448 | 449 | assert_eq!("42", &val); 450 | 451 | // validate unset 452 | spark.conf().unset("spark.sql.shuffle.partitions").await?; 453 | 454 | let val = spark 455 | .conf() 456 | .get("spark.sql.shuffle.partitions", None) 457 | .await?; 458 | 459 | assert_eq!("200", &val); 460 | 461 | // not a modifable setting 462 | let val = spark 463 | .conf() 464 | .is_modifable("spark.executor.instances") 465 | .await?; 466 | assert!(!val); 467 | 468 | // a modifable setting 469 | let val = spark 470 | .conf() 471 | .is_modifable("spark.sql.shuffle.partitions") 472 | .await?; 473 | assert!(val); 474 | 475 | Ok(()) 476 | } 477 | } 478 | -------------------------------------------------------------------------------- /crates/connect/src/column.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | //! [Column] represents a column in a DataFrame that holds a [spark::Expression] 19 | use std::convert::From; 20 | use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Not, Rem, Sub}; 21 | 22 | use crate::spark; 23 | 24 | use crate::functions::invoke_func; 25 | use crate::window::WindowSpec; 26 | 27 | use spark::expression::cast::CastToType; 28 | 29 | /// # A column in a DataFrame. 30 | /// 31 | /// A column holds a specific [spark::Expression] which will be resolved once an action is called. 32 | /// The columns are resolved by the Spark Connect server of the remote session. 33 | /// 34 | /// A column instance can be created by in a similar way as to the Spark API. A column with created 35 | /// with `col("*")` or `col("name.*")` is created as an unresolved star attribute which will select 36 | /// all columns or references in the specified column. 37 | /// 38 | /// ```rust 39 | /// use spark_connect_rs::{SparkSession, SparkSessionBuilder}; 40 | /// 41 | /// let spark: SparkSession = SparkSessionBuilder::remote("sc://127.0.0.1:15002/;user_id=example_rs".to_string()) 42 | /// .build() 43 | /// .await?; 44 | /// 45 | /// // As a &str representing an unresolved column in the dataframe 46 | /// spark.range(None, 1, 1, Some(1)).select(["id"]); 47 | /// 48 | /// // By using the `col` function 49 | /// spark.range(None, 1, 1, Some(1)).select([col("id")]); 50 | /// 51 | /// // By using the `lit` function to return a literal value 52 | /// spark.range(None, 1, 1, Some(1)).select([lit(4.0).alias("num_col")]); 53 | /// ``` 54 | #[derive(Clone, Debug)] 55 | pub struct Column { 56 | /// a [spark::Expression] containing any unresolved value to be leveraged in a [spark::Plan] 57 | pub expression: spark::Expression, 58 | } 59 | 60 | impl Column { 61 | #[allow(clippy::should_implement_trait)] 62 | pub fn from_str(s: &str) -> Self { 63 | Self::from(s) 64 | } 65 | 66 | pub fn from_string(s: String) -> Self { 67 | Self::from(s.as_str()) 68 | } 69 | 70 | /// Returns the column with a new name 71 | /// 72 | /// # Example: 73 | /// ```rust 74 | /// let cols = [ 75 | /// col("name").alias("new_name"), 76 | /// col("age").alias("new_age") 77 | /// ]; 78 | /// 79 | /// df.select(cols); 80 | /// ``` 81 | pub fn alias(self, value: &str) -> Column { 82 | let alias = spark::expression::Alias { 83 | expr: Some(Box::new(self.expression)), 84 | name: vec![value.to_string()], 85 | metadata: None, 86 | }; 87 | 88 | let expression = spark::Expression { 89 | expr_type: Some(spark::expression::ExprType::Alias(Box::new(alias))), 90 | }; 91 | 92 | Column::from(expression) 93 | } 94 | 95 | /// An alias for the function `alias` 96 | pub fn name(self, value: &str) -> Column { 97 | self.alias(value) 98 | } 99 | 100 | /// Returns a sorted expression based on the ascending order of the column 101 | /// 102 | /// # Example: 103 | /// ```rust 104 | /// let df: DataFrame = df.sort([col("id").asc()]); 105 | /// 106 | /// let df: DataFrame = df.sort([asc(col("id"))]); 107 | /// ``` 108 | pub fn asc(self) -> Column { 109 | self.asc_nulls_first() 110 | } 111 | 112 | pub fn asc_nulls_first(self) -> Column { 113 | let asc = spark::expression::SortOrder { 114 | child: Some(Box::new(self.expression)), 115 | direction: 1, 116 | null_ordering: 1, 117 | }; 118 | 119 | let expression = spark::Expression { 120 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 121 | }; 122 | 123 | Column::from(expression) 124 | } 125 | 126 | pub fn asc_nulls_last(self) -> Column { 127 | let asc = spark::expression::SortOrder { 128 | child: Some(Box::new(self.expression)), 129 | direction: 1, 130 | null_ordering: 2, 131 | }; 132 | 133 | let expression = spark::Expression { 134 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 135 | }; 136 | 137 | Column::from(expression) 138 | } 139 | 140 | /// Returns a sorted expression based on the ascending order of the column 141 | /// 142 | /// # Example: 143 | /// ```rust 144 | /// let df: DataFrame = df.sort(col("id").desc()); 145 | /// 146 | /// let df: DataFrame = df.sort(desc(col("id"))); 147 | /// ``` 148 | pub fn desc(self) -> Column { 149 | self.desc_nulls_first() 150 | } 151 | 152 | pub fn desc_nulls_first(self) -> Column { 153 | let asc = spark::expression::SortOrder { 154 | child: Some(Box::new(self.expression)), 155 | direction: 2, 156 | null_ordering: 1, 157 | }; 158 | 159 | let expression = spark::Expression { 160 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 161 | }; 162 | 163 | Column::from(expression) 164 | } 165 | 166 | pub fn desc_nulls_last(self) -> Column { 167 | let asc = spark::expression::SortOrder { 168 | child: Some(Box::new(self.expression)), 169 | direction: 2, 170 | null_ordering: 2, 171 | }; 172 | 173 | let expression = spark::Expression { 174 | expr_type: Some(spark::expression::ExprType::SortOrder(Box::new(asc))), 175 | }; 176 | 177 | Column::from(expression) 178 | } 179 | 180 | pub fn drop_fields(self, field_names: I) -> Column 181 | where 182 | I: IntoIterator>, 183 | { 184 | let mut parent_col = self.expression; 185 | 186 | for field in field_names { 187 | parent_col = spark::Expression { 188 | expr_type: Some(spark::expression::ExprType::UpdateFields(Box::new( 189 | spark::expression::UpdateFields { 190 | struct_expression: Some(Box::new(parent_col)), 191 | field_name: field.as_ref().to_string(), 192 | value_expression: None, 193 | }, 194 | ))), 195 | }; 196 | } 197 | 198 | Column::from(parent_col) 199 | } 200 | 201 | pub fn with_field(self, field_name: &str, col: impl Into) -> Column { 202 | let update_field = spark::Expression { 203 | expr_type: Some(spark::expression::ExprType::UpdateFields(Box::new( 204 | spark::expression::UpdateFields { 205 | struct_expression: Some(Box::new(self.expression)), 206 | field_name: field_name.to_string(), 207 | value_expression: Some(Box::new(col.into().expression)), 208 | }, 209 | ))), 210 | }; 211 | 212 | Column::from(update_field) 213 | } 214 | 215 | pub fn substr(self, start_pos: impl Into, length: impl Into) -> Column { 216 | invoke_func("substr", vec![self, start_pos.into(), length.into()]) 217 | } 218 | 219 | /// Casts the column into the Spark DataType 220 | /// 221 | /// # Arguments: 222 | /// 223 | /// * `to_type` is a string or [crate::types::DataType] of the target type 224 | /// 225 | /// # Example: 226 | /// ```rust 227 | /// use crate::types::DataType; 228 | /// 229 | /// let df = df.select([ 230 | /// col("age").cast("int"), 231 | /// col("name").cast("string") 232 | /// ]) 233 | /// 234 | /// // Using DataTypes 235 | /// let df = df.select([ 236 | /// col("age").cast(DataType::Integer), 237 | /// col("name").cast(DataType::String) 238 | /// ]) 239 | /// ``` 240 | pub fn cast(self, to_type: impl Into) -> Column { 241 | let cast = spark::expression::Cast { 242 | expr: Some(Box::new(self.expression)), 243 | cast_to_type: Some(to_type.into()), 244 | }; 245 | 246 | let expression = spark::Expression { 247 | expr_type: Some(spark::expression::ExprType::Cast(Box::new(cast))), 248 | }; 249 | 250 | Column::from(expression) 251 | } 252 | 253 | /// A boolean expression that is evaluated to `true` if the value of the expression is 254 | /// contained by the evaluated values of the arguments 255 | /// 256 | /// # Arguments: 257 | /// 258 | /// * `cols` a vector of Columns 259 | /// 260 | /// # Example: 261 | /// ```rust 262 | /// df.filter(col("name").isin([lit("Jorge"), lit("Bob")])); 263 | /// ``` 264 | pub fn isin(self, cols: Vec) -> Column { 265 | let mut val = cols.clone(); 266 | 267 | val.insert(0, self); 268 | 269 | invoke_func("in", val) 270 | } 271 | 272 | /// A boolean expression that is evaluated to `true` if the value is in the Column 273 | /// 274 | /// # Arguments: 275 | /// 276 | /// * `cols`: a col reference that is translated into an [spark::Expression] 277 | /// 278 | /// # Example: 279 | /// ```rust 280 | /// df.filter(col("name").contains("ge")); 281 | /// ``` 282 | pub fn contains(self, other: impl Into) -> Column { 283 | invoke_func("contains", vec![self, other.into()]) 284 | } 285 | 286 | /// A filter expression that evaluates if the column startswith a string literal 287 | pub fn startswith(self, other: impl Into) -> Column { 288 | invoke_func("startswith", vec![self, other.into()]) 289 | } 290 | 291 | /// A filter expression that evaluates if the column endswith a string literal 292 | pub fn endswith(self, other: impl Into) -> Column { 293 | invoke_func("endswith", vec![self, other.into()]) 294 | } 295 | 296 | /// A SQL LIKE filter expression that evaluates the column based on a case sensitive match 297 | pub fn like(self, other: impl Into) -> Column { 298 | invoke_func("like", vec![self, other.into()]) 299 | } 300 | 301 | /// A SQL ILIKE filter expression that evaluates the column based on a case insensitive match 302 | pub fn ilike(self, other: impl Into) -> Column { 303 | invoke_func("ilike", vec![self, other.into()]) 304 | } 305 | 306 | /// A SQL RLIKE filter expression that evaluates the column based on a regex match 307 | pub fn rlike(self, other: impl Into) -> Column { 308 | invoke_func("rlike", vec![self, other.into()]) 309 | } 310 | 311 | /// Equality comparion. Cannot overload the '==' and return something other 312 | /// than a bool 313 | pub fn eq(self, other: impl Into) -> Column { 314 | invoke_func("==", vec![self, other.into()]) 315 | } 316 | 317 | /// Logical AND comparion. Cannot overload the '&&' and return something other 318 | /// than a bool 319 | pub fn and(self, other: impl Into) -> Column { 320 | invoke_func("and", vec![self, other.into()]) 321 | } 322 | 323 | /// Logical OR comparion. 324 | pub fn or(self, other: impl Into) -> Column { 325 | invoke_func("or", vec![self, other.into()]) 326 | } 327 | 328 | /// A filter expression that evaluates to true is the expression is null 329 | pub fn is_null(self) -> Column { 330 | invoke_func("isnull", vec![self]) 331 | } 332 | 333 | /// A filter expression that evaluates to true is the expression is NOT null 334 | pub fn is_not_null(self) -> Column { 335 | invoke_func("isnotnull", vec![self]) 336 | } 337 | 338 | pub fn is_nan(self) -> Column { 339 | invoke_func("isNaN", vec![self]) 340 | } 341 | 342 | /// Defines a windowing column 343 | /// # Arguments: 344 | /// 345 | /// * `window`: a [WindowSpec] 346 | /// 347 | /// # Example 348 | /// 349 | /// ``` 350 | /// let window = Window::new() 351 | /// .partition_by([col("name")]) 352 | /// .order_by([col("age")]) 353 | /// .range_between(Window::unbounded_preceding(), Window::current_row()); 354 | /// 355 | /// let df = df.with_column("rank", rank().over(window.clone())) 356 | /// .with_column("min", min("age").over(window)); 357 | /// ``` 358 | pub fn over(self, window: WindowSpec) -> Column { 359 | let window_expr = spark::expression::Window { 360 | window_function: Some(Box::new(self.expression)), 361 | partition_spec: window.partition_spec, 362 | order_spec: window.order_spec, 363 | frame_spec: window.frame_spec, 364 | }; 365 | 366 | let expression = spark::Expression { 367 | expr_type: Some(spark::expression::ExprType::Window(Box::new(window_expr))), 368 | }; 369 | 370 | Column::from(expression) 371 | } 372 | } 373 | 374 | impl From for Column { 375 | /// Used for creating columns from a [spark::Expression] 376 | fn from(expression: spark::Expression) -> Self { 377 | Self { expression } 378 | } 379 | } 380 | 381 | impl From for Column { 382 | /// Used for creating columns from a [spark::Expression] 383 | fn from(expression: spark::expression::Literal) -> Self { 384 | Self::from(spark::Expression { 385 | expr_type: Some(spark::expression::ExprType::Literal(expression)), 386 | }) 387 | } 388 | } 389 | 390 | impl From for Column { 391 | fn from(value: String) -> Self { 392 | Column::from_string(value) 393 | } 394 | } 395 | 396 | impl From<&String> for Column { 397 | fn from(value: &String) -> Self { 398 | Column::from_str(value.as_str()) 399 | } 400 | } 401 | 402 | impl From<&str> for Column { 403 | /// `&str` values containing a `*` will be created as an unresolved star expression 404 | /// Otherwise, the value is created as an unresolved attribute 405 | fn from(value: &str) -> Self { 406 | let expression = match value { 407 | "*" => spark::Expression { 408 | expr_type: Some(spark::expression::ExprType::UnresolvedStar( 409 | spark::expression::UnresolvedStar { 410 | unparsed_target: None, 411 | }, 412 | )), 413 | }, 414 | value if value.ends_with(".*") => spark::Expression { 415 | expr_type: Some(spark::expression::ExprType::UnresolvedStar( 416 | spark::expression::UnresolvedStar { 417 | unparsed_target: Some(value.to_string()), 418 | }, 419 | )), 420 | }, 421 | _ => spark::Expression { 422 | expr_type: Some(spark::expression::ExprType::UnresolvedAttribute( 423 | spark::expression::UnresolvedAttribute { 424 | unparsed_identifier: value.to_string(), 425 | plan_id: None, 426 | }, 427 | )), 428 | }, 429 | }; 430 | 431 | Column::from(expression) 432 | } 433 | } 434 | 435 | impl Add for Column { 436 | type Output = Self; 437 | 438 | fn add(self, other: Self) -> Self { 439 | invoke_func("+", vec![self, other]) 440 | } 441 | } 442 | 443 | impl Neg for Column { 444 | type Output = Self; 445 | 446 | fn neg(self) -> Self { 447 | invoke_func("negative", vec![self]) 448 | } 449 | } 450 | 451 | impl Sub for Column { 452 | type Output = Self; 453 | 454 | fn sub(self, other: Self) -> Self { 455 | invoke_func("-", vec![self, other]) 456 | } 457 | } 458 | 459 | impl Mul for Column { 460 | type Output = Self; 461 | 462 | fn mul(self, other: Self) -> Self { 463 | invoke_func("*", vec![self, other]) 464 | } 465 | } 466 | 467 | impl Div for Column { 468 | type Output = Self; 469 | 470 | fn div(self, other: Self) -> Self { 471 | invoke_func("/", vec![self, other]) 472 | } 473 | } 474 | 475 | impl Rem for Column { 476 | type Output = Self; 477 | 478 | fn rem(self, other: Self) -> Self { 479 | invoke_func("%", vec![self, other]) 480 | } 481 | } 482 | 483 | impl BitOr for Column { 484 | type Output = Self; 485 | 486 | fn bitor(self, other: Self) -> Self { 487 | invoke_func("|", vec![self, other]) 488 | } 489 | } 490 | 491 | impl BitAnd for Column { 492 | type Output = Self; 493 | 494 | fn bitand(self, other: Self) -> Self { 495 | invoke_func("&", vec![self, other]) 496 | } 497 | } 498 | 499 | impl BitXor for Column { 500 | type Output = Self; 501 | 502 | fn bitxor(self, other: Self) -> Self { 503 | invoke_func("^", vec![self, other]) 504 | } 505 | } 506 | 507 | impl Not for Column { 508 | type Output = Self; 509 | 510 | fn not(self) -> Self::Output { 511 | invoke_func("not", vec![self]) 512 | } 513 | } 514 | --------------------------------------------------------------------------------