├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .rustfmt.toml ├── Cargo.toml ├── LICENSE ├── README.md ├── benches └── main.rs ├── src ├── common.rs ├── common_macros.rs ├── common_union.rs ├── json_as_text.rs ├── json_contains.rs ├── json_get.rs ├── json_get_bool.rs ├── json_get_float.rs ├── json_get_int.rs ├── json_get_json.rs ├── json_get_str.rs ├── json_length.rs ├── json_object_keys.rs ├── lib.rs └── rewrite.rs └── tests ├── main.rs └── utils └── mod.rs /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - '**' 9 | pull_request: {} 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - uses: dtolnay/rust-toolchain@stable 18 | with: 19 | components: rustfmt, clippy 20 | 21 | - id: cache-rust 22 | uses: Swatinem/rust-cache@v2 23 | 24 | - uses: pre-commit/action@v3.0.0 25 | with: 26 | extra_args: --all-files --verbose 27 | env: 28 | PRE_COMMIT_COLOR: always 29 | SKIP: test 30 | 31 | resolve: 32 | runs-on: ubuntu-latest 33 | outputs: 34 | MSRV: ${{ steps.resolve-msrv.outputs.MSRV }} 35 | steps: 36 | - uses: actions/checkout@v4 37 | 38 | - name: set up python 39 | uses: actions/setup-python@v5 40 | with: 41 | python-version: '3.12' 42 | 43 | - name: resolve MSRV 44 | id: resolve-msrv 45 | run: 46 | echo MSRV=`python -c 'import tomllib; print(tomllib.load(open("Cargo.toml", "rb"))["package"]["rust-version"])'` >> $GITHUB_OUTPUT 47 | 48 | test: 49 | needs: [resolve] 50 | name: test rust-${{ matrix.rust-version }} 51 | strategy: 52 | fail-fast: false 53 | matrix: 54 | rust-version: [stable, nightly] 55 | include: 56 | - rust-version: ${{ needs.resolve.outputs.MSRV }} 57 | 58 | runs-on: ubuntu-latest 59 | 60 | env: 61 | RUST_VERSION: ${{ matrix.rust-version }} 62 | 63 | steps: 64 | - uses: actions/checkout@v3 65 | 66 | - uses: dtolnay/rust-toolchain@master 67 | with: 68 | toolchain: ${{ matrix.rust-version }} 69 | 70 | - id: cache-rust 71 | uses: Swatinem/rust-cache@v2 72 | 73 | - uses: taiki-e/install-action@cargo-llvm-cov 74 | 75 | - run: cargo llvm-cov --all-features --codecov --output-path codecov.json 76 | 77 | - uses: codecov/codecov-action@v3 78 | with: 79 | token: ${{ secrets.CODECOV_TOKEN }} 80 | files: codecov.json 81 | env_vars: RUST_VERSION 82 | 83 | # https://github.com/marketplace/actions/alls-green#why used for branch protection checks 84 | check: 85 | if: always() 86 | needs: [test, lint] 87 | runs-on: ubuntu-latest 88 | steps: 89 | - name: Decide whether the needed jobs succeeded or failed 90 | uses: re-actors/alls-green@release/v1 91 | with: 92 | jobs: ${{ toJSON(needs) }} 93 | 94 | release: 95 | needs: [check] 96 | if: "success() && startsWith(github.ref, 'refs/tags/')" 97 | runs-on: ubuntu-latest 98 | environment: release 99 | 100 | steps: 101 | - uses: actions/checkout@v2 102 | 103 | - name: install rust stable 104 | uses: dtolnay/rust-toolchain@stable 105 | 106 | - uses: Swatinem/rust-cache@v2 107 | 108 | - run: cargo publish 109 | env: 110 | CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .idea 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: true 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.0.1 6 | hooks: 7 | - id: check-yaml 8 | - id: check-toml 9 | - id: end-of-file-fixer 10 | - id: trailing-whitespace 11 | - id: check-added-large-files 12 | 13 | - repo: local 14 | hooks: 15 | - id: format 16 | name: Format 17 | entry: cargo fmt 18 | types: [rust] 19 | language: system 20 | pass_filenames: false 21 | - id: clippy 22 | name: Clippy 23 | entry: cargo clippy -- -D warnings 24 | types: [rust] 25 | language: system 26 | pass_filenames: false 27 | - id: test 28 | name: Test 29 | entry: cargo test 30 | types: [rust] 31 | language: system 32 | pass_filenames: false 33 | -------------------------------------------------------------------------------- /.rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 120 2 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "datafusion-functions-json" 3 | version = "0.47.0" 4 | edition = "2021" 5 | description = "JSON functions for DataFusion" 6 | readme = "README.md" 7 | license = "Apache-2.0" 8 | keywords = ["datafusion", "JSON", "SQL"] 9 | categories = ["database-implementations", "parsing"] 10 | repository = "https://github.com/datafusion-contrib/datafusion-functions-json/" 11 | rust-version = "1.82.0" 12 | 13 | [dependencies] 14 | datafusion = { version = "47", default-features = false } 15 | jiter = "0.9" 16 | paste = "1" 17 | log = "0.4" 18 | 19 | [dev-dependencies] 20 | datafusion = { version = "47", default-features = false, features = ["nested_expressions"] } 21 | codspeed-criterion-compat = "2.6" 22 | criterion = "0.5.1" 23 | clap = "4" 24 | tokio = { version = "1.43", features = ["full"] } 25 | 26 | [lints.clippy] 27 | dbg_macro = "deny" 28 | print_stdout = "deny" 29 | 30 | # in general, we lint against the pedantic group, but we will whitelist 31 | # certain lints which we don't want to enforce (for now) 32 | pedantic = { level = "deny", priority = -1 } 33 | 34 | [[bench]] 35 | name = "main" 36 | harness = false 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # datafusion-functions-json 2 | 3 | [![CI](https://github.com/datafusion-contrib/datafusion-functions-json/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/datafusion-contrib/datafusion-functions-json/actions/workflows/ci.yml?query=branch%3Amain) 4 | [![Crates.io](https://img.shields.io/crates/v/datafusion-functions-json?color=green)](https://crates.io/crates/datafusion-functions-json) 5 | 6 | **Note:** This is not an official Apache Software Foundation release, see [datafusion-contrib/datafusion-functions-json#5](https://github.com/datafusion-contrib/datafusion-functions-json/issues/5). 7 | 8 | This crate provides a set of functions for querying JSON strings in DataFusion. The functions are implemented as scalar functions that can be used in SQL queries. 9 | 10 | To use these functions, you'll just need to call: 11 | 12 | ```rust 13 | datafusion_functions_json::register_all(&mut ctx)?; 14 | ``` 15 | To register the below JSON functions in your `SessionContext`. 16 | 17 | # Examples 18 | 19 | ```sql 20 | -- Create a table with a JSON column stored as a string 21 | CREATE TABLE test_table (id INT, json_col VARCHAR) AS VALUES 22 | (1, '{}'), 23 | (2, '{ "a": 1 }'), 24 | (3, '{ "a": 2 }'), 25 | (4, '{ "a": 1, "b": 2 }'), 26 | (5, '{ "a": 1, "b": 2, "c": 3 }'); 27 | 28 | -- Check if each document contains the key 'b' 29 | SELECT id, json_contains(json_col, 'b') as json_contains FROM test_table; 30 | -- Results in 31 | -- +----+---------------+ 32 | -- | id | json_contains | 33 | -- +----+---------------+ 34 | -- | 1 | false | 35 | -- | 2 | false | 36 | -- | 3 | false | 37 | -- | 4 | true | 38 | -- | 5 | true | 39 | -- +----+---------------+ 40 | 41 | -- Get the value of the key 'a' from each document 42 | SELECT id, json_col->'a' as json_col_a FROM test_table 43 | 44 | -- +----+------------+ 45 | -- | id | json_col_a | 46 | -- +----+------------+ 47 | -- | 1 | {null=} | 48 | -- | 2 | {int=1} | 49 | -- | 3 | {int=2} | 50 | -- | 4 | {int=1} | 51 | -- | 5 | {int=1} | 52 | -- +----+------------+ 53 | ``` 54 | 55 | 56 | ## Done 57 | 58 | * [x] `json_contains(json: str, *keys: str | int) -> bool` - true if a JSON string has a specific key (used for the `?` operator) 59 | * [x] `json_get(json: str, *keys: str | int) -> JsonUnion` - Get a value from a JSON string by its "path" 60 | * [x] `json_get_str(json: str, *keys: str | int) -> str` - Get a string value from a JSON string by its "path" 61 | * [x] `json_get_int(json: str, *keys: str | int) -> int` - Get an integer value from a JSON string by its "path" 62 | * [x] `json_get_float(json: str, *keys: str | int) -> float` - Get a float value from a JSON string by its "path" 63 | * [x] `json_get_bool(json: str, *keys: str | int) -> bool` - Get a boolean value from a JSON string by its "path" 64 | * [x] `json_get_json(json: str, *keys: str | int) -> str` - Get a nested raw JSON string from a JSON string by its "path" 65 | * [x] `json_as_text(json: str, *keys: str | int) -> str` - Get any value from a JSON string by its "path", represented as a string (used for the `->>` operator) 66 | * [x] `json_length(json: str, *keys: str | int) -> int` - get the length of a JSON string or array 67 | 68 | - [x] `->` operator - alias for `json_get` 69 | - [x] `->>` operator - alias for `json_as_text` 70 | - [x] `?` operator - alias for `json_contains` 71 | 72 | ### Notes 73 | Cast expressions with `json_get` are rewritten to the appropriate method, e.g. 74 | 75 | ```sql 76 | select * from foo where json_get(attributes, 'bar')::string='ham' 77 | ``` 78 | Will be rewritten to: 79 | ```sql 80 | select * from foo where json_get_str(attributes, 'bar')='ham' 81 | ``` 82 | 83 | ## TODO (maybe, if they're actually useful) 84 | 85 | * [ ] `json_keys(json: str, *keys: str | int) -> list[str]` - get the keys of a JSON string 86 | * [ ] `json_is_obj(json: str, *keys: str | int) -> bool` - true if the JSON is an object 87 | * [ ] `json_is_array(json: str, *keys: str | int) -> bool` - true if the JSON is an array 88 | * [ ] `json_valid(json: str) -> bool` - true if the JSON is valid 89 | -------------------------------------------------------------------------------- /benches/main.rs: -------------------------------------------------------------------------------- 1 | use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion}; 2 | 3 | use datafusion::arrow::datatypes::DataType; 4 | use datafusion::logical_expr::ColumnarValue; 5 | use datafusion::{common::ScalarValue, logical_expr::ScalarFunctionArgs}; 6 | use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; 7 | 8 | fn bench_json_contains(b: &mut Bencher) { 9 | let json_contains = json_contains_udf(); 10 | let args = vec![ 11 | ColumnarValue::Scalar(ScalarValue::Utf8(Some( 12 | r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), 13 | ))), 14 | ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), 15 | ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), 16 | ]; 17 | 18 | b.iter(|| { 19 | json_contains 20 | .invoke_with_args(ScalarFunctionArgs { 21 | args: args.clone(), 22 | number_rows: 1, 23 | return_type: &DataType::Boolean, 24 | }) 25 | .unwrap() 26 | }); 27 | } 28 | 29 | fn bench_json_get_str(b: &mut Bencher) { 30 | let json_get_str = json_get_str_udf(); 31 | let args = &[ 32 | ColumnarValue::Scalar(ScalarValue::Utf8(Some( 33 | r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), 34 | ))), 35 | ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), 36 | ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), 37 | ]; 38 | 39 | b.iter(|| { 40 | json_get_str 41 | .invoke_with_args(ScalarFunctionArgs { 42 | args: args.to_vec(), 43 | number_rows: 1, 44 | return_type: &DataType::Utf8, 45 | }) 46 | .unwrap() 47 | }); 48 | } 49 | 50 | fn criterion_benchmark(c: &mut Criterion) { 51 | c.bench_function("json_contains", bench_json_contains); 52 | c.bench_function("json_get_str", bench_json_get_str); 53 | } 54 | 55 | criterion_group!(benches, criterion_benchmark); 56 | criterion_main!(benches); 57 | -------------------------------------------------------------------------------- /src/common.rs: -------------------------------------------------------------------------------- 1 | use std::str::Utf8Error; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::{ 5 | downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray, 6 | PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray, 7 | }; 8 | use datafusion::arrow::compute::kernels::cast; 9 | use datafusion::arrow::compute::take; 10 | use datafusion::arrow::datatypes::{ArrowNativeType, DataType, Int64Type, UInt64Type}; 11 | use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; 12 | use datafusion::logical_expr::ColumnarValue; 13 | use jiter::{Jiter, JiterError, Peek}; 14 | 15 | use crate::common_union::{ 16 | is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL, 17 | }; 18 | 19 | /// General implementation of `ScalarUDFImpl::return_type`. 20 | /// 21 | /// # Arguments 22 | /// 23 | /// * `args` - The arguments to the function 24 | /// * `fn_name` - The name of the function 25 | /// * `value_type` - The general return type of the function, might be wrapped in a dictionary depending 26 | /// on the first argument 27 | pub fn return_type_check(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult { 28 | let Some(first) = args.first() else { 29 | return plan_err!("The '{fn_name}' function requires one or more arguments."); 30 | }; 31 | let first_dict_key_type = dict_key_type(first); 32 | if !(is_str(first) || is_json_union(first) || first_dict_key_type.is_some()) { 33 | // if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) { 34 | return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}."); 35 | } 36 | args.iter().skip(1).enumerate().try_for_each(|(index, arg)| { 37 | if is_str(arg) || is_int(arg) || dict_key_type(arg).is_some() { 38 | Ok(()) 39 | } else { 40 | plan_err!( 41 | "Unexpected argument type to '{fn_name}' at position {}, expected string or int, got {arg:?}.", 42 | index + 2 43 | ) 44 | } 45 | })?; 46 | if first_dict_key_type.is_some() && !value_type.is_primitive() { 47 | Ok(DataType::Dictionary(Box::new(DataType::Int64), Box::new(value_type))) 48 | } else { 49 | Ok(value_type) 50 | } 51 | } 52 | 53 | fn is_str(d: &DataType) -> bool { 54 | matches!(d, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View) 55 | } 56 | 57 | fn is_int(d: &DataType) -> bool { 58 | // TODO we should support more types of int, but that's a longer task 59 | matches!(d, DataType::UInt64 | DataType::Int64) 60 | } 61 | 62 | fn dict_key_type(d: &DataType) -> Option { 63 | if let DataType::Dictionary(key, value) = d { 64 | if is_str(value) || is_json_union(value) { 65 | return Some(*key.clone()); 66 | } 67 | } 68 | None 69 | } 70 | 71 | #[derive(Debug)] 72 | pub enum JsonPath<'s> { 73 | Key(&'s str), 74 | Index(usize), 75 | None, 76 | } 77 | 78 | impl<'a> From<&'a str> for JsonPath<'a> { 79 | fn from(key: &'a str) -> Self { 80 | JsonPath::Key(key) 81 | } 82 | } 83 | 84 | impl From for JsonPath<'_> { 85 | fn from(index: u64) -> Self { 86 | JsonPath::Index(usize::try_from(index).unwrap()) 87 | } 88 | } 89 | 90 | impl From for JsonPath<'_> { 91 | fn from(index: i64) -> Self { 92 | match usize::try_from(index) { 93 | Ok(i) => Self::Index(i), 94 | Err(_) => Self::None, 95 | } 96 | } 97 | } 98 | 99 | #[derive(Debug)] 100 | enum JsonPathArgs<'a> { 101 | Array(&'a ArrayRef), 102 | Scalars(Vec>), 103 | } 104 | 105 | impl<'s> JsonPathArgs<'s> { 106 | fn extract_path(path_args: &'s [ColumnarValue]) -> DataFusionResult { 107 | // If there is a single argument as an array, we know how to handle it 108 | if let Some((ColumnarValue::Array(array), &[])) = path_args.split_first() { 109 | return Ok(Self::Array(array)); 110 | } 111 | 112 | path_args 113 | .iter() 114 | .enumerate() 115 | .map(|(pos, arg)| match arg { 116 | ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s))) => { 117 | Ok(JsonPath::Key(s)) 118 | } 119 | ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => Ok((*i).into()), 120 | ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => Ok((*i).into()), 121 | ColumnarValue::Scalar( 122 | ScalarValue::Null 123 | | ScalarValue::Utf8(None) 124 | | ScalarValue::LargeUtf8(None) 125 | | ScalarValue::UInt64(None) 126 | | ScalarValue::Int64(None), 127 | ) => Ok(JsonPath::None), 128 | ColumnarValue::Array(_) => { 129 | // if there was a single arg, which is an array, handled above in the 130 | // split_first case. So this is multiple args of which one is an array 131 | exec_err!("More than 1 path element is not supported when querying JSON using an array.") 132 | } 133 | ColumnarValue::Scalar(arg) => exec_err!( 134 | "Unexpected argument type at position {}, expected string or int, got {arg:?}.", 135 | pos + 1 136 | ), 137 | }) 138 | .collect::>() 139 | .map(JsonPathArgs::Scalars) 140 | } 141 | } 142 | 143 | pub trait InvokeResult { 144 | type Item; 145 | type Builder; 146 | 147 | // Whether the return type should is allowed to be a dictionary 148 | const ACCEPT_DICT_RETURN: bool; 149 | 150 | fn builder(capacity: usize) -> Self::Builder; 151 | fn append_value(builder: &mut Self::Builder, value: Option); 152 | fn finish(builder: Self::Builder) -> DataFusionResult; 153 | 154 | /// Convert a single value to a scalar 155 | fn scalar(value: Option) -> ScalarValue; 156 | } 157 | 158 | pub fn invoke( 159 | args: &[ColumnarValue], 160 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, 161 | ) -> DataFusionResult { 162 | let Some((json_arg, path_args)) = args.split_first() else { 163 | return exec_err!("expected at least one argument"); 164 | }; 165 | 166 | let path = JsonPathArgs::extract_path(path_args)?; 167 | match (json_arg, path) { 168 | (ColumnarValue::Array(json_array), JsonPathArgs::Array(path_array)) => { 169 | invoke_array_array::(json_array, path_array, jiter_find).map(ColumnarValue::Array) 170 | } 171 | (ColumnarValue::Array(json_array), JsonPathArgs::Scalars(path)) => { 172 | invoke_array_scalars::(json_array, &path, jiter_find).map(ColumnarValue::Array) 173 | } 174 | (ColumnarValue::Scalar(s), JsonPathArgs::Array(path_array)) => { 175 | invoke_scalar_array::(s, path_array, jiter_find) 176 | } 177 | (ColumnarValue::Scalar(s), JsonPathArgs::Scalars(path)) => { 178 | invoke_scalar_scalars(s, &path, jiter_find, R::scalar) 179 | } 180 | } 181 | } 182 | 183 | fn invoke_array_array( 184 | json_array: &ArrayRef, 185 | path_array: &ArrayRef, 186 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, 187 | ) -> DataFusionResult { 188 | match json_array.data_type() { 189 | // for string dictionaries, cast dictionary keys to larger types to avoid generic explosion 190 | DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8 => { 191 | let json_array = cast_to_large_dictionary(json_array.as_any_dictionary())?; 192 | let output = zip_apply::( 193 | json_array.downcast_dict::().unwrap(), 194 | path_array, 195 | jiter_find, 196 | )?; 197 | if R::ACCEPT_DICT_RETURN { 198 | // ensure return is a dictionary to satisfy the declaration above in return_type_check 199 | Ok(Arc::new(wrap_as_large_dictionary(&json_array, output))) 200 | } else { 201 | Ok(output) 202 | } 203 | } 204 | DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::LargeUtf8 => { 205 | let json_array = cast_to_large_dictionary(json_array.as_any_dictionary())?; 206 | let output = zip_apply::( 207 | json_array.downcast_dict::().unwrap(), 208 | path_array, 209 | jiter_find, 210 | )?; 211 | if R::ACCEPT_DICT_RETURN { 212 | // ensure return is a dictionary to satisfy the declaration above in return_type_check 213 | Ok(Arc::new(wrap_as_large_dictionary(&json_array, output))) 214 | } else { 215 | Ok(output) 216 | } 217 | } 218 | other_dict_type @ DataType::Dictionary(_, _) => { 219 | // Horrible case: dict containing union as input with array for paths, figure 220 | // out from the path type which union members we should access, repack the 221 | // dictionary and then recurse. 222 | if let Some(child_array) = nested_json_array_ref( 223 | json_array.as_any_dictionary().values(), 224 | is_object_lookup_array(path_array.data_type()), 225 | ) { 226 | invoke_array_array::( 227 | &(Arc::new(json_array.as_any_dictionary().with_values(child_array.clone())) as _), 228 | path_array, 229 | jiter_find, 230 | ) 231 | } else { 232 | exec_err!("unexpected json array type {:?}", other_dict_type) 233 | } 234 | } 235 | DataType::Utf8 => zip_apply::(json_array.as_string::(), path_array, jiter_find), 236 | DataType::LargeUtf8 => zip_apply::(json_array.as_string::(), path_array, jiter_find), 237 | DataType::Utf8View => zip_apply::(json_array.as_string_view(), path_array, jiter_find), 238 | other => { 239 | if let Some(string_array) = nested_json_array(json_array, is_object_lookup_array(path_array.data_type())) { 240 | zip_apply::(string_array, path_array, jiter_find) 241 | } else { 242 | exec_err!("unexpected json array type {:?}", other) 243 | } 244 | } 245 | } 246 | } 247 | 248 | /// Transform keys that may be pointing to values with nulls to nulls themselves. 249 | /// keys = `[0, 1, 2, 3]`, values = `[null, "a", null, "b"]` 250 | /// into 251 | /// keys = `[null, 0, null, 1]`, values = `["a", "b"]` 252 | /// 253 | /// Arrow / `DataFusion` assumes that dictionary values do not contain nulls, nulls are encoded by the keys. 254 | /// Not following this invariant causes invalid dictionary arrays to be built later on inside of `DataFusion` 255 | /// when arrays are concacted and such. 256 | fn remap_dictionary_key_nulls(keys: PrimitiveArray, values: ArrayRef) -> DictionaryArray { 257 | // fast path: no nulls in values 258 | if values.null_count() == 0 { 259 | return DictionaryArray::new(keys, values); 260 | } 261 | 262 | let mut new_keys_builder = PrimitiveBuilder::::new(); 263 | 264 | for key in &keys { 265 | match key { 266 | Some(k) if values.is_null(k.as_usize()) => new_keys_builder.append_null(), 267 | Some(k) => new_keys_builder.append_value(k), 268 | None => new_keys_builder.append_null(), 269 | } 270 | } 271 | 272 | let new_keys = new_keys_builder.finish(); 273 | DictionaryArray::new(new_keys, values) 274 | } 275 | 276 | fn invoke_array_scalars( 277 | json_array: &ArrayRef, 278 | path: &[JsonPath], 279 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, 280 | ) -> DataFusionResult { 281 | #[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references 282 | fn inner<'j, R: InvokeResult>( 283 | json_array: impl ArrayAccessor, 284 | path: &[JsonPath], 285 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, 286 | ) -> DataFusionResult { 287 | let mut builder = R::builder(json_array.len()); 288 | for i in 0..json_array.len() { 289 | let opt_json = if json_array.is_null(i) { 290 | None 291 | } else { 292 | Some(json_array.value(i)) 293 | }; 294 | let opt_value = jiter_find(opt_json, path).ok(); 295 | R::append_value(&mut builder, opt_value); 296 | } 297 | R::finish(builder) 298 | } 299 | 300 | match json_array.data_type() { 301 | DataType::Dictionary(_, _) => { 302 | let json_array = json_array.as_any_dictionary(); 303 | let values = invoke_array_scalars::(json_array.values(), path, jiter_find)?; 304 | return if R::ACCEPT_DICT_RETURN { 305 | // make the keys into i64 to avoid generic bloat here 306 | let mut keys: PrimitiveArray = downcast_array(&cast(json_array.keys(), &DataType::Int64)?); 307 | if is_json_union(values.data_type()) { 308 | // JSON union: post-process the array to set keys to null where the union member is null 309 | let type_ids = values.as_union().type_ids(); 310 | keys = mask_dictionary_keys(&keys, type_ids); 311 | } 312 | Ok(Arc::new(remap_dictionary_key_nulls(keys, values))) 313 | } else { 314 | // this is what cast would do under the hood to unpack a dictionary into an array of its values 315 | Ok(take(&values, json_array.keys(), None)?) 316 | }; 317 | } 318 | DataType::Utf8 => inner::(json_array.as_string::(), path, jiter_find), 319 | DataType::LargeUtf8 => inner::(json_array.as_string::(), path, jiter_find), 320 | DataType::Utf8View => inner::(json_array.as_string_view(), path, jiter_find), 321 | other => { 322 | if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { 323 | inner::(string_array, path, jiter_find) 324 | } else { 325 | exec_err!("unexpected json array type {:?}", other) 326 | } 327 | } 328 | } 329 | } 330 | 331 | fn invoke_scalar_array( 332 | scalar: &ScalarValue, 333 | path_array: &ArrayRef, 334 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, 335 | ) -> DataFusionResult { 336 | let s = extract_json_scalar(scalar)?; 337 | let arr = s.map_or_else(|| StringArray::new_null(1), |s| StringArray::new_scalar(s).into_inner()); 338 | 339 | // TODO: possible optimization here if path_array is a dictionary; can apply against the 340 | // dictionary values directly for less work 341 | zip_apply::( 342 | RunArray::try_new( 343 | &PrimitiveArray::::new_scalar(i64::try_from(path_array.len()).expect("len out of i64 range")) 344 | .into_inner(), 345 | &arr, 346 | )? 347 | .downcast::() 348 | .expect("type known"), 349 | path_array, 350 | jiter_find, 351 | ) 352 | // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? 353 | .map(ColumnarValue::Array) 354 | } 355 | 356 | fn invoke_scalar_scalars( 357 | scalar: &ScalarValue, 358 | path: &[JsonPath], 359 | jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, 360 | to_scalar: impl Fn(Option) -> ScalarValue, 361 | ) -> DataFusionResult { 362 | let s = extract_json_scalar(scalar)?; 363 | let v = jiter_find(s, path).ok(); 364 | // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? 365 | Ok(ColumnarValue::Scalar(to_scalar(v))) 366 | } 367 | 368 | fn zip_apply<'a, R: InvokeResult>( 369 | json_array: impl ArrayAccessor, 370 | path_array: &ArrayRef, 371 | jiter_find: impl Fn(Option<&'a str>, &[JsonPath]) -> Result, 372 | ) -> DataFusionResult { 373 | fn get_array_values<'j, 'p, P: Into>>( 374 | j: &impl ArrayAccessor, 375 | p: &impl ArrayAccessor, 376 | index: usize, 377 | ) -> Option<(Option<&'j str>, JsonPath<'p>)> { 378 | let path = if p.is_null(index) { 379 | return None; 380 | } else { 381 | p.value(index).into() 382 | }; 383 | 384 | let json = if j.is_null(index) { None } else { Some(j.value(index)) }; 385 | 386 | Some((json, path)) 387 | } 388 | 389 | #[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references 390 | fn inner<'a, 'p, P: Into>, R: InvokeResult>( 391 | json_array: impl ArrayAccessor, 392 | path_array: impl ArrayAccessor, 393 | jiter_find: impl Fn(Option<&'a str>, &[JsonPath]) -> Result, 394 | ) -> DataFusionResult { 395 | let mut builder = R::builder(json_array.len()); 396 | for i in 0..json_array.len() { 397 | if let Some((opt_json, path)) = get_array_values(&json_array, &path_array, i) { 398 | let value = jiter_find(opt_json, &[path]).ok(); 399 | R::append_value(&mut builder, value); 400 | } else { 401 | R::append_value(&mut builder, None); 402 | } 403 | } 404 | R::finish(builder) 405 | } 406 | 407 | match path_array.data_type() { 408 | // for string dictionaries, cast dictionary keys to larger types to avoid generic explosion 409 | DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8 => { 410 | let path_array = cast_to_large_dictionary(path_array.as_any_dictionary())?; 411 | inner::<_, R>( 412 | json_array, 413 | path_array.downcast_dict::().unwrap(), 414 | jiter_find, 415 | ) 416 | } 417 | DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::LargeUtf8 => { 418 | let path_array = cast_to_large_dictionary(path_array.as_any_dictionary())?; 419 | inner::<_, R>( 420 | json_array, 421 | path_array.downcast_dict::().unwrap(), 422 | jiter_find, 423 | ) 424 | } 425 | DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8View => { 426 | let path_array = cast_to_large_dictionary(path_array.as_any_dictionary())?; 427 | inner::<_, R>( 428 | json_array, 429 | path_array.downcast_dict::().unwrap(), 430 | jiter_find, 431 | ) 432 | } 433 | // for integer dictionaries, cast them directly to the inner type because it basically costs 434 | // the same as building a new key array anyway 435 | DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Int64 => inner::<_, R>( 436 | json_array, 437 | cast(path_array, &DataType::Int64)?.as_primitive::(), 438 | jiter_find, 439 | ), 440 | DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::UInt64 => inner::<_, R>( 441 | json_array, 442 | cast(path_array, &DataType::UInt64)?.as_primitive::(), 443 | jiter_find, 444 | ), 445 | // for basic types, just consume directly 446 | DataType::Utf8 => inner::<_, R>(json_array, path_array.as_string::(), jiter_find), 447 | DataType::LargeUtf8 => inner::<_, R>(json_array, path_array.as_string::(), jiter_find), 448 | DataType::Utf8View => inner::<_, R>(json_array, path_array.as_string_view(), jiter_find), 449 | DataType::Int64 => inner::<_, R>(json_array, path_array.as_primitive::(), jiter_find), 450 | DataType::UInt64 => inner::<_, R>(json_array, path_array.as_primitive::(), jiter_find), 451 | other => { 452 | exec_err!( 453 | "unexpected second argument type, expected string or int array, got {:?}", 454 | other 455 | ) 456 | } 457 | } 458 | } 459 | 460 | fn extract_json_scalar(scalar: &ScalarValue) -> DataFusionResult> { 461 | match scalar { 462 | ScalarValue::Dictionary(_, b) => extract_json_scalar(b.as_ref()), 463 | ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => Ok(s.as_deref()), 464 | ScalarValue::Union(type_id_value, union_fields, _) => { 465 | Ok(json_from_union_scalar(type_id_value.as_ref(), union_fields)) 466 | } 467 | _ => { 468 | exec_err!("unexpected first argument type, expected string or JSON union") 469 | } 470 | } 471 | } 472 | 473 | fn is_object_lookup(path: &[JsonPath]) -> bool { 474 | if let Some(first) = path.first() { 475 | matches!(first, JsonPath::Key(_)) 476 | } else { 477 | false 478 | } 479 | } 480 | 481 | fn is_object_lookup_array(data_type: &DataType) -> bool { 482 | match data_type { 483 | DataType::Dictionary(_, value_type) => is_object_lookup_array(value_type), 484 | DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => true, 485 | _ => false, 486 | } 487 | } 488 | 489 | /// Cast an array to a dictionary with i64 indices. 490 | /// 491 | /// According to the 492 | /// recommendation is to avoid unsigned indices due to technologies like the JVM making it harder to 493 | /// support unsigned integers. 494 | /// 495 | /// So we'll just use i64 as the largest signed integer type. 496 | fn cast_to_large_dictionary(dict_array: &dyn AnyDictionaryArray) -> DataFusionResult> { 497 | let keys = downcast_array(&cast(dict_array.keys(), &DataType::Int64)?); 498 | Ok(DictionaryArray::::new(keys, dict_array.values().clone())) 499 | } 500 | 501 | /// Wrap an array as a dictionary with i64 indices. 502 | fn wrap_as_large_dictionary(original: &dyn AnyDictionaryArray, new_values: ArrayRef) -> DictionaryArray { 503 | assert_eq!(original.keys().len(), new_values.len()); 504 | let mut keys = 505 | PrimitiveArray::from_iter_values(0i64..original.keys().len().try_into().expect("keys out of i64 range")); 506 | if is_json_union(new_values.data_type()) { 507 | // JSON union: post-process the array to set keys to null where the union member is null 508 | let type_ids = new_values.as_union().type_ids(); 509 | keys = mask_dictionary_keys(&keys, type_ids); 510 | } 511 | DictionaryArray::new(keys, new_values) 512 | } 513 | 514 | pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { 515 | let json_str = opt_json?; 516 | let mut jiter = Jiter::new(json_str.as_bytes()); 517 | let mut peek = jiter.peek().ok()?; 518 | for element in path { 519 | match element { 520 | JsonPath::Key(key) if peek == Peek::Object => { 521 | let mut next_key = jiter.known_object().ok()??; 522 | 523 | while next_key != *key { 524 | jiter.next_skip().ok()?; 525 | next_key = jiter.next_key().ok()??; 526 | } 527 | 528 | peek = jiter.peek().ok()?; 529 | } 530 | JsonPath::Index(index) if peek == Peek::Array => { 531 | let mut array_item = jiter.known_array().ok()??; 532 | 533 | for _ in 0..*index { 534 | jiter.known_skip(array_item).ok()?; 535 | array_item = jiter.array_step().ok()??; 536 | } 537 | 538 | peek = array_item; 539 | } 540 | _ => { 541 | return None; 542 | } 543 | } 544 | } 545 | Some((jiter, peek)) 546 | } 547 | 548 | macro_rules! get_err { 549 | () => { 550 | Err(GetError) 551 | }; 552 | } 553 | pub(crate) use get_err; 554 | 555 | pub struct GetError; 556 | 557 | impl From for GetError { 558 | fn from(_: JiterError) -> Self { 559 | GetError 560 | } 561 | } 562 | 563 | impl From for GetError { 564 | fn from(_: Utf8Error) -> Self { 565 | GetError 566 | } 567 | } 568 | 569 | /// Set keys to null where the union member is null. 570 | /// 571 | /// This is a workaround to 572 | /// - i.e. that dictionary null is most reliably done if the keys are null. 573 | /// 574 | /// That said, doing this might also be an optimization for cases like null-checking without needing 575 | /// to check the value union array. 576 | fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { 577 | let mut null_mask = vec![true; keys.len()]; 578 | for (i, k) in keys.iter().enumerate() { 579 | match k { 580 | // if the key is non-null and value is non-null, don't mask it out 581 | Some(k) if type_ids[k.as_usize()] != TYPE_ID_NULL => {} 582 | // i.e. key is null or value is null here 583 | _ => null_mask[i] = false, 584 | } 585 | } 586 | PrimitiveArray::new(keys.values().clone(), Some(null_mask.into())) 587 | } 588 | -------------------------------------------------------------------------------- /src/common_macros.rs: -------------------------------------------------------------------------------- 1 | /// Creates external API `ScalarUDF` for an array UDF. Specifically, creates 2 | /// 3 | /// Creates a singleton `ScalarUDF` of the `$udf_impl` function named `$expr_fn_name _udf` and a 4 | /// function named `$expr_fn_name _udf` which returns that function. 5 | /// 6 | /// This is used to ensure creating the list of `ScalarUDF` only happens once. 7 | /// 8 | /// # Arguments 9 | /// * `udf_impl`: name of the [`ScalarUDFImpl`] 10 | /// * `expr_fn_name`: name of the `expr_fn` function to be created 11 | /// * `arg`: 0 or more named arguments for the function 12 | /// * `doc`: documentation string for the function 13 | /// 14 | /// Copied mostly from, `/datafusion/functions-array/src/macros.rs`. 15 | /// 16 | /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl 17 | macro_rules! make_udf_function { 18 | ($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr) => { 19 | paste::paste! { 20 | #[doc = $doc] 21 | #[must_use] pub fn $expr_fn_name($($arg: datafusion::logical_expr::Expr),*) -> datafusion::logical_expr::Expr { 22 | datafusion::logical_expr::Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( 23 | [< $expr_fn_name _udf >](), 24 | vec![$($arg),*], 25 | )) 26 | } 27 | 28 | /// Singleton instance of [`$udf_impl`], ensures the UDF is only created once 29 | /// named for example `STATIC_JSON_OBJ_CONTAINS` 30 | static [< STATIC_ $expr_fn_name:upper >]: std::sync::OnceLock> = 31 | std::sync::OnceLock::new(); 32 | 33 | /// ScalarFunction that returns a [`ScalarUDF`] for [`$udf_impl`] 34 | /// 35 | /// [`ScalarUDF`]: datafusion::logical_expr::ScalarUDF 36 | pub fn [< $expr_fn_name _udf >]() -> std::sync::Arc { 37 | [< STATIC_ $expr_fn_name:upper >] 38 | .get_or_init(|| { 39 | std::sync::Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl( 40 | <$udf_impl>::default(), 41 | )) 42 | }) 43 | .clone() 44 | } 45 | } 46 | }; 47 | } 48 | 49 | pub(crate) use make_udf_function; 50 | -------------------------------------------------------------------------------- /src/common_union.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::sync::{Arc, OnceLock}; 3 | 4 | use datafusion::arrow::array::{ 5 | Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray, 6 | }; 7 | use datafusion::arrow::buffer::{Buffer, ScalarBuffer}; 8 | use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; 9 | use datafusion::arrow::error::ArrowError; 10 | use datafusion::common::ScalarValue; 11 | 12 | pub fn is_json_union(data_type: &DataType) -> bool { 13 | match data_type { 14 | DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(), 15 | _ => false, 16 | } 17 | } 18 | 19 | /// Extract nested JSON from a `JsonUnion` `UnionArray` 20 | /// 21 | /// # Arguments 22 | /// * `array` - The `UnionArray` to extract the nested JSON from 23 | /// * `object_lookup` - If `true`, extract from the "object" member of the union, 24 | /// otherwise extract from the "array" member 25 | pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> { 26 | nested_json_array_ref(array, object_lookup).map(AsArray::as_string) 27 | } 28 | 29 | pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> { 30 | let union_array: &UnionArray = array.as_any().downcast_ref::()?; 31 | let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY }; 32 | Some(union_array.child(type_id)) 33 | } 34 | 35 | /// Extract a JSON string from a `JsonUnion` scalar 36 | pub(crate) fn json_from_union_scalar<'a>( 37 | type_id_value: Option<&'a (i8, Box)>, 38 | fields: &UnionFields, 39 | ) -> Option<&'a str> { 40 | if let Some((type_id, value)) = type_id_value { 41 | // we only want to take the ScalarValue string if the type_id indicates the value represents nested JSON 42 | if fields == &union_fields() && (*type_id == TYPE_ID_ARRAY || *type_id == TYPE_ID_OBJECT) { 43 | if let ScalarValue::Utf8(s) = value.as_ref() { 44 | return s.as_deref(); 45 | } 46 | } 47 | } 48 | None 49 | } 50 | 51 | #[derive(Debug)] 52 | pub(crate) struct JsonUnion { 53 | bools: Vec>, 54 | ints: Vec>, 55 | floats: Vec>, 56 | strings: Vec>, 57 | arrays: Vec>, 58 | objects: Vec>, 59 | type_ids: Vec, 60 | index: usize, 61 | length: usize, 62 | } 63 | 64 | impl JsonUnion { 65 | pub fn new(length: usize) -> Self { 66 | Self { 67 | bools: vec![None; length], 68 | ints: vec![None; length], 69 | floats: vec![None; length], 70 | strings: vec![None; length], 71 | arrays: vec![None; length], 72 | objects: vec![None; length], 73 | type_ids: vec![TYPE_ID_NULL; length], 74 | index: 0, 75 | length, 76 | } 77 | } 78 | 79 | pub fn data_type() -> DataType { 80 | DataType::Union(union_fields(), UnionMode::Sparse) 81 | } 82 | 83 | pub fn push(&mut self, field: JsonUnionField) { 84 | self.type_ids[self.index] = field.type_id(); 85 | match field { 86 | JsonUnionField::JsonNull => (), 87 | JsonUnionField::Bool(value) => self.bools[self.index] = Some(value), 88 | JsonUnionField::Int(value) => self.ints[self.index] = Some(value), 89 | JsonUnionField::Float(value) => self.floats[self.index] = Some(value), 90 | JsonUnionField::Str(value) => self.strings[self.index] = Some(value), 91 | JsonUnionField::Array(value) => self.arrays[self.index] = Some(value), 92 | JsonUnionField::Object(value) => self.objects[self.index] = Some(value), 93 | } 94 | self.index += 1; 95 | debug_assert!(self.index <= self.length); 96 | } 97 | 98 | pub fn push_none(&mut self) { 99 | self.index += 1; 100 | debug_assert!(self.index <= self.length); 101 | } 102 | } 103 | 104 | /// So we can do `collect::()` 105 | impl FromIterator> for JsonUnion { 106 | fn from_iter>>(iter: I) -> Self { 107 | let inner = iter.into_iter(); 108 | let (lower, upper) = inner.size_hint(); 109 | let mut union = Self::new(upper.unwrap_or(lower)); 110 | 111 | for opt_field in inner { 112 | if let Some(union_field) = opt_field { 113 | union.push(union_field); 114 | } else { 115 | union.push_none(); 116 | } 117 | } 118 | union 119 | } 120 | } 121 | 122 | impl TryFrom for UnionArray { 123 | type Error = ArrowError; 124 | 125 | fn try_from(value: JsonUnion) -> Result { 126 | let children: Vec> = vec![ 127 | Arc::new(NullArray::new(value.length)), 128 | Arc::new(BooleanArray::from(value.bools)), 129 | Arc::new(Int64Array::from(value.ints)), 130 | Arc::new(Float64Array::from(value.floats)), 131 | Arc::new(StringArray::from(value.strings)), 132 | Arc::new(StringArray::from(value.arrays)), 133 | Arc::new(StringArray::from(value.objects)), 134 | ]; 135 | UnionArray::try_new(union_fields(), Buffer::from_vec(value.type_ids).into(), None, children) 136 | } 137 | } 138 | 139 | #[derive(Debug)] 140 | pub(crate) enum JsonUnionField { 141 | JsonNull, 142 | Bool(bool), 143 | Int(i64), 144 | Float(f64), 145 | Str(String), 146 | Array(String), 147 | Object(String), 148 | } 149 | 150 | pub(crate) const TYPE_ID_NULL: i8 = 0; 151 | const TYPE_ID_BOOL: i8 = 1; 152 | const TYPE_ID_INT: i8 = 2; 153 | const TYPE_ID_FLOAT: i8 = 3; 154 | const TYPE_ID_STR: i8 = 4; 155 | const TYPE_ID_ARRAY: i8 = 5; 156 | const TYPE_ID_OBJECT: i8 = 6; 157 | 158 | fn union_fields() -> UnionFields { 159 | static FIELDS: OnceLock = OnceLock::new(); 160 | FIELDS 161 | .get_or_init(|| { 162 | let json_metadata: HashMap = 163 | HashMap::from_iter(vec![("is_json".to_string(), "true".to_string())]); 164 | UnionFields::from_iter([ 165 | (TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Null, true))), 166 | (TYPE_ID_BOOL, Arc::new(Field::new("bool", DataType::Boolean, false))), 167 | (TYPE_ID_INT, Arc::new(Field::new("int", DataType::Int64, false))), 168 | (TYPE_ID_FLOAT, Arc::new(Field::new("float", DataType::Float64, false))), 169 | (TYPE_ID_STR, Arc::new(Field::new("str", DataType::Utf8, false))), 170 | ( 171 | TYPE_ID_ARRAY, 172 | Arc::new(Field::new("array", DataType::Utf8, false).with_metadata(json_metadata.clone())), 173 | ), 174 | ( 175 | TYPE_ID_OBJECT, 176 | Arc::new(Field::new("object", DataType::Utf8, false).with_metadata(json_metadata.clone())), 177 | ), 178 | ]) 179 | }) 180 | .clone() 181 | } 182 | 183 | impl JsonUnionField { 184 | fn type_id(&self) -> i8 { 185 | match self { 186 | Self::JsonNull => TYPE_ID_NULL, 187 | Self::Bool(_) => TYPE_ID_BOOL, 188 | Self::Int(_) => TYPE_ID_INT, 189 | Self::Float(_) => TYPE_ID_FLOAT, 190 | Self::Str(_) => TYPE_ID_STR, 191 | Self::Array(_) => TYPE_ID_ARRAY, 192 | Self::Object(_) => TYPE_ID_OBJECT, 193 | } 194 | } 195 | 196 | pub fn scalar_value(f: Option) -> ScalarValue { 197 | ScalarValue::Union( 198 | f.map(|f| (f.type_id(), Box::new(f.into()))), 199 | union_fields(), 200 | UnionMode::Sparse, 201 | ) 202 | } 203 | } 204 | 205 | impl From for ScalarValue { 206 | fn from(value: JsonUnionField) -> Self { 207 | match value { 208 | JsonUnionField::JsonNull => Self::Null, 209 | JsonUnionField::Bool(b) => Self::Boolean(Some(b)), 210 | JsonUnionField::Int(i) => Self::Int64(Some(i)), 211 | JsonUnionField::Float(f) => Self::Float64(Some(f)), 212 | JsonUnionField::Str(s) | JsonUnionField::Array(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)), 213 | } 214 | } 215 | } 216 | 217 | pub struct JsonUnionEncoder { 218 | boolean: BooleanArray, 219 | int: Int64Array, 220 | float: Float64Array, 221 | string: StringArray, 222 | array: StringArray, 223 | object: StringArray, 224 | type_ids: ScalarBuffer, 225 | } 226 | 227 | impl JsonUnionEncoder { 228 | #[must_use] 229 | pub fn from_union(union: UnionArray) -> Option { 230 | if is_json_union(union.data_type()) { 231 | let (_, type_ids, _, c) = union.into_parts(); 232 | Some(Self { 233 | boolean: c[1].as_boolean().clone(), 234 | int: c[2].as_primitive().clone(), 235 | float: c[3].as_primitive().clone(), 236 | string: c[4].as_string().clone(), 237 | array: c[5].as_string().clone(), 238 | object: c[6].as_string().clone(), 239 | type_ids, 240 | }) 241 | } else { 242 | None 243 | } 244 | } 245 | 246 | #[must_use] 247 | #[allow(clippy::len_without_is_empty)] 248 | pub fn len(&self) -> usize { 249 | self.type_ids.len() 250 | } 251 | 252 | /// Get the encodable value for a given index 253 | /// 254 | /// # Panics 255 | /// 256 | /// Panics if the idx is outside the union values or an invalid type id exists in the union. 257 | #[must_use] 258 | pub fn get_value(&self, idx: usize) -> JsonUnionValue { 259 | let type_id = self.type_ids[idx]; 260 | match type_id { 261 | TYPE_ID_NULL => JsonUnionValue::JsonNull, 262 | TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)), 263 | TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)), 264 | TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)), 265 | TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)), 266 | TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)), 267 | TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)), 268 | _ => panic!("Invalid type_id: {type_id}, not a valid JSON type"), 269 | } 270 | } 271 | } 272 | 273 | #[derive(Debug, PartialEq)] 274 | pub enum JsonUnionValue<'a> { 275 | JsonNull, 276 | Bool(bool), 277 | Int(i64), 278 | Float(f64), 279 | Str(&'a str), 280 | Array(&'a str), 281 | Object(&'a str), 282 | } 283 | 284 | #[cfg(test)] 285 | mod test { 286 | use super::*; 287 | 288 | #[test] 289 | fn test_json_union() { 290 | let json_union = JsonUnion::from_iter(vec![ 291 | Some(JsonUnionField::JsonNull), 292 | Some(JsonUnionField::Bool(true)), 293 | Some(JsonUnionField::Bool(false)), 294 | Some(JsonUnionField::Int(42)), 295 | Some(JsonUnionField::Float(42.0)), 296 | Some(JsonUnionField::Str("foo".to_string())), 297 | Some(JsonUnionField::Array("[42]".to_string())), 298 | Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())), 299 | None, 300 | ]); 301 | 302 | let union_array = UnionArray::try_from(json_union).unwrap(); 303 | let encoder = JsonUnionEncoder::from_union(union_array).unwrap(); 304 | 305 | let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect(); 306 | assert_eq!( 307 | values_after, 308 | vec![ 309 | JsonUnionValue::JsonNull, 310 | JsonUnionValue::Bool(true), 311 | JsonUnionValue::Bool(false), 312 | JsonUnionValue::Int(42), 313 | JsonUnionValue::Float(42.0), 314 | JsonUnionValue::Str("foo"), 315 | JsonUnionValue::Array("[42]"), 316 | JsonUnionValue::Object(r#"{"foo": 42}"#), 317 | JsonUnionValue::JsonNull, 318 | ] 319 | ); 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /src/json_as_text.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::{ArrayRef, StringArray, StringBuilder}; 5 | use datafusion::arrow::datatypes::DataType; 6 | use datafusion::common::{Result as DataFusionResult, ScalarValue}; 7 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 8 | use jiter::Peek; 9 | 10 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; 11 | use crate::common_macros::make_udf_function; 12 | 13 | make_udf_function!( 14 | JsonAsText, 15 | json_as_text, 16 | json_data path, 17 | r#"Get any value from a JSON string by its "path", represented as a string"# 18 | ); 19 | 20 | #[derive(Debug)] 21 | pub(super) struct JsonAsText { 22 | signature: Signature, 23 | aliases: [String; 1], 24 | } 25 | 26 | impl Default for JsonAsText { 27 | fn default() -> Self { 28 | Self { 29 | signature: Signature::variadic_any(Volatility::Immutable), 30 | aliases: ["json_as_text".to_string()], 31 | } 32 | } 33 | } 34 | 35 | impl ScalarUDFImpl for JsonAsText { 36 | fn as_any(&self) -> &dyn Any { 37 | self 38 | } 39 | 40 | fn name(&self) -> &str { 41 | self.aliases[0].as_str() 42 | } 43 | 44 | fn signature(&self) -> &Signature { 45 | &self.signature 46 | } 47 | 48 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 49 | return_type_check(arg_types, self.name(), DataType::Utf8) 50 | } 51 | 52 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 53 | invoke::(&args.args, jiter_json_as_text) 54 | } 55 | 56 | fn aliases(&self) -> &[String] { 57 | &self.aliases 58 | } 59 | } 60 | 61 | impl InvokeResult for StringArray { 62 | type Item = String; 63 | 64 | type Builder = StringBuilder; 65 | 66 | const ACCEPT_DICT_RETURN: bool = true; 67 | 68 | fn builder(capacity: usize) -> Self::Builder { 69 | StringBuilder::with_capacity(capacity, 0) 70 | } 71 | 72 | fn append_value(builder: &mut Self::Builder, value: Option) { 73 | builder.append_option(value); 74 | } 75 | 76 | fn finish(mut builder: Self::Builder) -> DataFusionResult { 77 | Ok(Arc::new(builder.finish())) 78 | } 79 | 80 | fn scalar(value: Option) -> ScalarValue { 81 | ScalarValue::Utf8(value) 82 | } 83 | } 84 | 85 | fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath]) -> Result { 86 | if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { 87 | match peek { 88 | Peek::Null => { 89 | jiter.known_null()?; 90 | get_err!() 91 | } 92 | Peek::String => Ok(jiter.known_str()?.to_owned()), 93 | _ => { 94 | let start = jiter.current_index(); 95 | jiter.known_skip(peek)?; 96 | let object_slice = jiter.slice_to_current(start); 97 | let object_string = std::str::from_utf8(object_slice)?; 98 | Ok(object_string.to_owned()) 99 | } 100 | } 101 | } else { 102 | get_err!() 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/json_contains.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::BooleanBuilder; 5 | use datafusion::arrow::datatypes::DataType; 6 | use datafusion::common::arrow::array::{ArrayRef, BooleanArray}; 7 | use datafusion::common::{plan_err, Result, ScalarValue}; 8 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 9 | 10 | use crate::common::{invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; 11 | use crate::common_macros::make_udf_function; 12 | 13 | make_udf_function!( 14 | JsonContains, 15 | json_contains, 16 | json_data path, 17 | r#"Does the key/index exist within the JSON value as the specified "path"?"# 18 | ); 19 | 20 | #[derive(Debug)] 21 | pub(super) struct JsonContains { 22 | signature: Signature, 23 | aliases: [String; 1], 24 | } 25 | 26 | impl Default for JsonContains { 27 | fn default() -> Self { 28 | Self { 29 | signature: Signature::variadic_any(Volatility::Immutable), 30 | aliases: ["json_contains".to_string()], 31 | } 32 | } 33 | } 34 | 35 | impl ScalarUDFImpl for JsonContains { 36 | fn as_any(&self) -> &dyn Any { 37 | self 38 | } 39 | 40 | fn name(&self) -> &str { 41 | self.aliases[0].as_str() 42 | } 43 | 44 | fn signature(&self) -> &Signature { 45 | &self.signature 46 | } 47 | 48 | fn return_type(&self, arg_types: &[DataType]) -> Result { 49 | if arg_types.len() < 2 { 50 | plan_err!("The 'json_contains' function requires two or more arguments.") 51 | } else { 52 | return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) 53 | } 54 | } 55 | 56 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { 57 | invoke::(&args.args, jiter_json_contains) 58 | } 59 | 60 | fn aliases(&self) -> &[String] { 61 | &self.aliases 62 | } 63 | } 64 | 65 | impl InvokeResult for BooleanArray { 66 | type Item = bool; 67 | 68 | type Builder = BooleanBuilder; 69 | 70 | // Using boolean inside a dictionary is not an optimization! 71 | const ACCEPT_DICT_RETURN: bool = false; 72 | 73 | fn builder(capacity: usize) -> Self::Builder { 74 | BooleanBuilder::with_capacity(capacity) 75 | } 76 | 77 | fn append_value(builder: &mut Self::Builder, value: Option) { 78 | builder.append_option(value); 79 | } 80 | 81 | fn finish(mut builder: Self::Builder) -> Result { 82 | Ok(Arc::new(builder.finish())) 83 | } 84 | 85 | fn scalar(value: Option) -> ScalarValue { 86 | ScalarValue::Boolean(value) 87 | } 88 | } 89 | 90 | #[allow(clippy::unnecessary_wraps)] 91 | fn jiter_json_contains(json_data: Option<&str>, path: &[JsonPath]) -> Result { 92 | Ok(jiter_json_find(json_data, path).is_some()) 93 | } 94 | -------------------------------------------------------------------------------- /src/json_get.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::ArrayRef; 5 | use datafusion::arrow::array::UnionArray; 6 | use datafusion::arrow::datatypes::DataType; 7 | use datafusion::common::Result as DataFusionResult; 8 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 9 | use datafusion::scalar::ScalarValue; 10 | use jiter::{Jiter, NumberAny, NumberInt, Peek}; 11 | 12 | use crate::common::InvokeResult; 13 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; 14 | use crate::common_macros::make_udf_function; 15 | use crate::common_union::{JsonUnion, JsonUnionField}; 16 | 17 | make_udf_function!( 18 | JsonGet, 19 | json_get, 20 | json_data path, 21 | r#"Get a value from a JSON string by its "path""# 22 | ); 23 | 24 | // build_typed_get!(JsonGet, "json_get", Union, Float64Array, jiter_json_get_float); 25 | 26 | #[derive(Debug)] 27 | pub(super) struct JsonGet { 28 | signature: Signature, 29 | aliases: [String; 1], 30 | } 31 | 32 | impl Default for JsonGet { 33 | fn default() -> Self { 34 | Self { 35 | signature: Signature::variadic_any(Volatility::Immutable), 36 | aliases: ["json_get".to_string()], 37 | } 38 | } 39 | } 40 | 41 | impl ScalarUDFImpl for JsonGet { 42 | fn as_any(&self) -> &dyn Any { 43 | self 44 | } 45 | 46 | fn name(&self) -> &str { 47 | self.aliases[0].as_str() 48 | } 49 | 50 | fn signature(&self) -> &Signature { 51 | &self.signature 52 | } 53 | 54 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 55 | return_type_check(arg_types, self.name(), JsonUnion::data_type()) 56 | } 57 | 58 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 59 | invoke::(&args.args, jiter_json_get_union) 60 | } 61 | 62 | fn aliases(&self) -> &[String] { 63 | &self.aliases 64 | } 65 | } 66 | 67 | impl InvokeResult for JsonUnion { 68 | type Item = JsonUnionField; 69 | 70 | type Builder = JsonUnion; 71 | 72 | const ACCEPT_DICT_RETURN: bool = true; 73 | 74 | fn builder(capacity: usize) -> Self::Builder { 75 | JsonUnion::new(capacity) 76 | } 77 | 78 | fn append_value(builder: &mut Self::Builder, value: Option) { 79 | if let Some(value) = value { 80 | builder.push(value); 81 | } else { 82 | builder.push_none(); 83 | } 84 | } 85 | 86 | fn finish(builder: Self::Builder) -> DataFusionResult { 87 | let array: UnionArray = builder.try_into()?; 88 | Ok(Arc::new(array) as ArrayRef) 89 | } 90 | 91 | fn scalar(value: Option) -> ScalarValue { 92 | JsonUnionField::scalar_value(value) 93 | } 94 | } 95 | 96 | fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result { 97 | if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { 98 | build_union(&mut jiter, peek) 99 | } else { 100 | get_err!() 101 | } 102 | } 103 | 104 | fn build_union(jiter: &mut Jiter, peek: Peek) -> Result { 105 | match peek { 106 | Peek::Null => { 107 | jiter.known_null()?; 108 | Ok(JsonUnionField::JsonNull) 109 | } 110 | Peek::True | Peek::False => { 111 | let value = jiter.known_bool(peek)?; 112 | Ok(JsonUnionField::Bool(value)) 113 | } 114 | Peek::String => { 115 | let value = jiter.known_str()?; 116 | Ok(JsonUnionField::Str(value.to_owned())) 117 | } 118 | Peek::Array => { 119 | let start = jiter.current_index(); 120 | jiter.known_skip(peek)?; 121 | let array_slice = jiter.slice_to_current(start); 122 | let array_string = std::str::from_utf8(array_slice)?; 123 | Ok(JsonUnionField::Array(array_string.to_owned())) 124 | } 125 | Peek::Object => { 126 | let start = jiter.current_index(); 127 | jiter.known_skip(peek)?; 128 | let object_slice = jiter.slice_to_current(start); 129 | let object_string = std::str::from_utf8(object_slice)?; 130 | Ok(JsonUnionField::Object(object_string.to_owned())) 131 | } 132 | _ => match jiter.known_number(peek)? { 133 | NumberAny::Int(NumberInt::Int(value)) => Ok(JsonUnionField::Int(value)), 134 | NumberAny::Int(NumberInt::BigInt(_)) => todo!("BigInt not supported yet"), 135 | NumberAny::Float(value) => Ok(JsonUnionField::Float(value)), 136 | }, 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/json_get_bool.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | 3 | use datafusion::arrow::array::BooleanArray; 4 | use datafusion::arrow::datatypes::DataType; 5 | use datafusion::common::Result as DataFusionResult; 6 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 7 | use jiter::Peek; 8 | 9 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; 10 | use crate::common_macros::make_udf_function; 11 | 12 | make_udf_function!( 13 | JsonGetBool, 14 | json_get_bool, 15 | json_data path, 16 | r#"Get an boolean value from a JSON string by its "path""# 17 | ); 18 | 19 | #[derive(Debug)] 20 | pub(super) struct JsonGetBool { 21 | signature: Signature, 22 | aliases: [String; 1], 23 | } 24 | 25 | impl Default for JsonGetBool { 26 | fn default() -> Self { 27 | Self { 28 | signature: Signature::variadic_any(Volatility::Immutable), 29 | aliases: ["json_get_bool".to_string()], 30 | } 31 | } 32 | } 33 | 34 | impl ScalarUDFImpl for JsonGetBool { 35 | fn as_any(&self) -> &dyn Any { 36 | self 37 | } 38 | 39 | fn name(&self) -> &str { 40 | self.aliases[0].as_str() 41 | } 42 | 43 | fn signature(&self) -> &Signature { 44 | &self.signature 45 | } 46 | 47 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 48 | return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) 49 | } 50 | 51 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 52 | invoke::(&args.args, jiter_json_get_bool) 53 | } 54 | 55 | fn aliases(&self) -> &[String] { 56 | &self.aliases 57 | } 58 | } 59 | 60 | fn jiter_json_get_bool(json_data: Option<&str>, path: &[JsonPath]) -> Result { 61 | if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { 62 | match peek { 63 | Peek::True | Peek::False => Ok(jiter.known_bool(peek)?), 64 | _ => get_err!(), 65 | } 66 | } else { 67 | get_err!() 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/json_get_float.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::{ArrayRef, Float64Array, Float64Builder}; 5 | use datafusion::arrow::datatypes::DataType; 6 | use datafusion::common::{Result as DataFusionResult, ScalarValue}; 7 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 8 | use jiter::{NumberAny, Peek}; 9 | 10 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; 11 | use crate::common_macros::make_udf_function; 12 | 13 | make_udf_function!( 14 | JsonGetFloat, 15 | json_get_float, 16 | json_data path, 17 | r#"Get a float value from a JSON string by its "path""# 18 | ); 19 | 20 | #[derive(Debug)] 21 | pub(super) struct JsonGetFloat { 22 | signature: Signature, 23 | aliases: [String; 1], 24 | } 25 | 26 | impl Default for JsonGetFloat { 27 | fn default() -> Self { 28 | Self { 29 | signature: Signature::variadic_any(Volatility::Immutable), 30 | aliases: ["json_get_float".to_string()], 31 | } 32 | } 33 | } 34 | 35 | impl ScalarUDFImpl for JsonGetFloat { 36 | fn as_any(&self) -> &dyn Any { 37 | self 38 | } 39 | 40 | fn name(&self) -> &str { 41 | self.aliases[0].as_str() 42 | } 43 | 44 | fn signature(&self) -> &Signature { 45 | &self.signature 46 | } 47 | 48 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 49 | return_type_check(arg_types, self.name(), DataType::Float64) 50 | } 51 | 52 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 53 | invoke::(&args.args, jiter_json_get_float) 54 | } 55 | 56 | fn aliases(&self) -> &[String] { 57 | &self.aliases 58 | } 59 | } 60 | 61 | impl InvokeResult for Float64Array { 62 | type Item = f64; 63 | 64 | type Builder = Float64Builder; 65 | 66 | // Cheaper to produce a float array rather than dict-encoded floats 67 | const ACCEPT_DICT_RETURN: bool = false; 68 | 69 | fn builder(capacity: usize) -> Self::Builder { 70 | Float64Builder::with_capacity(capacity) 71 | } 72 | 73 | fn append_value(builder: &mut Self::Builder, value: Option) { 74 | builder.append_option(value); 75 | } 76 | 77 | fn finish(mut builder: Self::Builder) -> DataFusionResult { 78 | Ok(Arc::new(builder.finish())) 79 | } 80 | 81 | fn scalar(value: Option) -> ScalarValue { 82 | ScalarValue::Float64(value) 83 | } 84 | } 85 | 86 | fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result { 87 | if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { 88 | match peek { 89 | // numbers are represented by everything else in peek, hence doing it this way 90 | Peek::Null 91 | | Peek::True 92 | | Peek::False 93 | | Peek::Minus 94 | | Peek::Infinity 95 | | Peek::NaN 96 | | Peek::String 97 | | Peek::Array 98 | | Peek::Object => get_err!(), 99 | _ => match jiter.known_number(peek)? { 100 | NumberAny::Float(f) => Ok(f), 101 | NumberAny::Int(int) => Ok(int.into()), 102 | }, 103 | } 104 | } else { 105 | get_err!() 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/json_get_int.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::{ArrayRef, Int64Array, Int64Builder}; 5 | use datafusion::arrow::datatypes::DataType; 6 | use datafusion::common::{Result as DataFusionResult, ScalarValue}; 7 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 8 | use jiter::{NumberInt, Peek}; 9 | 10 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; 11 | use crate::common_macros::make_udf_function; 12 | 13 | make_udf_function!( 14 | JsonGetInt, 15 | json_get_int, 16 | json_data path, 17 | r#"Get an integer value from a JSON string by its "path""# 18 | ); 19 | 20 | #[derive(Debug)] 21 | pub(super) struct JsonGetInt { 22 | signature: Signature, 23 | aliases: [String; 1], 24 | } 25 | 26 | impl Default for JsonGetInt { 27 | fn default() -> Self { 28 | Self { 29 | signature: Signature::variadic_any(Volatility::Immutable), 30 | aliases: ["json_get_int".to_string()], 31 | } 32 | } 33 | } 34 | 35 | impl ScalarUDFImpl for JsonGetInt { 36 | fn as_any(&self) -> &dyn Any { 37 | self 38 | } 39 | 40 | fn name(&self) -> &str { 41 | self.aliases[0].as_str() 42 | } 43 | 44 | fn signature(&self) -> &Signature { 45 | &self.signature 46 | } 47 | 48 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 49 | return_type_check(arg_types, self.name(), DataType::Int64) 50 | } 51 | 52 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 53 | invoke::(&args.args, jiter_json_get_int) 54 | } 55 | 56 | fn aliases(&self) -> &[String] { 57 | &self.aliases 58 | } 59 | } 60 | 61 | impl InvokeResult for Int64Array { 62 | type Item = i64; 63 | 64 | type Builder = Int64Builder; 65 | 66 | // Cheaper to return an int array rather than dict-encoded ints 67 | const ACCEPT_DICT_RETURN: bool = false; 68 | 69 | fn builder(capacity: usize) -> Self::Builder { 70 | Int64Builder::with_capacity(capacity) 71 | } 72 | 73 | fn append_value(builder: &mut Self::Builder, value: Option) { 74 | builder.append_option(value); 75 | } 76 | 77 | fn finish(mut builder: Self::Builder) -> DataFusionResult { 78 | Ok(Arc::new(builder.finish())) 79 | } 80 | 81 | fn scalar(value: Option) -> ScalarValue { 82 | ScalarValue::Int64(value) 83 | } 84 | } 85 | 86 | fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result { 87 | if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { 88 | match peek { 89 | // numbers are represented by everything else in peek, hence doing it this way 90 | Peek::Null 91 | | Peek::True 92 | | Peek::False 93 | | Peek::Minus 94 | | Peek::Infinity 95 | | Peek::NaN 96 | | Peek::String 97 | | Peek::Array 98 | | Peek::Object => get_err!(), 99 | _ => match jiter.known_int(peek)? { 100 | NumberInt::Int(i) => Ok(i), 101 | NumberInt::BigInt(_) => get_err!(), 102 | }, 103 | } 104 | } else { 105 | get_err!() 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/json_get_json.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | 3 | use datafusion::arrow::array::StringArray; 4 | use datafusion::arrow::datatypes::DataType; 5 | use datafusion::common::Result as DataFusionResult; 6 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 7 | 8 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; 9 | use crate::common_macros::make_udf_function; 10 | 11 | make_udf_function!( 12 | JsonGetJson, 13 | json_get_json, 14 | json_data path, 15 | r#"Get a nested raw JSON string from a JSON string by its "path""# 16 | ); 17 | 18 | #[derive(Debug)] 19 | pub(super) struct JsonGetJson { 20 | signature: Signature, 21 | aliases: [String; 1], 22 | } 23 | 24 | impl Default for JsonGetJson { 25 | fn default() -> Self { 26 | Self { 27 | signature: Signature::variadic_any(Volatility::Immutable), 28 | aliases: ["json_get_json".to_string()], 29 | } 30 | } 31 | } 32 | 33 | impl ScalarUDFImpl for JsonGetJson { 34 | fn as_any(&self) -> &dyn Any { 35 | self 36 | } 37 | 38 | fn name(&self) -> &str { 39 | self.aliases[0].as_str() 40 | } 41 | 42 | fn signature(&self) -> &Signature { 43 | &self.signature 44 | } 45 | 46 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 47 | return_type_check(arg_types, self.name(), DataType::Utf8) 48 | } 49 | 50 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 51 | invoke::(&args.args, jiter_json_get_json) 52 | } 53 | 54 | fn aliases(&self) -> &[String] { 55 | &self.aliases 56 | } 57 | } 58 | 59 | fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result { 60 | if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { 61 | let start = jiter.current_index(); 62 | jiter.known_skip(peek)?; 63 | let object_slice = jiter.slice_to_current(start); 64 | let object_string = std::str::from_utf8(object_slice)?; 65 | Ok(object_string.to_owned()) 66 | } else { 67 | get_err!() 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/json_get_str.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | 3 | use datafusion::arrow::array::StringArray; 4 | use datafusion::arrow::datatypes::DataType; 5 | use datafusion::common::Result as DataFusionResult; 6 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 7 | use jiter::Peek; 8 | 9 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; 10 | use crate::common_macros::make_udf_function; 11 | 12 | make_udf_function!( 13 | JsonGetStr, 14 | json_get_str, 15 | json_data path, 16 | r#"Get a string value from a JSON string by its "path""# 17 | ); 18 | 19 | #[derive(Debug)] 20 | pub(super) struct JsonGetStr { 21 | signature: Signature, 22 | aliases: [String; 1], 23 | } 24 | 25 | impl Default for JsonGetStr { 26 | fn default() -> Self { 27 | Self { 28 | signature: Signature::variadic_any(Volatility::Immutable), 29 | aliases: ["json_get_str".to_string()], 30 | } 31 | } 32 | } 33 | 34 | impl ScalarUDFImpl for JsonGetStr { 35 | fn as_any(&self) -> &dyn Any { 36 | self 37 | } 38 | 39 | fn name(&self) -> &str { 40 | self.aliases[0].as_str() 41 | } 42 | 43 | fn signature(&self) -> &Signature { 44 | &self.signature 45 | } 46 | 47 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 48 | return_type_check(arg_types, self.name(), DataType::Utf8) 49 | } 50 | 51 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 52 | invoke::(&args.args, jiter_json_get_str) 53 | } 54 | 55 | fn aliases(&self) -> &[String] { 56 | &self.aliases 57 | } 58 | } 59 | 60 | fn jiter_json_get_str(json_data: Option<&str>, path: &[JsonPath]) -> Result { 61 | if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { 62 | match peek { 63 | Peek::String => Ok(jiter.known_str()?.to_owned()), 64 | _ => get_err!(), 65 | } 66 | } else { 67 | get_err!() 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/json_length.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::{ArrayRef, UInt64Array, UInt64Builder}; 5 | use datafusion::arrow::datatypes::DataType; 6 | use datafusion::common::{Result as DataFusionResult, ScalarValue}; 7 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 8 | use jiter::Peek; 9 | 10 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; 11 | use crate::common_macros::make_udf_function; 12 | 13 | make_udf_function!( 14 | JsonLength, 15 | json_length, 16 | json_data path, 17 | r"Get the length of the array or object at the given path." 18 | ); 19 | 20 | #[derive(Debug)] 21 | pub(super) struct JsonLength { 22 | signature: Signature, 23 | aliases: [String; 2], 24 | } 25 | 26 | impl Default for JsonLength { 27 | fn default() -> Self { 28 | Self { 29 | signature: Signature::variadic_any(Volatility::Immutable), 30 | aliases: ["json_length".to_string(), "json_len".to_string()], 31 | } 32 | } 33 | } 34 | 35 | impl ScalarUDFImpl for JsonLength { 36 | fn as_any(&self) -> &dyn Any { 37 | self 38 | } 39 | 40 | fn name(&self) -> &str { 41 | self.aliases[0].as_str() 42 | } 43 | 44 | fn signature(&self) -> &Signature { 45 | &self.signature 46 | } 47 | 48 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 49 | return_type_check(arg_types, self.name(), DataType::UInt64) 50 | } 51 | 52 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 53 | invoke::(&args.args, jiter_json_length) 54 | } 55 | 56 | fn aliases(&self) -> &[String] { 57 | &self.aliases 58 | } 59 | } 60 | 61 | impl InvokeResult for UInt64Array { 62 | type Item = u64; 63 | 64 | type Builder = UInt64Builder; 65 | 66 | // cheaper to return integers without dict-encoding them 67 | const ACCEPT_DICT_RETURN: bool = false; 68 | 69 | fn builder(capacity: usize) -> Self::Builder { 70 | UInt64Builder::with_capacity(capacity) 71 | } 72 | 73 | fn append_value(builder: &mut Self::Builder, value: Option) { 74 | builder.append_option(value); 75 | } 76 | 77 | fn finish(mut builder: Self::Builder) -> DataFusionResult { 78 | Ok(Arc::new(builder.finish())) 79 | } 80 | 81 | fn scalar(value: Option) -> ScalarValue { 82 | ScalarValue::UInt64(value) 83 | } 84 | } 85 | 86 | fn jiter_json_length(opt_json: Option<&str>, path: &[JsonPath]) -> Result { 87 | if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { 88 | match peek { 89 | Peek::Array => { 90 | let mut peek_opt = jiter.known_array()?; 91 | let mut length: u64 = 0; 92 | while let Some(peek) = peek_opt { 93 | jiter.known_skip(peek)?; 94 | length += 1; 95 | peek_opt = jiter.array_step()?; 96 | } 97 | Ok(length) 98 | } 99 | Peek::Object => { 100 | let mut opt_key = jiter.known_object()?; 101 | 102 | let mut length: u64 = 0; 103 | while opt_key.is_some() { 104 | jiter.next_skip()?; 105 | length += 1; 106 | opt_key = jiter.next_key()?; 107 | } 108 | Ok(length) 109 | } 110 | _ => get_err!(), 111 | } 112 | } else { 113 | get_err!() 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/json_object_keys.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::{ArrayRef, ListBuilder, StringBuilder}; 5 | use datafusion::arrow::datatypes::{DataType, Field}; 6 | use datafusion::common::{Result as DataFusionResult, ScalarValue}; 7 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; 8 | use jiter::Peek; 9 | 10 | use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; 11 | use crate::common_macros::make_udf_function; 12 | 13 | make_udf_function!( 14 | JsonObjectKeys, 15 | json_object_keys, 16 | json_data path, 17 | r"Get the keys of a JSON object as an array." 18 | ); 19 | 20 | #[derive(Debug)] 21 | pub(super) struct JsonObjectKeys { 22 | signature: Signature, 23 | aliases: [String; 2], 24 | } 25 | 26 | impl Default for JsonObjectKeys { 27 | fn default() -> Self { 28 | Self { 29 | signature: Signature::variadic_any(Volatility::Immutable), 30 | aliases: ["json_object_keys".to_string(), "json_keys".to_string()], 31 | } 32 | } 33 | } 34 | 35 | impl ScalarUDFImpl for JsonObjectKeys { 36 | fn as_any(&self) -> &dyn Any { 37 | self 38 | } 39 | 40 | fn name(&self) -> &str { 41 | self.aliases[0].as_str() 42 | } 43 | 44 | fn signature(&self) -> &Signature { 45 | &self.signature 46 | } 47 | 48 | fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { 49 | return_type_check( 50 | arg_types, 51 | self.name(), 52 | DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), 53 | ) 54 | } 55 | 56 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { 57 | invoke::(&args.args, jiter_json_object_keys) 58 | } 59 | 60 | fn aliases(&self) -> &[String] { 61 | &self.aliases 62 | } 63 | } 64 | 65 | /// Struct used to build a `ListArray` from the result of `jiter_json_object_keys`. 66 | #[derive(Debug)] 67 | struct BuildListArray; 68 | 69 | impl InvokeResult for BuildListArray { 70 | type Item = Vec; 71 | 72 | type Builder = ListBuilder; 73 | 74 | const ACCEPT_DICT_RETURN: bool = true; 75 | 76 | fn builder(capacity: usize) -> Self::Builder { 77 | let values_builder = StringBuilder::new(); 78 | ListBuilder::with_capacity(values_builder, capacity) 79 | } 80 | 81 | fn append_value(builder: &mut Self::Builder, value: Option) { 82 | builder.append_option(value.map(|v| v.into_iter().map(Some))); 83 | } 84 | 85 | fn finish(mut builder: Self::Builder) -> DataFusionResult { 86 | Ok(Arc::new(builder.finish())) 87 | } 88 | 89 | fn scalar(value: Option) -> ScalarValue { 90 | keys_to_scalar(value) 91 | } 92 | } 93 | 94 | fn keys_to_scalar(opt_keys: Option>) -> ScalarValue { 95 | let values_builder = StringBuilder::new(); 96 | let mut builder = ListBuilder::new(values_builder); 97 | if let Some(keys) = opt_keys { 98 | for value in keys { 99 | builder.values().append_value(value); 100 | } 101 | builder.append(true); 102 | } else { 103 | builder.append(false); 104 | } 105 | let array = builder.finish(); 106 | ScalarValue::List(Arc::new(array)) 107 | } 108 | 109 | fn jiter_json_object_keys(opt_json: Option<&str>, path: &[JsonPath]) -> Result, GetError> { 110 | if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { 111 | match peek { 112 | Peek::Object => { 113 | let mut opt_key = jiter.known_object()?; 114 | 115 | let mut keys = Vec::new(); 116 | while let Some(key) = opt_key { 117 | keys.push(key.to_string()); 118 | jiter.next_skip()?; 119 | opt_key = jiter.next_key()?; 120 | } 121 | Ok(keys) 122 | } 123 | _ => get_err!(), 124 | } 125 | } else { 126 | get_err!() 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | use log::debug; 2 | use std::sync::Arc; 3 | 4 | use datafusion::common::Result; 5 | use datafusion::execution::FunctionRegistry; 6 | use datafusion::logical_expr::ScalarUDF; 7 | 8 | mod common; 9 | mod common_macros; 10 | mod common_union; 11 | mod json_as_text; 12 | mod json_contains; 13 | mod json_get; 14 | mod json_get_bool; 15 | mod json_get_float; 16 | mod json_get_int; 17 | mod json_get_json; 18 | mod json_get_str; 19 | mod json_length; 20 | mod json_object_keys; 21 | mod rewrite; 22 | 23 | pub use common_union::{JsonUnionEncoder, JsonUnionValue}; 24 | 25 | pub mod functions { 26 | pub use crate::json_as_text::json_as_text; 27 | pub use crate::json_contains::json_contains; 28 | pub use crate::json_get::json_get; 29 | pub use crate::json_get_bool::json_get_bool; 30 | pub use crate::json_get_float::json_get_float; 31 | pub use crate::json_get_int::json_get_int; 32 | pub use crate::json_get_json::json_get_json; 33 | pub use crate::json_get_str::json_get_str; 34 | pub use crate::json_length::json_length; 35 | pub use crate::json_object_keys::json_object_keys; 36 | } 37 | 38 | pub mod udfs { 39 | pub use crate::json_as_text::json_as_text_udf; 40 | pub use crate::json_contains::json_contains_udf; 41 | pub use crate::json_get::json_get_udf; 42 | pub use crate::json_get_bool::json_get_bool_udf; 43 | pub use crate::json_get_float::json_get_float_udf; 44 | pub use crate::json_get_int::json_get_int_udf; 45 | pub use crate::json_get_json::json_get_json_udf; 46 | pub use crate::json_get_str::json_get_str_udf; 47 | pub use crate::json_length::json_length_udf; 48 | pub use crate::json_object_keys::json_object_keys_udf; 49 | } 50 | 51 | /// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`]. 52 | /// 53 | /// # Arguments 54 | /// 55 | /// * `registry`: `FunctionRegistry` to register the UDFs 56 | /// 57 | /// # Errors 58 | /// 59 | /// Returns an error if the UDFs cannot be registered or if the rewriter cannot be registered. 60 | pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { 61 | let functions: Vec> = vec![ 62 | json_get::json_get_udf(), 63 | json_get_bool::json_get_bool_udf(), 64 | json_get_float::json_get_float_udf(), 65 | json_get_int::json_get_int_udf(), 66 | json_get_json::json_get_json_udf(), 67 | json_as_text::json_as_text_udf(), 68 | json_get_str::json_get_str_udf(), 69 | json_contains::json_contains_udf(), 70 | json_length::json_length_udf(), 71 | json_object_keys::json_object_keys_udf(), 72 | ]; 73 | functions.into_iter().try_for_each(|udf| { 74 | let existing_udf = registry.register_udf(udf)?; 75 | if let Some(existing_udf) = existing_udf { 76 | debug!("Overwrite existing UDF: {}", existing_udf.name()); 77 | } 78 | Ok(()) as Result<()> 79 | })?; 80 | registry.register_function_rewrite(Arc::new(rewrite::JsonFunctionRewriter))?; 81 | registry.register_expr_planner(Arc::new(rewrite::JsonExprPlanner))?; 82 | 83 | Ok(()) 84 | } 85 | -------------------------------------------------------------------------------- /src/rewrite.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use datafusion::arrow::datatypes::DataType; 4 | use datafusion::common::config::ConfigOptions; 5 | use datafusion::common::tree_node::Transformed; 6 | use datafusion::common::Column; 7 | use datafusion::common::DFSchema; 8 | use datafusion::common::Result; 9 | use datafusion::logical_expr::expr::{Alias, Cast, Expr, ScalarFunction}; 10 | use datafusion::logical_expr::expr_rewriter::FunctionRewrite; 11 | use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; 12 | use datafusion::logical_expr::sqlparser::ast::BinaryOperator; 13 | use datafusion::logical_expr::ScalarUDF; 14 | use datafusion::scalar::ScalarValue; 15 | 16 | #[derive(Debug)] 17 | pub(crate) struct JsonFunctionRewriter; 18 | 19 | impl FunctionRewrite for JsonFunctionRewriter { 20 | fn name(&self) -> &'static str { 21 | "JsonFunctionRewriter" 22 | } 23 | 24 | fn rewrite(&self, expr: Expr, _schema: &DFSchema, _config: &ConfigOptions) -> Result> { 25 | let transform = match &expr { 26 | Expr::Cast(cast) => optimise_json_get_cast(cast), 27 | Expr::ScalarFunction(func) => unnest_json_calls(func), 28 | _ => None, 29 | }; 30 | Ok(transform.unwrap_or_else(|| Transformed::no(expr))) 31 | } 32 | } 33 | 34 | /// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of 35 | /// extracting the right value type from JSON without the need to materialize the JSON union. 36 | fn optimise_json_get_cast(cast: &Cast) -> Option> { 37 | let scalar_func = extract_scalar_function(&cast.expr)?; 38 | if scalar_func.func.name() != "json_get" { 39 | return None; 40 | } 41 | let func = match &cast.data_type { 42 | DataType::Boolean => crate::json_get_bool::json_get_bool_udf(), 43 | DataType::Float64 | DataType::Float32 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { 44 | crate::json_get_float::json_get_float_udf() 45 | } 46 | DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(), 47 | DataType::Utf8 => crate::json_get_str::json_get_str_udf(), 48 | _ => return None, 49 | }; 50 | Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction { 51 | func, 52 | args: scalar_func.args.clone(), 53 | }))) 54 | } 55 | 56 | // Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')` 57 | fn unnest_json_calls(func: &ScalarFunction) -> Option> { 58 | if !matches!( 59 | func.func.name(), 60 | "json_get" 61 | | "json_get_bool" 62 | | "json_get_float" 63 | | "json_get_int" 64 | | "json_get_json" 65 | | "json_get_str" 66 | | "json_as_text" 67 | ) { 68 | return None; 69 | } 70 | let mut outer_args_iter = func.args.iter(); 71 | let first_arg = outer_args_iter.next()?; 72 | let inner_func = extract_scalar_function(first_arg)?; 73 | 74 | // both json_get and json_as_text would produce new JSON to be processed by the outer 75 | // function so can be inlined 76 | if !matches!(inner_func.func.name(), "json_get" | "json_as_text") { 77 | return None; 78 | } 79 | 80 | let mut args = inner_func.args.clone(); 81 | args.extend(outer_args_iter.cloned()); 82 | // See #23, unnest only when all lookup arguments are literals 83 | if args.iter().skip(1).all(|arg| matches!(arg, Expr::Literal(_))) { 84 | Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction { 85 | func: func.func.clone(), 86 | args, 87 | }))) 88 | } else { 89 | None 90 | } 91 | } 92 | 93 | fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> { 94 | match expr { 95 | Expr::ScalarFunction(func) => Some(func), 96 | Expr::Alias(alias) => extract_scalar_function(&alias.expr), 97 | _ => None, 98 | } 99 | } 100 | 101 | #[derive(Debug, Clone, Copy)] 102 | enum JsonOperator { 103 | Arrow, 104 | LongArrow, 105 | Question, 106 | } 107 | 108 | impl TryFrom<&BinaryOperator> for JsonOperator { 109 | type Error = (); 110 | 111 | fn try_from(op: &BinaryOperator) -> Result { 112 | match op { 113 | BinaryOperator::Arrow => Ok(JsonOperator::Arrow), 114 | BinaryOperator::LongArrow => Ok(JsonOperator::LongArrow), 115 | BinaryOperator::Question => Ok(JsonOperator::Question), 116 | _ => Err(()), 117 | } 118 | } 119 | } 120 | 121 | impl From for Arc { 122 | fn from(op: JsonOperator) -> Arc { 123 | match op { 124 | JsonOperator::Arrow => crate::udfs::json_get_udf(), 125 | JsonOperator::LongArrow => crate::udfs::json_as_text_udf(), 126 | JsonOperator::Question => crate::udfs::json_contains_udf(), 127 | } 128 | } 129 | } 130 | 131 | impl std::fmt::Display for JsonOperator { 132 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 133 | match self { 134 | JsonOperator::Arrow => write!(f, "->"), 135 | JsonOperator::LongArrow => write!(f, "->>"), 136 | JsonOperator::Question => write!(f, "?"), 137 | } 138 | } 139 | } 140 | 141 | /// Convert an Expr to a String representatiion for use in alias names. 142 | fn expr_to_sql_repr(expr: &Expr) -> String { 143 | match expr { 144 | Expr::Column(Column { 145 | name, 146 | relation, 147 | spans: _, 148 | }) => relation 149 | .as_ref() 150 | .map_or_else(|| name.clone(), |r| format!("{r}.{name}")), 151 | Expr::Alias(alias) => alias.name.clone(), 152 | Expr::Literal(scalar) => match scalar { 153 | ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { 154 | format!("'{v}'") 155 | } 156 | ScalarValue::UInt8(Some(v)) => v.to_string(), 157 | ScalarValue::UInt16(Some(v)) => v.to_string(), 158 | ScalarValue::UInt32(Some(v)) => v.to_string(), 159 | ScalarValue::UInt64(Some(v)) => v.to_string(), 160 | ScalarValue::Int8(Some(v)) => v.to_string(), 161 | ScalarValue::Int16(Some(v)) => v.to_string(), 162 | ScalarValue::Int32(Some(v)) => v.to_string(), 163 | ScalarValue::Int64(Some(v)) => v.to_string(), 164 | _ => scalar.to_string(), 165 | }, 166 | Expr::Cast(cast) => expr_to_sql_repr(&cast.expr), 167 | _ => expr.to_string(), 168 | } 169 | } 170 | 171 | /// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs 172 | #[derive(Debug, Default)] 173 | pub struct JsonExprPlanner; 174 | 175 | impl ExprPlanner for JsonExprPlanner { 176 | fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result> { 177 | let Ok(op) = JsonOperator::try_from(&expr.op) else { 178 | return Ok(PlannerResult::Original(expr)); 179 | }; 180 | 181 | let left_repr = expr_to_sql_repr(&expr.left); 182 | let right_repr = expr_to_sql_repr(&expr.right); 183 | 184 | let alias_name = format!("{left_repr} {op} {right_repr}"); 185 | 186 | // we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)` 187 | Ok(PlannerResult::Planned(Expr::Alias(Alias::new( 188 | Expr::ScalarFunction(ScalarFunction { 189 | func: op.into(), 190 | args: vec![expr.left, expr.right], 191 | }), 192 | None::<&str>, 193 | alias_name, 194 | )))) 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /tests/main.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use datafusion::arrow::array::{Array, ArrayRef, DictionaryArray, RecordBatch}; 4 | use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema}; 5 | use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType}; 6 | use datafusion::assert_batches_eq; 7 | use datafusion::common::ScalarValue; 8 | use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs}; 9 | use datafusion::prelude::SessionContext; 10 | use datafusion_functions_json::udfs::json_get_str_udf; 11 | use utils::{create_context, display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params}; 12 | 13 | mod utils; 14 | 15 | #[tokio::test] 16 | async fn test_json_contains() { 17 | let expected = [ 18 | "+------------------+-------------------------------------------+", 19 | "| name | json_contains(test.json_data,Utf8(\"foo\")) |", 20 | "+------------------+-------------------------------------------+", 21 | "| object_foo | true |", 22 | "| object_foo_array | true |", 23 | "| object_foo_obj | true |", 24 | "| object_foo_null | true |", 25 | "| object_bar | false |", 26 | "| list_foo | false |", 27 | "| invalid_json | false |", 28 | "+------------------+-------------------------------------------+", 29 | ]; 30 | 31 | let batches = run_query("select name, json_contains(json_data, 'foo') from test") 32 | .await 33 | .unwrap(); 34 | assert_batches_eq!(expected, &batches); 35 | } 36 | 37 | #[tokio::test] 38 | async fn test_json_contains_array() { 39 | let sql = "select json_contains('[1, 2, 3]', 2)"; 40 | let batches = run_query(sql).await.unwrap(); 41 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 42 | 43 | let sql = "select json_contains('[1, 2, 3]', 3)"; 44 | let batches = run_query(sql).await.unwrap(); 45 | assert_eq!(display_val(batches).await, (DataType::Boolean, "false".to_string())); 46 | } 47 | 48 | #[tokio::test] 49 | async fn test_json_contains_nested() { 50 | let sql = r#"select json_contains('[1, 2, {"foo": null}]', 2)"#; 51 | let batches = run_query(sql).await.unwrap(); 52 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 53 | 54 | let sql = r#"select json_contains('[1, 2, {"foo": null}]', 2, 'foo')"#; 55 | let batches = run_query(sql).await.unwrap(); 56 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 57 | 58 | let sql = r#"select json_contains('[1, 2, {"foo": null}]', 2, 'bar')"#; 59 | let batches = run_query(sql).await.unwrap(); 60 | assert_eq!(display_val(batches).await, (DataType::Boolean, "false".to_string())); 61 | } 62 | 63 | #[tokio::test] 64 | async fn test_json_get_union() { 65 | let batches = run_query("select name, json_get(json_data, 'foo') from test") 66 | .await 67 | .unwrap(); 68 | 69 | let expected = [ 70 | "+------------------+--------------------------------------+", 71 | "| name | json_get(test.json_data,Utf8(\"foo\")) |", 72 | "+------------------+--------------------------------------+", 73 | "| object_foo | {str=abc} |", 74 | "| object_foo_array | {array=[1]} |", 75 | "| object_foo_obj | {object={}} |", 76 | "| object_foo_null | {null=} |", 77 | "| object_bar | {null=} |", 78 | "| list_foo | {null=} |", 79 | "| invalid_json | {null=} |", 80 | "+------------------+--------------------------------------+", 81 | ]; 82 | assert_batches_eq!(expected, &batches); 83 | } 84 | 85 | #[tokio::test] 86 | async fn test_json_get_array() { 87 | let sql = "select json_get('[1, 2, 3]', 2)"; 88 | let batches = run_query(sql).await.unwrap(); 89 | let (value_type, value_repr) = display_val(batches).await; 90 | assert!(matches!(value_type, DataType::Union(_, _))); 91 | assert_eq!(value_repr, "{int=3}"); 92 | } 93 | 94 | #[tokio::test] 95 | async fn test_json_get_equals() { 96 | let e = run_query(r"select name, json_get(json_data, 'foo')='abc' from test") 97 | .await 98 | .unwrap_err(); 99 | 100 | // see https://github.com/apache/datafusion/issues/10180 101 | assert!(e 102 | .to_string() 103 | .starts_with("Error during planning: Cannot infer common argument type for comparison operation Union")); 104 | } 105 | 106 | #[tokio::test] 107 | async fn test_json_get_cast_equals() { 108 | let batches = run_query(r"select name, json_get(json_data, 'foo')::string='abc' from test") 109 | .await 110 | .unwrap(); 111 | 112 | let expected = [ 113 | "+------------------+----------------------------------------------------+", 114 | "| name | json_get(test.json_data,Utf8(\"foo\")) = Utf8(\"abc\") |", 115 | "+------------------+----------------------------------------------------+", 116 | "| object_foo | true |", 117 | "| object_foo_array | |", 118 | "| object_foo_obj | |", 119 | "| object_foo_null | |", 120 | "| object_bar | |", 121 | "| list_foo | |", 122 | "| invalid_json | |", 123 | "+------------------+----------------------------------------------------+", 124 | ]; 125 | assert_batches_eq!(expected, &batches); 126 | } 127 | 128 | #[tokio::test] 129 | async fn test_json_get_str() { 130 | let batches = run_query("select name, json_get_str(json_data, 'foo') from test") 131 | .await 132 | .unwrap(); 133 | 134 | let expected = [ 135 | "+------------------+------------------------------------------+", 136 | "| name | json_get_str(test.json_data,Utf8(\"foo\")) |", 137 | "+------------------+------------------------------------------+", 138 | "| object_foo | abc |", 139 | "| object_foo_array | |", 140 | "| object_foo_obj | |", 141 | "| object_foo_null | |", 142 | "| object_bar | |", 143 | "| list_foo | |", 144 | "| invalid_json | |", 145 | "+------------------+------------------------------------------+", 146 | ]; 147 | assert_batches_eq!(expected, &batches); 148 | } 149 | 150 | #[tokio::test] 151 | async fn test_json_get_str_equals() { 152 | let sql = "select name, json_get_str(json_data, 'foo')='abc' from test"; 153 | let batches = run_query(sql).await.unwrap(); 154 | 155 | let expected = [ 156 | "+------------------+--------------------------------------------------------+", 157 | "| name | json_get_str(test.json_data,Utf8(\"foo\")) = Utf8(\"abc\") |", 158 | "+------------------+--------------------------------------------------------+", 159 | "| object_foo | true |", 160 | "| object_foo_array | |", 161 | "| object_foo_obj | |", 162 | "| object_foo_null | |", 163 | "| object_bar | |", 164 | "| list_foo | |", 165 | "| invalid_json | |", 166 | "+------------------+--------------------------------------------------------+", 167 | ]; 168 | assert_batches_eq!(expected, &batches); 169 | } 170 | 171 | #[tokio::test] 172 | async fn test_json_get_str_int() { 173 | let sql = r#"select json_get_str('["a", "b", "c"]', 1)"#; 174 | let batches = run_query(sql).await.unwrap(); 175 | assert_eq!(display_val(batches).await, (DataType::Utf8, "b".to_string())); 176 | 177 | let sql = r#"select json_get_str('["a", "b", "c"]', 3)"#; 178 | let batches = run_query(sql).await.unwrap(); 179 | assert_eq!(display_val(batches).await, (DataType::Utf8, String::new())); 180 | } 181 | 182 | #[tokio::test] 183 | async fn test_json_get_str_path() { 184 | let sql = r#"select json_get_str('{"a": {"aa": "x", "ab: "y"}, "b": []}', 'a', 'aa')"#; 185 | let batches = run_query(sql).await.unwrap(); 186 | assert_eq!(display_val(batches).await, (DataType::Utf8, "x".to_string())); 187 | } 188 | 189 | #[tokio::test] 190 | async fn test_json_get_str_null() { 191 | let e = run_query(r"select json_get_str('{}', null)").await.unwrap_err(); 192 | 193 | assert_eq!( 194 | e.to_string(), 195 | "Error during planning: Unexpected argument type to 'json_get_str' at position 2, expected string or int, got Null." 196 | ); 197 | } 198 | 199 | #[tokio::test] 200 | async fn test_json_get_no_path() { 201 | let batches = run_query(r#"select json_get('"foo"')::string"#).await.unwrap(); 202 | assert_eq!(display_val(batches).await, (DataType::Utf8, "foo".to_string())); 203 | 204 | let batches = run_query(r"select json_get('123')::int").await.unwrap(); 205 | assert_eq!(display_val(batches).await, (DataType::Int64, "123".to_string())); 206 | 207 | let batches = run_query(r"select json_get('true')::int").await.unwrap(); 208 | assert_eq!(display_val(batches).await, (DataType::Int64, String::new())); 209 | } 210 | 211 | #[tokio::test] 212 | async fn test_json_get_int() { 213 | let batches = run_query(r"select json_get_int('[1, 2, 3]', 1)").await.unwrap(); 214 | assert_eq!(display_val(batches).await, (DataType::Int64, "2".to_string())); 215 | } 216 | 217 | #[tokio::test] 218 | async fn test_json_get_path() { 219 | let batches = run_query(r#"select json_get('{"i": 19}', 'i')::int<20"#).await.unwrap(); 220 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 221 | } 222 | 223 | #[tokio::test] 224 | async fn test_json_get_cast_int() { 225 | let sql = r#"select json_get('{"foo": 42}', 'foo')::int"#; 226 | let batches = run_query(sql).await.unwrap(); 227 | assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); 228 | 229 | // floats not allowed 230 | let sql = r#"select json_get('{"foo": 4.2}', 'foo')::int"#; 231 | let batches = run_query(sql).await.unwrap(); 232 | assert_eq!(display_val(batches).await, (DataType::Int64, String::new())); 233 | } 234 | 235 | #[tokio::test] 236 | async fn test_json_get_cast_int_path() { 237 | let sql = r#"select json_get('{"foo": [null, {"x": false, "bar": 73}}', 'foo', 1, 'bar')::int"#; 238 | let batches = run_query(sql).await.unwrap(); 239 | assert_eq!(display_val(batches).await, (DataType::Int64, "73".to_string())); 240 | } 241 | 242 | #[tokio::test] 243 | async fn test_json_get_int_lookup() { 244 | let sql = "select str_key, json_data from other where json_get_int(json_data, str_key) is not null"; 245 | let batches = run_query(sql).await.unwrap(); 246 | let expected = [ 247 | "+---------+---------------+", 248 | "| str_key | json_data |", 249 | "+---------+---------------+", 250 | "| foo | {\"foo\": 42} |", 251 | "+---------+---------------+", 252 | ]; 253 | assert_batches_eq!(expected, &batches); 254 | 255 | // lookup by int 256 | let sql = "select int_key, json_data from other where json_get_int(json_data, int_key) is not null"; 257 | let batches = run_query(sql).await.unwrap(); 258 | let expected = [ 259 | "+---------+-----------+", 260 | "| int_key | json_data |", 261 | "+---------+-----------+", 262 | "| 0 | [42] |", 263 | "+---------+-----------+", 264 | ]; 265 | assert_batches_eq!(expected, &batches); 266 | } 267 | 268 | #[tokio::test] 269 | async fn test_json_get_float() { 270 | let batches = run_query("select json_get_float('[1.5]', 0)").await.unwrap(); 271 | assert_eq!(display_val(batches).await, (DataType::Float64, "1.5".to_string())); 272 | 273 | // coerce int to float 274 | let batches = run_query("select json_get_float('[1]', 0)").await.unwrap(); 275 | assert_eq!(display_val(batches).await, (DataType::Float64, "1.0".to_string())); 276 | } 277 | 278 | #[tokio::test] 279 | async fn test_json_get_cast_float() { 280 | let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::float"#; 281 | let batches = run_query(sql).await.unwrap(); 282 | assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); 283 | } 284 | 285 | #[tokio::test] 286 | async fn test_json_get_cast_numeric() { 287 | let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::numeric"#; 288 | let batches = run_query(sql).await.unwrap(); 289 | assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); 290 | } 291 | 292 | #[tokio::test] 293 | async fn test_json_get_cast_numeric_equals() { 294 | let sql = r#"select json_get('{"foo": 420}', 'foo')::numeric = 420"#; 295 | let batches = run_query(sql).await.unwrap(); 296 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 297 | } 298 | 299 | #[tokio::test] 300 | async fn test_json_get_bool() { 301 | let batches = run_query("select json_get_bool('[true]', 0)").await.unwrap(); 302 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 303 | 304 | let batches = run_query(r#"select json_get_bool('{"": false}', '')"#).await.unwrap(); 305 | assert_eq!(display_val(batches).await, (DataType::Boolean, "false".to_string())); 306 | } 307 | 308 | #[tokio::test] 309 | async fn test_json_get_cast_bool() { 310 | let sql = r#"select json_get('{"foo": true}', 'foo')::bool"#; 311 | let batches = run_query(sql).await.unwrap(); 312 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 313 | } 314 | 315 | #[tokio::test] 316 | async fn test_json_get_json() { 317 | let batches = run_query("select name, json_get_json(json_data, 'foo') from test") 318 | .await 319 | .unwrap(); 320 | 321 | let expected = [ 322 | "+------------------+-------------------------------------------+", 323 | "| name | json_get_json(test.json_data,Utf8(\"foo\")) |", 324 | "+------------------+-------------------------------------------+", 325 | "| object_foo | \"abc\" |", 326 | "| object_foo_array | [1] |", 327 | "| object_foo_obj | {} |", 328 | "| object_foo_null | null |", 329 | "| object_bar | |", 330 | "| list_foo | |", 331 | "| invalid_json | |", 332 | "+------------------+-------------------------------------------+", 333 | ]; 334 | assert_batches_eq!(expected, &batches); 335 | } 336 | 337 | #[tokio::test] 338 | async fn test_json_get_json_float() { 339 | let sql = r#"select json_get_json('{"x": 4.2e-1}', 'x')"#; 340 | let batches = run_query(sql).await.unwrap(); 341 | assert_eq!(display_val(batches).await, (DataType::Utf8, "4.2e-1".to_string())); 342 | } 343 | 344 | #[tokio::test] 345 | async fn test_json_length_array() { 346 | let sql = "select json_length('[1, 2, 3]')"; 347 | let batches = run_query(sql).await.unwrap(); 348 | assert_eq!(display_val(batches).await, (DataType::UInt64, "3".to_string())); 349 | } 350 | 351 | #[tokio::test] 352 | async fn test_json_length_object() { 353 | let sql = r#"select json_length('{"a": 1, "b": 2, "c": 3}')"#; 354 | let batches = run_query(sql).await.unwrap(); 355 | assert_eq!(display_val(batches).await, (DataType::UInt64, "3".to_string())); 356 | 357 | let sql = r"select json_length('{}')"; 358 | let batches = run_query(sql).await.unwrap(); 359 | assert_eq!(display_val(batches).await, (DataType::UInt64, "0".to_string())); 360 | } 361 | 362 | #[tokio::test] 363 | async fn test_json_length_string() { 364 | let sql = r#"select json_length('"foobar"')"#; 365 | let batches = run_query(sql).await.unwrap(); 366 | assert_eq!(display_val(batches).await, (DataType::UInt64, String::new())); 367 | } 368 | 369 | #[tokio::test] 370 | async fn test_json_length_object_nested() { 371 | let sql = r#"select json_length('{"a": 1, "b": 2, "c": [1, 2]}', 'c')"#; 372 | let batches = run_query(sql).await.unwrap(); 373 | assert_eq!(display_val(batches).await, (DataType::UInt64, "2".to_string())); 374 | 375 | let sql = r#"select json_length('{"a": 1, "b": 2, "c": []}', 'b')"#; 376 | let batches = run_query(sql).await.unwrap(); 377 | assert_eq!(display_val(batches).await, (DataType::UInt64, String::new())); 378 | } 379 | 380 | #[tokio::test] 381 | async fn test_json_contains_large() { 382 | let expected = [ 383 | "+----------+", 384 | "| count(*) |", 385 | "+----------+", 386 | "| 4 |", 387 | "+----------+", 388 | ]; 389 | 390 | let batches = run_query_large("select count(*) from test where json_contains(json_data, 'foo')") 391 | .await 392 | .unwrap(); 393 | assert_batches_eq!(expected, &batches); 394 | } 395 | 396 | #[tokio::test] 397 | async fn test_json_contains_large_vec() { 398 | let expected = [ 399 | "+----------+", 400 | "| count(*) |", 401 | "+----------+", 402 | "| 0 |", 403 | "+----------+", 404 | ]; 405 | 406 | let batches = run_query_large("select count(*) from test where json_contains(json_data, name)") 407 | .await 408 | .unwrap(); 409 | assert_batches_eq!(expected, &batches); 410 | } 411 | 412 | #[tokio::test] 413 | async fn test_json_contains_large_both() { 414 | let expected = [ 415 | "+----------+", 416 | "| count(*) |", 417 | "+----------+", 418 | "| 0 |", 419 | "+----------+", 420 | ]; 421 | 422 | let batches = run_query_large("select count(*) from test where json_contains(json_data, json_data)") 423 | .await 424 | .unwrap(); 425 | assert_batches_eq!(expected, &batches); 426 | } 427 | 428 | #[tokio::test] 429 | async fn test_json_contains_large_params() { 430 | let expected = [ 431 | "+----------+", 432 | "| count(*) |", 433 | "+----------+", 434 | "| 4 |", 435 | "+----------+", 436 | ]; 437 | 438 | let sql = "select count(*) from test where json_contains(json_data, 'foo')"; 439 | let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))]; 440 | let batches = run_query_params(sql, false, params).await.unwrap(); 441 | assert_batches_eq!(expected, &batches); 442 | } 443 | 444 | #[tokio::test] 445 | async fn test_json_contains_large_both_params() { 446 | let expected = [ 447 | "+----------+", 448 | "| count(*) |", 449 | "+----------+", 450 | "| 4 |", 451 | "+----------+", 452 | ]; 453 | 454 | let sql = "select count(*) from test where json_contains(json_data, 'foo')"; 455 | let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))]; 456 | let batches = run_query_params(sql, true, params).await.unwrap(); 457 | assert_batches_eq!(expected, &batches); 458 | } 459 | 460 | #[tokio::test] 461 | async fn test_json_length_vec() { 462 | let sql = r"select name, json_len(json_data) as len from test"; 463 | let batches = run_query(sql).await.unwrap(); 464 | 465 | let expected = [ 466 | "+------------------+-----+", 467 | "| name | len |", 468 | "+------------------+-----+", 469 | "| object_foo | 1 |", 470 | "| object_foo_array | 1 |", 471 | "| object_foo_obj | 1 |", 472 | "| object_foo_null | 1 |", 473 | "| object_bar | 1 |", 474 | "| list_foo | 1 |", 475 | "| invalid_json | |", 476 | "+------------------+-----+", 477 | ]; 478 | assert_batches_eq!(expected, &batches); 479 | 480 | let batches = run_query_large(sql).await.unwrap(); 481 | assert_batches_eq!(expected, &batches); 482 | } 483 | 484 | #[tokio::test] 485 | async fn test_no_args() { 486 | let err = run_query(r"select json_len()").await.unwrap_err(); 487 | assert!(err 488 | .to_string() 489 | .contains("No function matches the given name and argument types 'json_length()'.")); 490 | } 491 | 492 | #[test] 493 | fn test_json_get_utf8() { 494 | let json_get_str = json_get_str_udf(); 495 | let args = &[ 496 | ColumnarValue::Scalar(ScalarValue::Utf8(Some( 497 | r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), 498 | ))), 499 | ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), 500 | ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), 501 | ]; 502 | 503 | let ColumnarValue::Scalar(sv) = json_get_str 504 | .invoke_with_args(ScalarFunctionArgs { 505 | args: args.to_vec(), 506 | number_rows: 1, 507 | return_type: &DataType::Utf8, 508 | }) 509 | .unwrap() 510 | else { 511 | panic!("expected scalar") 512 | }; 513 | 514 | assert_eq!(sv, ScalarValue::Utf8(Some("x".to_string()))); 515 | } 516 | 517 | #[test] 518 | fn test_json_get_large_utf8() { 519 | let json_get_str = json_get_str_udf(); 520 | let args = &[ 521 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( 522 | r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), 523 | ))), 524 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("a".to_string()))), 525 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))), 526 | ]; 527 | 528 | let ColumnarValue::Scalar(sv) = json_get_str 529 | .invoke_with_args(ScalarFunctionArgs { 530 | args: args.to_vec(), 531 | number_rows: 1, 532 | return_type: &DataType::LargeUtf8, 533 | }) 534 | .unwrap() 535 | else { 536 | panic!("expected scalar") 537 | }; 538 | 539 | assert_eq!(sv, ScalarValue::Utf8(Some("x".to_string()))); 540 | } 541 | 542 | #[tokio::test] 543 | async fn test_json_get_union_scalar() { 544 | let expected = [ 545 | "+---------+", 546 | "| v |", 547 | "+---------+", 548 | "| {int=1} |", 549 | "+---------+", 550 | ]; 551 | 552 | let batches = run_query(r#"select json_get(json_get('{"x": {"y": 1}}', 'x'), 'y') as v"#) 553 | .await 554 | .unwrap(); 555 | assert_batches_eq!(expected, &batches); 556 | } 557 | 558 | #[tokio::test] 559 | async fn test_json_get_nested_collapsed() { 560 | let expected = [ 561 | "+------------------+---------+", 562 | "| name | v |", 563 | "+------------------+---------+", 564 | "| object_foo | {null=} |", 565 | "| object_foo_array | {int=1} |", 566 | "| object_foo_obj | {null=} |", 567 | "| object_foo_null | {null=} |", 568 | "| object_bar | {null=} |", 569 | "| list_foo | {null=} |", 570 | "| invalid_json | {null=} |", 571 | "+------------------+---------+", 572 | ]; 573 | 574 | let batches = run_query("select name, json_get(json_get(json_data, 'foo'), 0) v from test") 575 | .await 576 | .unwrap(); 577 | assert_batches_eq!(expected, &batches); 578 | } 579 | 580 | #[tokio::test] 581 | async fn test_json_get_cte() { 582 | // avoid auto-un-nesting with a CTE 583 | let sql = r" 584 | with t as (select name, json_get(json_data, 'foo') j from test) 585 | select name, json_get(j, 0) v from t 586 | "; 587 | let expected = [ 588 | "+------------------+---------+", 589 | "| name | v |", 590 | "+------------------+---------+", 591 | "| object_foo | {null=} |", 592 | "| object_foo_array | {int=1} |", 593 | "| object_foo_obj | {null=} |", 594 | "| object_foo_null | {null=} |", 595 | "| object_bar | {null=} |", 596 | "| list_foo | {null=} |", 597 | "| invalid_json | {null=} |", 598 | "+------------------+---------+", 599 | ]; 600 | 601 | let batches = run_query(sql).await.unwrap(); 602 | assert_batches_eq!(expected, &batches); 603 | } 604 | 605 | #[tokio::test] 606 | async fn test_plan_json_get_cte() { 607 | // avoid auto-unnesting with a CTE 608 | let sql = r" 609 | explain 610 | with t as (select name, json_get(json_data, 'foo') j from test) 611 | select name, json_get(j, 0) v from t 612 | "; 613 | let expected = [ 614 | "Projection: t.name, json_get(t.j, Int64(0)) AS v", 615 | " SubqueryAlias: t", 616 | " Projection: test.name, json_get(test.json_data, Utf8(\"foo\")) AS j", 617 | " TableScan: test projection=[name, json_data]", 618 | ]; 619 | 620 | let plan_lines = logical_plan(sql).await; 621 | assert_eq!(plan_lines, expected); 622 | } 623 | 624 | #[tokio::test] 625 | async fn test_json_get_unnest() { 626 | let sql = "select name, json_get(json_get(json_data, 'foo'), 0) v from test"; 627 | 628 | let expected = [ 629 | "+------------------+---------+", 630 | "| name | v |", 631 | "+------------------+---------+", 632 | "| object_foo | {null=} |", 633 | "| object_foo_array | {int=1} |", 634 | "| object_foo_obj | {null=} |", 635 | "| object_foo_null | {null=} |", 636 | "| object_bar | {null=} |", 637 | "| list_foo | {null=} |", 638 | "| invalid_json | {null=} |", 639 | "+------------------+---------+", 640 | ]; 641 | 642 | let batches = run_query(sql).await.unwrap(); 643 | assert_batches_eq!(expected, &batches); 644 | } 645 | 646 | #[tokio::test] 647 | async fn test_plan_json_get_unnest() { 648 | let sql = "explain select json_get(json_get(json_data, 'foo'), 0) v from test"; 649 | let expected = [ 650 | "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS v", 651 | " TableScan: test projection=[json_data]", 652 | ]; 653 | 654 | let plan_lines = logical_plan(sql).await; 655 | assert_eq!(plan_lines, expected); 656 | } 657 | 658 | #[tokio::test] 659 | async fn test_json_get_int_unnest() { 660 | let sql = "select name, json_get(json_get(json_data, 'foo'), 0)::int v from test"; 661 | 662 | let expected = [ 663 | "+------------------+---+", 664 | "| name | v |", 665 | "+------------------+---+", 666 | "| object_foo | |", 667 | "| object_foo_array | 1 |", 668 | "| object_foo_obj | |", 669 | "| object_foo_null | |", 670 | "| object_bar | |", 671 | "| list_foo | |", 672 | "| invalid_json | |", 673 | "+------------------+---+", 674 | ]; 675 | 676 | let batches = run_query(sql).await.unwrap(); 677 | assert_batches_eq!(expected, &batches); 678 | } 679 | 680 | #[tokio::test] 681 | async fn test_plan_json_get_int_unnest() { 682 | let sql = "explain select json_get(json_get(json_data, 'foo'), 0)::int v from test"; 683 | let expected = [ 684 | "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS v", 685 | " TableScan: test projection=[json_data]", 686 | ]; 687 | 688 | let plan_lines = logical_plan(sql).await; 689 | assert_eq!(plan_lines, expected); 690 | } 691 | 692 | #[tokio::test] 693 | async fn test_multiple_lookup_arrays() { 694 | let sql = "select json_get(json_data, str_key1, str_key2) v from more_nested"; 695 | let err = run_query(sql).await.unwrap_err(); 696 | assert_eq!( 697 | err.to_string(), 698 | "Execution error: More than 1 path element is not supported when querying JSON using an array." 699 | ); 700 | } 701 | 702 | #[tokio::test] 703 | async fn test_json_get_union_array_nested() { 704 | let sql = "select json_get(json_get(json_data, str_key1), str_key2) v from more_nested"; 705 | let expected = [ 706 | "+-------------+", 707 | "| v |", 708 | "+-------------+", 709 | "| {array=[0]} |", 710 | "| {null=} |", 711 | "| {null=} |", 712 | "+-------------+", 713 | ]; 714 | 715 | let batches = run_query(sql).await.unwrap(); 716 | assert_batches_eq!(expected, &batches); 717 | } 718 | 719 | #[tokio::test] 720 | async fn test_plan_json_get_union_array_nested() { 721 | let sql = "explain select json_get(json_get(json_data, str_key1), str_key2) v from more_nested"; 722 | // json_get is not un-nested because lookup types are not literals 723 | let expected = [ 724 | "Projection: json_get(json_get(more_nested.json_data, more_nested.str_key1), more_nested.str_key2) AS v", 725 | " TableScan: more_nested projection=[json_data, str_key1, str_key2]", 726 | ]; 727 | 728 | let plan_lines = logical_plan(sql).await; 729 | assert_eq!(plan_lines, expected); 730 | } 731 | 732 | #[tokio::test] 733 | async fn test_json_get_union_array_skip_double_nested() { 734 | let sql = 735 | "select json_data, json_get_int(json_get(json_get(json_data, str_key1), str_key2), int_key) v from more_nested"; 736 | let expected = [ 737 | "+--------------------------+---+", 738 | "| json_data | v |", 739 | "+--------------------------+---+", 740 | "| {\"foo\": {\"bar\": [0]}} | 0 |", 741 | "| {\"foo\": {\"bar\": [1]}} | |", 742 | "| {\"foo\": {\"bar\": null}} | |", 743 | "+--------------------------+---+", 744 | ]; 745 | 746 | let batches = run_query(sql).await.unwrap(); 747 | assert_batches_eq!(expected, &batches); 748 | } 749 | 750 | #[tokio::test] 751 | async fn test_arrow() { 752 | let batches = run_query("select name, json_data->'foo' from test").await.unwrap(); 753 | 754 | let expected = [ 755 | "+------------------+-------------------------+", 756 | "| name | test.json_data -> 'foo' |", 757 | "+------------------+-------------------------+", 758 | "| object_foo | {str=abc} |", 759 | "| object_foo_array | {array=[1]} |", 760 | "| object_foo_obj | {object={}} |", 761 | "| object_foo_null | {null=} |", 762 | "| object_bar | {null=} |", 763 | "| list_foo | {null=} |", 764 | "| invalid_json | {null=} |", 765 | "+------------------+-------------------------+", 766 | ]; 767 | assert_batches_eq!(expected, &batches); 768 | } 769 | 770 | #[tokio::test] 771 | async fn test_plan_arrow() { 772 | let lines = logical_plan(r"explain select json_data->'foo' from test").await; 773 | 774 | let expected = [ 775 | "Projection: json_get(test.json_data, Utf8(\"foo\")) AS test.json_data -> 'foo'", 776 | " TableScan: test projection=[json_data]", 777 | ]; 778 | 779 | assert_eq!(lines, expected); 780 | } 781 | 782 | #[tokio::test] 783 | async fn test_long_arrow() { 784 | let batches = run_query("select name, json_data->>'foo' from test").await.unwrap(); 785 | 786 | let expected = [ 787 | "+------------------+--------------------------+", 788 | "| name | test.json_data ->> 'foo' |", 789 | "+------------------+--------------------------+", 790 | "| object_foo | abc |", 791 | "| object_foo_array | [1] |", 792 | "| object_foo_obj | {} |", 793 | "| object_foo_null | |", 794 | "| object_bar | |", 795 | "| list_foo | |", 796 | "| invalid_json | |", 797 | "+------------------+--------------------------+", 798 | ]; 799 | assert_batches_eq!(expected, &batches); 800 | } 801 | 802 | #[tokio::test] 803 | async fn test_plan_long_arrow() { 804 | let lines = logical_plan(r"explain select json_data->>'foo' from test").await; 805 | 806 | let expected = [ 807 | "Projection: json_as_text(test.json_data, Utf8(\"foo\")) AS test.json_data ->> 'foo'", 808 | " TableScan: test projection=[json_data]", 809 | ]; 810 | 811 | assert_eq!(lines, expected); 812 | } 813 | 814 | #[tokio::test] 815 | async fn test_long_arrow_eq_str() { 816 | let batches = run_query(r"select name, (json_data->>'foo')='abc' from test") 817 | .await 818 | .unwrap(); 819 | 820 | let expected = [ 821 | "+------------------+----------------------------------------+", 822 | "| name | test.json_data ->> 'foo' = Utf8(\"abc\") |", 823 | "+------------------+----------------------------------------+", 824 | "| object_foo | true |", 825 | "| object_foo_array | false |", 826 | "| object_foo_obj | false |", 827 | "| object_foo_null | |", 828 | "| object_bar | |", 829 | "| list_foo | |", 830 | "| invalid_json | |", 831 | "+------------------+----------------------------------------+", 832 | ]; 833 | assert_batches_eq!(expected, &batches); 834 | } 835 | 836 | /// Test column name / alias creation with a cast in the needle / key 837 | #[tokio::test] 838 | async fn test_arrow_cast_key_text() { 839 | let sql = r#"select ('{"foo": 42}'->>('foo'::text))"#; 840 | let batches = run_query(sql).await.unwrap(); 841 | 842 | let expected = [ 843 | "+-------------------------+", 844 | "| '{\"foo\": 42}' ->> 'foo' |", 845 | "+-------------------------+", 846 | "| 42 |", 847 | "+-------------------------+", 848 | ]; 849 | 850 | assert_batches_eq!(expected, &batches); 851 | } 852 | 853 | #[tokio::test] 854 | async fn test_arrow_cast_int() { 855 | let sql = r#"select ('{"foo": 42}'->'foo')::int"#; 856 | let batches = run_query(sql).await.unwrap(); 857 | 858 | let expected = [ 859 | "+------------------------+", 860 | "| '{\"foo\": 42}' -> 'foo' |", 861 | "+------------------------+", 862 | "| 42 |", 863 | "+------------------------+", 864 | ]; 865 | assert_batches_eq!(expected, &batches); 866 | 867 | assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); 868 | } 869 | 870 | #[tokio::test] 871 | async fn test_plan_arrow_cast_int() { 872 | let lines = logical_plan(r"explain select (json_data->'foo')::int from test").await; 873 | 874 | let expected = [ 875 | "Projection: json_get_int(test.json_data, Utf8(\"foo\")) AS test.json_data -> 'foo'", 876 | " TableScan: test projection=[json_data]", 877 | ]; 878 | 879 | assert_eq!(lines, expected); 880 | } 881 | 882 | #[tokio::test] 883 | async fn test_arrow_double_nested() { 884 | let batches = run_query("select name, json_data->'foo'->0 from test").await.unwrap(); 885 | 886 | let expected = [ 887 | "+------------------+------------------------------+", 888 | "| name | test.json_data -> 'foo' -> 0 |", 889 | "+------------------+------------------------------+", 890 | "| object_foo | {null=} |", 891 | "| object_foo_array | {int=1} |", 892 | "| object_foo_obj | {null=} |", 893 | "| object_foo_null | {null=} |", 894 | "| object_bar | {null=} |", 895 | "| list_foo | {null=} |", 896 | "| invalid_json | {null=} |", 897 | "+------------------+------------------------------+", 898 | ]; 899 | assert_batches_eq!(expected, &batches); 900 | } 901 | 902 | #[tokio::test] 903 | async fn test_plan_arrow_double_nested() { 904 | let lines = logical_plan(r"explain select json_data->'foo'->0 from test").await; 905 | 906 | let expected = [ 907 | "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> 'foo' -> 0", 908 | " TableScan: test projection=[json_data]", 909 | ]; 910 | 911 | assert_eq!(lines, expected); 912 | } 913 | 914 | #[tokio::test] 915 | async fn test_double_arrow_double_nested() { 916 | let batches = run_query("select name, json_data->>'foo'->>0 from test").await.unwrap(); 917 | 918 | let expected = [ 919 | "+------------------+--------------------------------+", 920 | "| name | test.json_data ->> 'foo' ->> 0 |", 921 | "+------------------+--------------------------------+", 922 | "| object_foo | |", 923 | "| object_foo_array | 1 |", 924 | "| object_foo_obj | |", 925 | "| object_foo_null | |", 926 | "| object_bar | |", 927 | "| list_foo | |", 928 | "| invalid_json | |", 929 | "+------------------+--------------------------------+", 930 | ]; 931 | assert_batches_eq!(expected, &batches); 932 | } 933 | 934 | #[tokio::test] 935 | async fn test_plan_double_arrow_double_nested() { 936 | let lines = logical_plan(r"explain select json_data->>'foo'->>0 from test").await; 937 | 938 | let expected = [ 939 | "Projection: json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> 'foo' ->> 0", 940 | " TableScan: test projection=[json_data]", 941 | ]; 942 | 943 | assert_eq!(lines, expected); 944 | } 945 | 946 | #[tokio::test] 947 | async fn test_arrow_double_nested_cast() { 948 | let batches = run_query("select name, (json_data->'foo'->0)::int from test") 949 | .await 950 | .unwrap(); 951 | 952 | let expected = [ 953 | "+------------------+------------------------------+", 954 | "| name | test.json_data -> 'foo' -> 0 |", 955 | "+------------------+------------------------------+", 956 | "| object_foo | |", 957 | "| object_foo_array | 1 |", 958 | "| object_foo_obj | |", 959 | "| object_foo_null | |", 960 | "| object_bar | |", 961 | "| list_foo | |", 962 | "| invalid_json | |", 963 | "+------------------+------------------------------+", 964 | ]; 965 | assert_batches_eq!(expected, &batches); 966 | } 967 | 968 | #[tokio::test] 969 | async fn test_plan_arrow_double_nested_cast() { 970 | let lines = logical_plan(r"explain select (json_data->'foo'->0)::int from test").await; 971 | 972 | let expected = [ 973 | "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> 'foo' -> 0", 974 | " TableScan: test projection=[json_data]", 975 | ]; 976 | 977 | assert_eq!(lines, expected); 978 | } 979 | 980 | #[tokio::test] 981 | async fn test_double_arrow_double_nested_cast() { 982 | let batches = run_query("select name, (json_data->>'foo'->>0)::int from test") 983 | .await 984 | .unwrap(); 985 | 986 | let expected = [ 987 | "+------------------+--------------------------------+", 988 | "| name | test.json_data ->> 'foo' ->> 0 |", 989 | "+------------------+--------------------------------+", 990 | "| object_foo | |", 991 | "| object_foo_array | 1 |", 992 | "| object_foo_obj | |", 993 | "| object_foo_null | |", 994 | "| object_bar | |", 995 | "| list_foo | |", 996 | "| invalid_json | |", 997 | "+------------------+--------------------------------+", 998 | ]; 999 | assert_batches_eq!(expected, &batches); 1000 | } 1001 | 1002 | #[tokio::test] 1003 | async fn test_plan_double_arrow_double_nested_cast() { 1004 | let lines = logical_plan(r"explain select (json_data->>'foo'->>0)::int from test").await; 1005 | 1006 | // NB: json_as_text(..)::int is NOT the same as `json_get_int(..)`, hence the cast is not rewritten 1007 | let expected = [ 1008 | "Projection: CAST(json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> 'foo' ->> 0 AS Int32)", 1009 | " TableScan: test projection=[json_data]", 1010 | ]; 1011 | 1012 | assert_eq!(lines, expected); 1013 | } 1014 | 1015 | #[tokio::test] 1016 | async fn test_arrow_nested_columns() { 1017 | let expected = [ 1018 | "+-------------+", 1019 | "| v |", 1020 | "+-------------+", 1021 | "| {array=[0]} |", 1022 | "| {null=} |", 1023 | "| {null=} |", 1024 | "+-------------+", 1025 | ]; 1026 | 1027 | let sql = "select json_data->str_key1->str_key2 v from more_nested"; 1028 | let batches = run_query(sql).await.unwrap(); 1029 | assert_batches_eq!(expected, &batches); 1030 | } 1031 | 1032 | #[tokio::test] 1033 | async fn test_arrow_nested_double_columns() { 1034 | let expected = [ 1035 | "+---------+", 1036 | "| v |", 1037 | "+---------+", 1038 | "| {int=0} |", 1039 | "| {null=} |", 1040 | "| {null=} |", 1041 | "+---------+", 1042 | ]; 1043 | 1044 | let sql = "select json_data->str_key1->str_key2->int_key v from more_nested"; 1045 | let batches = run_query(sql).await.unwrap(); 1046 | assert_batches_eq!(expected, &batches); 1047 | } 1048 | 1049 | #[tokio::test] 1050 | async fn test_lexical_precedence_correct() { 1051 | #[rustfmt::skip] 1052 | let expected = [ 1053 | "+------+", 1054 | "| v |", 1055 | "+------+", 1056 | "| true |", 1057 | "+------+", 1058 | ]; 1059 | let sql = r#"select '{"a": "b"}'->>'a'='b' as v"#; 1060 | let batches = run_query(sql).await.unwrap(); 1061 | assert_batches_eq!(expected, &batches); 1062 | } 1063 | 1064 | #[tokio::test] 1065 | async fn test_question_mark_contains() { 1066 | let expected = [ 1067 | "+------------------+------------------------+", 1068 | "| name | test.json_data ? 'foo' |", 1069 | "+------------------+------------------------+", 1070 | "| object_foo | true |", 1071 | "| object_foo_array | true |", 1072 | "| object_foo_obj | true |", 1073 | "| object_foo_null | true |", 1074 | "| object_bar | false |", 1075 | "| list_foo | false |", 1076 | "| invalid_json | false |", 1077 | "+------------------+------------------------+", 1078 | ]; 1079 | 1080 | let batches = run_query("select name, json_data ? 'foo' from test").await.unwrap(); 1081 | assert_batches_eq!(expected, &batches); 1082 | } 1083 | 1084 | #[tokio::test] 1085 | async fn test_arrow_filter() { 1086 | let batches = run_query("select name from test where (json_data->>'foo') = 'abc'") 1087 | .await 1088 | .unwrap(); 1089 | 1090 | let expected = [ 1091 | "+------------+", 1092 | "| name |", 1093 | "+------------+", 1094 | "| object_foo |", 1095 | "+------------+", 1096 | ]; 1097 | assert_batches_eq!(expected, &batches); 1098 | } 1099 | 1100 | #[tokio::test] 1101 | async fn test_question_filter() { 1102 | let batches = run_query("select name from test where json_data ? 'foo'") 1103 | .await 1104 | .unwrap(); 1105 | 1106 | let expected = [ 1107 | "+------------------+", 1108 | "| name |", 1109 | "+------------------+", 1110 | "| object_foo |", 1111 | "| object_foo_array |", 1112 | "| object_foo_obj |", 1113 | "| object_foo_null |", 1114 | "+------------------+", 1115 | ]; 1116 | assert_batches_eq!(expected, &batches); 1117 | } 1118 | 1119 | #[tokio::test] 1120 | async fn test_json_get_union_is_null() { 1121 | let batches = run_query("select name, json_get(json_data, 'foo') is null from test") 1122 | .await 1123 | .unwrap(); 1124 | 1125 | let expected = [ 1126 | "+------------------+----------------------------------------------+", 1127 | "| name | json_get(test.json_data,Utf8(\"foo\")) IS NULL |", 1128 | "+------------------+----------------------------------------------+", 1129 | "| object_foo | false |", 1130 | "| object_foo_array | false |", 1131 | "| object_foo_obj | false |", 1132 | "| object_foo_null | true |", 1133 | "| object_bar | true |", 1134 | "| list_foo | true |", 1135 | "| invalid_json | true |", 1136 | "+------------------+----------------------------------------------+", 1137 | ]; 1138 | assert_batches_eq!(expected, &batches); 1139 | } 1140 | 1141 | #[tokio::test] 1142 | async fn test_json_get_union_is_not_null() { 1143 | let batches = run_query("select name, json_get(json_data, 'foo') is not null from test") 1144 | .await 1145 | .unwrap(); 1146 | 1147 | let expected = [ 1148 | "+------------------+--------------------------------------------------+", 1149 | "| name | json_get(test.json_data,Utf8(\"foo\")) IS NOT NULL |", 1150 | "+------------------+--------------------------------------------------+", 1151 | "| object_foo | true |", 1152 | "| object_foo_array | true |", 1153 | "| object_foo_obj | true |", 1154 | "| object_foo_null | false |", 1155 | "| object_bar | false |", 1156 | "| list_foo | false |", 1157 | "| invalid_json | false |", 1158 | "+------------------+--------------------------------------------------+", 1159 | ]; 1160 | assert_batches_eq!(expected, &batches); 1161 | } 1162 | 1163 | #[tokio::test] 1164 | async fn test_arrow_union_is_null() { 1165 | let batches = run_query("select name, (json_data->'foo') is null from test") 1166 | .await 1167 | .unwrap(); 1168 | 1169 | let expected = [ 1170 | "+------------------+---------------------------------+", 1171 | "| name | test.json_data -> 'foo' IS NULL |", 1172 | "+------------------+---------------------------------+", 1173 | "| object_foo | false |", 1174 | "| object_foo_array | false |", 1175 | "| object_foo_obj | false |", 1176 | "| object_foo_null | true |", 1177 | "| object_bar | true |", 1178 | "| list_foo | true |", 1179 | "| invalid_json | true |", 1180 | "+------------------+---------------------------------+", 1181 | ]; 1182 | assert_batches_eq!(expected, &batches); 1183 | } 1184 | 1185 | #[tokio::test] 1186 | async fn test_arrow_union_is_null_dict_encoded() { 1187 | let batches = run_query_dict("select name, (json_data->'foo') is null from test") 1188 | .await 1189 | .unwrap(); 1190 | 1191 | let expected = [ 1192 | "+------------------+---------------------------------+", 1193 | "| name | test.json_data -> 'foo' IS NULL |", 1194 | "+------------------+---------------------------------+", 1195 | "| object_foo | false |", 1196 | "| object_foo_array | false |", 1197 | "| object_foo_obj | false |", 1198 | "| object_foo_null | true |", 1199 | "| object_bar | true |", 1200 | "| list_foo | true |", 1201 | "| invalid_json | true |", 1202 | "+------------------+---------------------------------+", 1203 | ]; 1204 | assert_batches_eq!(expected, &batches); 1205 | } 1206 | 1207 | #[tokio::test] 1208 | async fn test_arrow_union_is_not_null() { 1209 | let batches = run_query("select name, (json_data->'foo') is not null from test") 1210 | .await 1211 | .unwrap(); 1212 | 1213 | let expected = [ 1214 | "+------------------+-------------------------------------+", 1215 | "| name | test.json_data -> 'foo' IS NOT NULL |", 1216 | "+------------------+-------------------------------------+", 1217 | "| object_foo | true |", 1218 | "| object_foo_array | true |", 1219 | "| object_foo_obj | true |", 1220 | "| object_foo_null | false |", 1221 | "| object_bar | false |", 1222 | "| list_foo | false |", 1223 | "| invalid_json | false |", 1224 | "+------------------+-------------------------------------+", 1225 | ]; 1226 | assert_batches_eq!(expected, &batches); 1227 | } 1228 | 1229 | #[tokio::test] 1230 | async fn test_arrow_union_is_not_null_dict_encoded() { 1231 | let batches = run_query_dict("select name, (json_data->'foo') is not null from test") 1232 | .await 1233 | .unwrap(); 1234 | 1235 | let expected = [ 1236 | "+------------------+-------------------------------------+", 1237 | "| name | test.json_data -> 'foo' IS NOT NULL |", 1238 | "+------------------+-------------------------------------+", 1239 | "| object_foo | true |", 1240 | "| object_foo_array | true |", 1241 | "| object_foo_obj | true |", 1242 | "| object_foo_null | false |", 1243 | "| object_bar | false |", 1244 | "| list_foo | false |", 1245 | "| invalid_json | false |", 1246 | "+------------------+-------------------------------------+", 1247 | ]; 1248 | assert_batches_eq!(expected, &batches); 1249 | } 1250 | 1251 | #[tokio::test] 1252 | async fn test_arrow_scalar_union_is_null() { 1253 | let batches = run_query( 1254 | r#" 1255 | select ('{"x": 1}'->'foo') is null as not_contains, 1256 | ('{"foo": 1}'->'foo') is null as contains_num, 1257 | ('{"foo": null}'->'foo') is null as contains_null"#, 1258 | ) 1259 | .await 1260 | .unwrap(); 1261 | 1262 | let expected = [ 1263 | "+--------------+--------------+---------------+", 1264 | "| not_contains | contains_num | contains_null |", 1265 | "+--------------+--------------+---------------+", 1266 | "| true | false | true |", 1267 | "+--------------+--------------+---------------+", 1268 | ]; 1269 | assert_batches_eq!(expected, &batches); 1270 | } 1271 | 1272 | #[tokio::test] 1273 | async fn test_long_arrow_cast() { 1274 | let batches = run_query("select (json_data->>'foo')::int from other").await.unwrap(); 1275 | 1276 | let expected = [ 1277 | "+---------------------------+", 1278 | "| other.json_data ->> 'foo' |", 1279 | "+---------------------------+", 1280 | "| 42 |", 1281 | "| 42 |", 1282 | "| |", 1283 | "| |", 1284 | "+---------------------------+", 1285 | ]; 1286 | assert_batches_eq!(expected, &batches); 1287 | } 1288 | 1289 | #[tokio::test] 1290 | async fn test_arrow_cast_numeric() { 1291 | let sql = r#"select ('{"foo": 420}'->'foo')::numeric = 420"#; 1292 | let batches = run_query(sql).await.unwrap(); 1293 | assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); 1294 | } 1295 | 1296 | #[tokio::test] 1297 | async fn test_dict_haystack() { 1298 | let sql = "select json_get(json_data, 'foo') v from dicts"; 1299 | let expected = [ 1300 | "+-----------------------+", 1301 | "| v |", 1302 | "+-----------------------+", 1303 | "| {object={\"bar\": [0]}} |", 1304 | "| |", 1305 | "| |", 1306 | "| |", 1307 | "+-----------------------+", 1308 | ]; 1309 | 1310 | let batches = run_query(sql).await.unwrap(); 1311 | assert_batches_eq!(expected, &batches); 1312 | } 1313 | 1314 | fn check_for_null_dictionary_values(array: &dyn Array) { 1315 | let array = array.as_any().downcast_ref::>().unwrap(); 1316 | let keys_array = array.keys(); 1317 | let keys = keys_array 1318 | .iter() 1319 | .filter_map(|x| x.map(|v| usize::try_from(v).unwrap())) 1320 | .collect::>(); 1321 | let values_array = array.values(); 1322 | // no non-null keys should point to a null value 1323 | for i in 0..values_array.len() { 1324 | if values_array.is_null(i) { 1325 | // keys should not contain 1326 | if keys.contains(&i) { 1327 | #[allow(clippy::print_stdout)] 1328 | { 1329 | println!("keys: {keys:?}"); 1330 | println!("values: {values_array:?}"); 1331 | panic!("keys should not contain null values"); 1332 | } 1333 | } 1334 | } 1335 | } 1336 | } 1337 | 1338 | /// Test that we don't output nulls in dictionary values. 1339 | /// This can cause issues with arrow-rs and DataFusion; they expect nulls to be in keys. 1340 | #[tokio::test] 1341 | async fn test_dict_get_no_null_values() { 1342 | let ctx = build_dict_schema().await; 1343 | 1344 | let sql = "select json_get(x, 'baz') v from data"; 1345 | let expected = [ 1346 | "+------------+", 1347 | "| v |", 1348 | "+------------+", 1349 | "| |", 1350 | "| {str=fizz} |", 1351 | "| |", 1352 | "| {str=abcd} |", 1353 | "| |", 1354 | "| {str=fizz} |", 1355 | "| {str=fizz} |", 1356 | "| {str=fizz} |", 1357 | "| {str=fizz} |", 1358 | "| |", 1359 | "+------------+", 1360 | ]; 1361 | let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); 1362 | assert_batches_eq!(expected, &batches); 1363 | for batch in batches { 1364 | check_for_null_dictionary_values(batch.column(0).as_ref()); 1365 | } 1366 | 1367 | let sql = "select json_get_str(x, 'baz') v from data"; 1368 | let expected = [ 1369 | "+------+", "| v |", "+------+", "| |", "| fizz |", "| |", "| abcd |", "| |", "| fizz |", 1370 | "| fizz |", "| fizz |", "| fizz |", "| |", "+------+", 1371 | ]; 1372 | let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); 1373 | assert_batches_eq!(expected, &batches); 1374 | for batch in batches { 1375 | check_for_null_dictionary_values(batch.column(0).as_ref()); 1376 | } 1377 | } 1378 | 1379 | #[tokio::test] 1380 | async fn test_dict_haystack_filter() { 1381 | let sql = "select json_data v from dicts where json_get(json_data, 'foo') is not null"; 1382 | let expected = [ 1383 | "+-------------------------+", 1384 | "| v |", 1385 | "+-------------------------+", 1386 | "| {\"foo\": {\"bar\": [0]}} |", 1387 | "+-------------------------+", 1388 | ]; 1389 | 1390 | let batches = run_query(sql).await.unwrap(); 1391 | assert_batches_eq!(expected, &batches); 1392 | } 1393 | 1394 | #[tokio::test] 1395 | async fn test_dict_haystack_needle() { 1396 | let sql = "select json_get(json_get(json_data, str_key1), str_key2) v from dicts"; 1397 | let expected = [ 1398 | "+-------------+", 1399 | "| v |", 1400 | "+-------------+", 1401 | "| {array=[0]} |", 1402 | "| |", 1403 | "| |", 1404 | "| |", 1405 | "+-------------+", 1406 | ]; 1407 | 1408 | let batches = run_query(sql).await.unwrap(); 1409 | assert_batches_eq!(expected, &batches); 1410 | } 1411 | 1412 | #[tokio::test] 1413 | async fn test_dict_length() { 1414 | let sql = "select json_length(json_data) v from dicts"; 1415 | #[rustfmt::skip] 1416 | let expected = [ 1417 | "+---+", 1418 | "| v |", 1419 | "+---+", 1420 | "| 1 |", 1421 | "| 1 |", 1422 | "| 2 |", 1423 | "| 2 |", 1424 | "+---+", 1425 | ]; 1426 | 1427 | let batches = run_query(sql).await.unwrap(); 1428 | assert_batches_eq!(expected, &batches); 1429 | } 1430 | 1431 | #[tokio::test] 1432 | async fn test_dict_contains() { 1433 | let sql = "select json_contains(json_data, str_key2) v from dicts"; 1434 | let expected = [ 1435 | "+-------+", 1436 | "| v |", 1437 | "+-------+", 1438 | "| false |", 1439 | "| false |", 1440 | "| true |", 1441 | "| true |", 1442 | "+-------+", 1443 | ]; 1444 | 1445 | let batches = run_query(sql).await.unwrap(); 1446 | assert_batches_eq!(expected, &batches); 1447 | } 1448 | 1449 | #[tokio::test] 1450 | async fn test_dict_contains_where() { 1451 | let sql = "select str_key2 from dicts where json_contains(json_data, str_key2)"; 1452 | let expected = [ 1453 | "+----------+", 1454 | "| str_key2 |", 1455 | "+----------+", 1456 | "| spam |", 1457 | "| snap |", 1458 | "+----------+", 1459 | ]; 1460 | 1461 | let batches = run_query(sql).await.unwrap(); 1462 | assert_batches_eq!(expected, &batches); 1463 | } 1464 | 1465 | #[tokio::test] 1466 | async fn test_dict_get_int() { 1467 | let sql = "select json_get_int(json_data, str_key2) v from dicts"; 1468 | #[rustfmt::skip] 1469 | let expected = [ 1470 | "+---+", 1471 | "| v |", 1472 | "+---+", 1473 | "| |", 1474 | "| |", 1475 | "| 1 |", 1476 | "| 2 |", 1477 | "+---+", 1478 | ]; 1479 | 1480 | let batches = run_query(sql).await.unwrap(); 1481 | assert_batches_eq!(expected, &batches); 1482 | } 1483 | 1484 | async fn build_dict_schema() -> SessionContext { 1485 | let mut builder = StringDictionaryBuilder::::new(); 1486 | builder.append(r#"{"foo": "bar"}"#).unwrap(); 1487 | builder.append(r#"{"baz": "fizz"}"#).unwrap(); 1488 | builder.append("nah").unwrap(); 1489 | builder.append(r#"{"baz": "abcd"}"#).unwrap(); 1490 | builder.append_null(); 1491 | builder.append(r#"{"baz": "fizz"}"#).unwrap(); 1492 | builder.append(r#"{"baz": "fizz"}"#).unwrap(); 1493 | builder.append(r#"{"baz": "fizz"}"#).unwrap(); 1494 | builder.append(r#"{"baz": "fizz"}"#).unwrap(); 1495 | builder.append_null(); 1496 | 1497 | let dict = builder.finish(); 1498 | 1499 | assert_eq!(dict.len(), 10); 1500 | assert_eq!(dict.values().len(), 4); 1501 | 1502 | let array = Arc::new(dict) as ArrayRef; 1503 | 1504 | let schema = Arc::new(Schema::new(vec![Field::new( 1505 | "x", 1506 | DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), 1507 | true, 1508 | )])); 1509 | 1510 | let data = RecordBatch::try_new(schema.clone(), vec![array]).unwrap(); 1511 | 1512 | let ctx = create_context().await.unwrap(); 1513 | ctx.register_batch("data", data).unwrap(); 1514 | ctx 1515 | } 1516 | 1517 | #[tokio::test] 1518 | async fn test_dict_filter() { 1519 | let ctx = build_dict_schema().await; 1520 | 1521 | let sql = "select json_get(x, 'baz') v from data"; 1522 | let expected = [ 1523 | "+------------+", 1524 | "| v |", 1525 | "+------------+", 1526 | "| |", 1527 | "| {str=fizz} |", 1528 | "| |", 1529 | "| {str=abcd} |", 1530 | "| |", 1531 | "| {str=fizz} |", 1532 | "| {str=fizz} |", 1533 | "| {str=fizz} |", 1534 | "| {str=fizz} |", 1535 | "| |", 1536 | "+------------+", 1537 | ]; 1538 | 1539 | let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); 1540 | 1541 | assert_batches_eq!(expected, &batches); 1542 | } 1543 | 1544 | #[tokio::test] 1545 | async fn test_dict_filter_is_not_null() { 1546 | let ctx = build_dict_schema().await; 1547 | let sql = "select x from data where json_get(x, 'baz') is not null"; 1548 | let expected = [ 1549 | "+-----------------+", 1550 | "| x |", 1551 | "+-----------------+", 1552 | "| {\"baz\": \"fizz\"} |", 1553 | "| {\"baz\": \"abcd\"} |", 1554 | "| {\"baz\": \"fizz\"} |", 1555 | "| {\"baz\": \"fizz\"} |", 1556 | "| {\"baz\": \"fizz\"} |", 1557 | "| {\"baz\": \"fizz\"} |", 1558 | "+-----------------+", 1559 | ]; 1560 | 1561 | let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); 1562 | 1563 | assert_batches_eq!(expected, &batches); 1564 | } 1565 | 1566 | #[tokio::test] 1567 | async fn test_dict_filter_contains() { 1568 | let ctx = build_dict_schema().await; 1569 | let sql = "select x from data where json_contains(x, 'baz')"; 1570 | let expected = [ 1571 | "+-----------------+", 1572 | "| x |", 1573 | "+-----------------+", 1574 | "| {\"baz\": \"fizz\"} |", 1575 | "| {\"baz\": \"abcd\"} |", 1576 | "| {\"baz\": \"fizz\"} |", 1577 | "| {\"baz\": \"fizz\"} |", 1578 | "| {\"baz\": \"fizz\"} |", 1579 | "| {\"baz\": \"fizz\"} |", 1580 | "+-----------------+", 1581 | ]; 1582 | 1583 | let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); 1584 | 1585 | assert_batches_eq!(expected, &batches); 1586 | 1587 | // test with a boolean OR as well 1588 | let batches = ctx 1589 | .sql(&format!("{sql} or false")) 1590 | .await 1591 | .unwrap() 1592 | .collect() 1593 | .await 1594 | .unwrap(); 1595 | 1596 | assert_batches_eq!(expected, &batches); 1597 | } 1598 | 1599 | #[tokio::test] 1600 | async fn test_json_object_keys() { 1601 | let expected = [ 1602 | "+----------------------------------+", 1603 | "| json_object_keys(test.json_data) |", 1604 | "+----------------------------------+", 1605 | "| [foo] |", 1606 | "| [foo] |", 1607 | "| [foo] |", 1608 | "| [foo] |", 1609 | "| [bar] |", 1610 | "| |", 1611 | "| |", 1612 | "+----------------------------------+", 1613 | ]; 1614 | 1615 | let sql = "select json_object_keys(json_data) from test"; 1616 | let batches = run_query(sql).await.unwrap(); 1617 | assert_batches_eq!(expected, &batches); 1618 | 1619 | let sql = "select json_object_keys(json_data) from test"; 1620 | let batches = run_query_dict(sql).await.unwrap(); 1621 | assert_batches_eq!(expected, &batches); 1622 | 1623 | let sql = "select json_object_keys(json_data) from test"; 1624 | let batches = run_query_large(sql).await.unwrap(); 1625 | assert_batches_eq!(expected, &batches); 1626 | } 1627 | 1628 | #[tokio::test] 1629 | async fn test_json_object_keys_many() { 1630 | let expected = [ 1631 | "+-----------------------+", 1632 | "| v |", 1633 | "+-----------------------+", 1634 | "| [foo, bar, spam, ham] |", 1635 | "+-----------------------+", 1636 | ]; 1637 | 1638 | let sql = r#"select json_object_keys('{"foo": 1, "bar": 2.2, "spam": true, "ham": []}') as v"#; 1639 | let batches = run_query(sql).await.unwrap(); 1640 | assert_batches_eq!(expected, &batches); 1641 | } 1642 | 1643 | #[tokio::test] 1644 | async fn test_json_object_keys_nested() { 1645 | let json = r#"'{"foo": [{"bar": {"spam": true, "ham": []}}]}'"#; 1646 | 1647 | let sql = format!("select json_object_keys({json}) as v"); 1648 | let batches = run_query(&sql).await.unwrap(); 1649 | #[rustfmt::skip] 1650 | let expected = [ 1651 | "+-------+", 1652 | "| v |", 1653 | "+-------+", 1654 | "| [foo] |", 1655 | "+-------+", 1656 | ]; 1657 | assert_batches_eq!(expected, &batches); 1658 | 1659 | let sql = format!("select json_object_keys({json}, 'foo') as v"); 1660 | let batches = run_query(&sql).await.unwrap(); 1661 | #[rustfmt::skip] 1662 | let expected = [ 1663 | "+---+", 1664 | "| v |", 1665 | "+---+", 1666 | "| |", 1667 | "+---+", 1668 | ]; 1669 | assert_batches_eq!(expected, &batches); 1670 | 1671 | let sql = format!("select json_object_keys({json}, 'foo', 0) as v"); 1672 | let batches = run_query(&sql).await.unwrap(); 1673 | #[rustfmt::skip] 1674 | let expected = [ 1675 | "+-------+", 1676 | "| v |", 1677 | "+-------+", 1678 | "| [bar] |", 1679 | "+-------+", 1680 | ]; 1681 | assert_batches_eq!(expected, &batches); 1682 | 1683 | let sql = format!("select json_object_keys({json}, 'foo', 0, 'bar') as v"); 1684 | let batches = run_query(&sql).await.unwrap(); 1685 | #[rustfmt::skip] 1686 | let expected = [ 1687 | "+-------------+", 1688 | "| v |", 1689 | "+-------------+", 1690 | "| [spam, ham] |", 1691 | "+-------------+", 1692 | ]; 1693 | assert_batches_eq!(expected, &batches); 1694 | } 1695 | 1696 | #[tokio::test] 1697 | async fn test_lookup_literal_column_matrix() { 1698 | let sql = r#" 1699 | WITH json_columns AS ( 1700 | SELECT unnest(['{"a": 1}', '{"b": 2}']) as json_column 1701 | ), attr_names AS ( 1702 | -- this is deliberately a different length to json_columns 1703 | SELECT 1704 | unnest(['a', 'b', 'c']) as attr_name, 1705 | arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict 1706 | ) 1707 | SELECT 1708 | attr_name, 1709 | json_column, 1710 | 'a' = attr_name, 1711 | json_get('{"a": 1}', attr_name), -- literal lookup with column 1712 | json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column 1713 | json_get('{"a": 1}', 'a'), -- literal lookup with literal 1714 | json_get(json_column, attr_name), -- column lookup with column 1715 | json_get(json_column, attr_name_dict), -- column lookup with dict column 1716 | json_get(json_column, 'a') -- column lookup with literal 1717 | FROM json_columns, attr_names 1718 | "#; 1719 | 1720 | let expected = [ 1721 | "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", 1722 | "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", 1723 | "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", 1724 | "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", 1725 | "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", 1726 | "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", 1727 | "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | {null=} | {null=} | {null=} |", 1728 | "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | {null=} |", 1729 | "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {null=} |", 1730 | "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", 1731 | ]; 1732 | 1733 | let batches = run_query(sql).await.unwrap(); 1734 | assert_batches_eq!(expected, &batches); 1735 | } 1736 | 1737 | #[tokio::test] 1738 | async fn test_lookup_literal_column_matrix_dictionaries() { 1739 | let sql = r#" 1740 | WITH json_columns AS ( 1741 | SELECT arrow_cast(unnest(['{"a": 1}', '{"b": 2}']), 'Dictionary(Int32, Utf8)') as json_column 1742 | ), attr_names AS ( 1743 | -- this is deliberately a different length to json_columns 1744 | SELECT 1745 | unnest(['a', 'b', 'c']) as attr_name, 1746 | arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict 1747 | ) 1748 | SELECT 1749 | attr_name, 1750 | json_column, 1751 | 'a' = attr_name, 1752 | json_get('{"a": 1}', attr_name), -- literal lookup with column 1753 | json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column 1754 | json_get('{"a": 1}', 'a'), -- literal lookup with literal 1755 | json_get(json_column, attr_name), -- column lookup with column 1756 | json_get(json_column, attr_name_dict), -- column lookup with dict column 1757 | json_get(json_column, 'a') -- column lookup with literal 1758 | FROM json_columns, attr_names 1759 | "#; 1760 | 1761 | // NB as compared to the non-dictionary case, we null out the dictionary keys if the return 1762 | // value is a dict, which is why we get true nulls instead of {null=} 1763 | let expected = [ 1764 | "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", 1765 | "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", 1766 | "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", 1767 | "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", 1768 | "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", 1769 | "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", 1770 | "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | | | |", 1771 | "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | |", 1772 | "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | | | |", 1773 | "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", 1774 | ]; 1775 | 1776 | let batches = run_query(sql).await.unwrap(); 1777 | assert_batches_eq!(expected, &batches); 1778 | } 1779 | -------------------------------------------------------------------------------- /tests/utils/mod.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | use std::sync::Arc; 3 | 4 | use datafusion::arrow::array::{ 5 | ArrayRef, DictionaryArray, Int32Array, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array, 6 | }; 7 | use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Int64Type, Schema, UInt32Type, UInt8Type}; 8 | use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions}; 9 | use datafusion::arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch}; 10 | use datafusion::common::ParamValues; 11 | use datafusion::error::Result; 12 | use datafusion::execution::context::SessionContext; 13 | use datafusion::prelude::SessionConfig; 14 | use datafusion_functions_json::register_all; 15 | 16 | pub async fn create_context() -> Result { 17 | let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres"); 18 | let mut ctx = SessionContext::new_with_config(config); 19 | register_all(&mut ctx)?; 20 | Ok(ctx) 21 | } 22 | 23 | #[expect(clippy::too_many_lines)] 24 | async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result { 25 | let ctx = create_context().await?; 26 | 27 | let test_data = [ 28 | ("object_foo", r#" {"foo": "abc"} "#), 29 | ("object_foo_array", r#" {"foo": [1]} "#), 30 | ("object_foo_obj", r#" {"foo": {}} "#), 31 | ("object_foo_null", r#" {"foo": null} "#), 32 | ("object_bar", r#" {"bar": true} "#), 33 | ("list_foo", r#" ["foo"] "#), 34 | ("invalid_json", "is not json"), 35 | ]; 36 | let json_values = test_data.iter().map(|(_, json)| *json).collect::>(); 37 | let (mut json_data_type, mut json_array): (DataType, ArrayRef) = if large_utf8 { 38 | (DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values))) 39 | } else { 40 | (DataType::Utf8, Arc::new(StringArray::from(json_values))) 41 | }; 42 | 43 | if dict_encoded { 44 | json_data_type = DataType::Dictionary(DataType::Int32.into(), json_data_type.into()); 45 | json_array = Arc::new(DictionaryArray::::new( 46 | Int32Array::from_iter_values(0..(i32::try_from(json_array.len()).expect("fits in a i32"))), 47 | json_array, 48 | )); 49 | } 50 | 51 | let test_batch = RecordBatch::try_new( 52 | Arc::new(Schema::new(vec![ 53 | Field::new("name", DataType::Utf8, false), 54 | Field::new("json_data", json_data_type, false), 55 | ])), 56 | vec![ 57 | Arc::new(StringArray::from( 58 | test_data.iter().map(|(name, _)| *name).collect::>(), 59 | )), 60 | json_array, 61 | ], 62 | )?; 63 | ctx.register_batch("test", test_batch)?; 64 | 65 | let other_data = [ 66 | (r#" {"foo": 42} "#, "foo", 0), 67 | (r#" {"foo": 42} "#, "bar", 1), 68 | (r" [42] ", "foo", 0), 69 | (r" [42] ", "bar", 1), 70 | ]; 71 | let other_batch = RecordBatch::try_new( 72 | Arc::new(Schema::new(vec![ 73 | Field::new("json_data", DataType::Utf8, false), 74 | Field::new("str_key", DataType::Utf8, false), 75 | Field::new("int_key", DataType::Int64, false), 76 | ])), 77 | vec![ 78 | Arc::new(StringArray::from( 79 | other_data.iter().map(|(json, _, _)| *json).collect::>(), 80 | )), 81 | Arc::new(StringArray::from( 82 | other_data.iter().map(|(_, str_key, _)| *str_key).collect::>(), 83 | )), 84 | Arc::new(Int64Array::from( 85 | other_data.iter().map(|(_, _, int_key)| *int_key).collect::>(), 86 | )), 87 | ], 88 | )?; 89 | ctx.register_batch("other", other_batch)?; 90 | 91 | let more_nested = [ 92 | (r#" {"foo": {"bar": [0]}} "#, "foo", "bar", 0), 93 | (r#" {"foo": {"bar": [1]}} "#, "foo", "spam", 0), 94 | (r#" {"foo": {"bar": null}} "#, "foo", "bar", 0), 95 | ]; 96 | let more_nested_batch = RecordBatch::try_new( 97 | Arc::new(Schema::new(vec![ 98 | Field::new("json_data", DataType::Utf8, false), 99 | Field::new("str_key1", DataType::Utf8, false), 100 | Field::new( 101 | "str_key2", 102 | DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), 103 | false, 104 | ), 105 | Field::new("int_key", DataType::Int64, false), 106 | ])), 107 | vec![ 108 | Arc::new(StringArray::from( 109 | more_nested.iter().map(|(json, _, _, _)| *json).collect::>(), 110 | )), 111 | Arc::new(StringArray::from( 112 | more_nested 113 | .iter() 114 | .map(|(_, str_key1, _, _)| *str_key1) 115 | .collect::>(), 116 | )), 117 | Arc::new( 118 | more_nested 119 | .iter() 120 | .map(|(_, _, str_key2, _)| *str_key2) 121 | .collect::>(), 122 | ), 123 | Arc::new(Int64Array::from( 124 | more_nested 125 | .iter() 126 | .map(|(_, _, _, int_key)| *int_key) 127 | .collect::>(), 128 | )), 129 | ], 130 | )?; 131 | ctx.register_batch("more_nested", more_nested_batch)?; 132 | 133 | let dict_data = [ 134 | (r#" {"foo": {"bar": [0]}} "#, "foo", "bar", 0), 135 | (r#" {"bar": "snap"} "#, "foo", "spam", 0), 136 | (r#" {"spam": 1, "snap": 2} "#, "foo", "spam", 0), 137 | (r#" {"spam": 1, "snap": 2} "#, "foo", "snap", 0), 138 | ]; 139 | let dict_batch = RecordBatch::try_new( 140 | Arc::new(Schema::new(vec![ 141 | Field::new( 142 | "json_data", 143 | DataType::Dictionary(DataType::UInt32.into(), DataType::Utf8.into()), 144 | false, 145 | ), 146 | Field::new( 147 | "str_key1", 148 | DataType::Dictionary(DataType::UInt8.into(), DataType::LargeUtf8.into()), 149 | false, 150 | ), 151 | Field::new( 152 | "str_key2", 153 | DataType::Dictionary(DataType::UInt8.into(), DataType::Utf8View.into()), 154 | false, 155 | ), 156 | Field::new( 157 | "int_key", 158 | DataType::Dictionary(DataType::Int64.into(), DataType::UInt64.into()), 159 | false, 160 | ), 161 | ])), 162 | vec![ 163 | Arc::new(DictionaryArray::::new( 164 | UInt32Array::from_iter_values( 165 | dict_data 166 | .iter() 167 | .enumerate() 168 | .map(|(id, _)| u32::try_from(id).expect("fits in a u32")), 169 | ), 170 | Arc::new(StringArray::from( 171 | dict_data.iter().map(|(json, _, _, _)| *json).collect::>(), 172 | )), 173 | )), 174 | Arc::new(DictionaryArray::::new( 175 | UInt8Array::from_iter_values( 176 | dict_data 177 | .iter() 178 | .enumerate() 179 | .map(|(id, _)| u8::try_from(id).expect("fits in a u8")), 180 | ), 181 | Arc::new(LargeStringArray::from( 182 | dict_data 183 | .iter() 184 | .map(|(_, str_key1, _, _)| *str_key1) 185 | .collect::>(), 186 | )), 187 | )), 188 | Arc::new(DictionaryArray::::new( 189 | UInt8Array::from_iter_values( 190 | dict_data 191 | .iter() 192 | .enumerate() 193 | .map(|(id, _)| u8::try_from(id).expect("fits in a u8")), 194 | ), 195 | Arc::new(StringViewArray::from( 196 | dict_data 197 | .iter() 198 | .map(|(_, _, str_key2, _)| *str_key2) 199 | .collect::>(), 200 | )), 201 | )), 202 | Arc::new(DictionaryArray::::new( 203 | Int64Array::from_iter_values( 204 | dict_data 205 | .iter() 206 | .enumerate() 207 | .map(|(id, _)| i64::try_from(id).expect("fits in a i64")), 208 | ), 209 | Arc::new(UInt64Array::from_iter_values( 210 | dict_data 211 | .iter() 212 | .map(|(_, _, _, int_key)| u64::try_from(*int_key).expect("not negative")), 213 | )), 214 | )), 215 | ], 216 | )?; 217 | ctx.register_batch("dicts", dict_batch)?; 218 | 219 | Ok(ctx) 220 | } 221 | 222 | pub async fn run_query(sql: &str) -> Result> { 223 | let ctx = create_test_table(false, false).await?; 224 | ctx.sql(sql).await?.collect().await 225 | } 226 | 227 | pub async fn run_query_large(sql: &str) -> Result> { 228 | let ctx = create_test_table(true, false).await?; 229 | ctx.sql(sql).await?.collect().await 230 | } 231 | 232 | pub async fn run_query_dict(sql: &str) -> Result> { 233 | let ctx = create_test_table(false, true).await?; 234 | ctx.sql(sql).await?.collect().await 235 | } 236 | 237 | pub async fn run_query_params( 238 | sql: &str, 239 | large_utf8: bool, 240 | query_values: impl Into, 241 | ) -> Result> { 242 | let ctx = create_test_table(large_utf8, false).await?; 243 | ctx.sql(sql).await?.with_param_values(query_values)?.collect().await 244 | } 245 | 246 | pub async fn display_val(batch: Vec) -> (DataType, String) { 247 | assert_eq!(batch.len(), 1); 248 | let batch = batch.first().unwrap(); 249 | assert_eq!(batch.num_rows(), 1); 250 | let schema = batch.schema(); 251 | let schema_col = schema.field(0); 252 | let c = batch.column(0); 253 | let options = FormatOptions::default().with_display_error(true); 254 | let f = ArrayFormatter::try_new(c.as_ref(), &options).unwrap(); 255 | let repr = f.value(0).try_to_string().unwrap(); 256 | (schema_col.data_type().clone(), repr) 257 | } 258 | 259 | pub async fn logical_plan(sql: &str) -> Vec { 260 | let batches = run_query(sql).await.unwrap(); 261 | let plan_col = batches[0].column(1).as_any().downcast_ref::().unwrap(); 262 | let logical_plan = plan_col.value(0); 263 | logical_plan.split('\n').map(ToString::to_string).collect() 264 | } 265 | --------------------------------------------------------------------------------