├── .gitmodules ├── .gitignore ├── src ├── errors │ ├── mod.rs │ └── zarr_errors.rs ├── table │ ├── mod.rs │ ├── scanner.rs │ ├── opener.rs │ ├── config.rs │ └── table_provider.rs ├── zarr_store_opener │ ├── mod.rs │ ├── io_runtime.rs │ ├── filter.rs │ └── zarr_data_stream.rs └── lib.rs ├── README.md ├── .github └── workflows │ └── rust.yml ├── Cargo.toml ├── benches └── s3_bench.rs └── LICENSE /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Cargo.lock 2 | target -------------------------------------------------------------------------------- /src/errors/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod zarr_errors; 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # arrow-zarr 2 | Implementation of a query engine for zarr storage. -------------------------------------------------------------------------------- /src/table/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod config; 2 | pub(crate) mod opener; 3 | pub(crate) mod scanner; 4 | pub(crate) mod table_provider; 5 | 6 | pub use config::ZarrTableConfig; 7 | pub use table_provider::{ZarrTable, ZarrTableFactory}; 8 | -------------------------------------------------------------------------------- /src/zarr_store_opener/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod filter; 2 | pub(crate) mod io_runtime; 3 | pub(crate) mod zarr_data_stream; 4 | 5 | pub use filter::{ZarrArrowPredicate, ZarrChunkFilter}; 6 | pub use zarr_data_stream::ZarrRecordBatchStream; 7 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | check-fmt: 11 | name: Check cargo fmt 12 | runs-on: ubuntu-latest 13 | container: 14 | image: amd64/rust 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: dtolnay/rust-toolchain@master 18 | with: 19 | toolchain: nightly-2025-05-14 20 | components: rustfmt 21 | - name: Run check 22 | run: cargo +nightly-2025-05-14 fmt -- --check --unstable-features --config imports_granularity=Module,group_imports=StdExternalCrate 23 | 24 | check-clippy-all-features: 25 | name: Check cargo clippy 26 | runs-on: ubuntu-latest 27 | container: 28 | image: amd64/rust 29 | steps: 30 | - uses: actions/checkout@v4 31 | - uses: dtolnay/rust-toolchain@stable 32 | with: 33 | components: clippy 34 | - name: Run check 35 | run: cargo clippy --all-targets --all-features -- -D warnings 36 | 37 | # Check clippy without features, helps to catch missing feature configurations 38 | check-clippy-no-features: 39 | name: Check cargo clippy 40 | runs-on: ubuntu-latest 41 | container: 42 | image: amd64/rust 43 | steps: 44 | - uses: actions/checkout@v4 45 | - uses: dtolnay/rust-toolchain@stable 46 | with: 47 | components: clippy 48 | - name: Run check 49 | run: cargo clippy --all-targets -- -D warnings 50 | 51 | test: 52 | name: Run unit tests 53 | runs-on: ubuntu-latest 54 | steps: 55 | - uses: actions/checkout@v4 56 | - uses: dtolnay/rust-toolchain@stable 57 | - run: cargo test --features icechunk -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "arrow-zarr" 3 | version = "0.1.0" 4 | homepage = "https://github.com/datafusion-contrib/arrow-zarr" 5 | repository = "https://github.com/datafusion-contrib/arrow-zarr" 6 | authors = ["Maxime Dion "] 7 | license = "Apache-2.0" 8 | keywords = ["arrow"] 9 | edition = "2021" 10 | rust-version = "1.86" 11 | 12 | [features] 13 | default = [] 14 | icechunk = ["dep:icechunk", "dep:zarrs_icechunk"] 15 | 16 | [dependencies] 17 | arrow = { version = "55.2.0" } 18 | arrow-array = { version = "55.2.0" } 19 | arrow-schema = { version = "55.2.0" } 20 | async-stream = "0.3" 21 | async-trait = { version = "0.1.89" } 22 | bytes = { version = "1.10.1" } 23 | chrono = { version = "0.4.42" } 24 | datafusion = { version = "49.0.0" } 25 | futures = { version = "0.3.31" } 26 | futures-util = { version = "0.3.31" } 27 | icechunk = { version = "0.3.17", optional = true } 28 | itertools = { version = "0.14.0" } 29 | ndarray = { version = "^0.16.1" } 30 | object_store = { version = "0.12.0", features = ["aws", "gcp"] } 31 | tokio = { version = "1.46.1", features = ["rt", "macros", "net", "rt-multi-thread"] } 32 | tokio-test = { version = "0.4.4" } 33 | zarrs = { version = "0.22.1", features = ["async"] } 34 | zarrs_filesystem = { version = "0.3.0" } 35 | zarrs_icechunk = { version = "0.4.0", optional = true } 36 | zarrs_metadata = { version = "0.6.0" } 37 | zarrs_object_store = { version = "0.5.0" } 38 | zarrs_storage = { version = "0.4.0", features = ["async"] } 39 | 40 | [dev-dependencies] 41 | aws-config = { version = "1.5.18" } 42 | aws-sdk-s3 = { version = "1.78.0" } 43 | criterion = { version = "0.7.0", features = ["async_tokio"] } 44 | walkdir = { version = "2.5.0" } 45 | 46 | 47 | [[bench]] 48 | name = "s3_bench" 49 | harness = false 50 | required-features = ["icechunk"] -------------------------------------------------------------------------------- /src/zarr_store_opener/io_runtime.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use tokio::runtime::Handle; 4 | use tokio::sync::Notify; 5 | 6 | use crate::errors::zarr_errors::ZarrQueryResult; 7 | 8 | /// More or less copied from here, https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/thread_pools.rs 9 | /// with a few tweaks to make this a runtime for non blockinng i/o. 10 | pub(crate) struct IoRuntime { 11 | /// Handle is the tokio structure for interacting with a Runtime. 12 | handle: Handle, 13 | /// Signal to start shutting down. 14 | notify_shutdown: Arc, 15 | /// When thread is active, is Some. 16 | thread_join_handle: Option>, 17 | } 18 | 19 | impl Drop for IoRuntime { 20 | fn drop(&mut self) { 21 | // Notify the thread to shutdown. 22 | self.notify_shutdown.notify_one(); 23 | 24 | // TODO make sure that no tasks can be added to the runtime 25 | // past this point. 26 | 27 | if let Some(thread_join_handle) = self.thread_join_handle.take() { 28 | // If the thread is still running, we wait for it to finish. 29 | if let Err(e) = thread_join_handle.join() { 30 | eprintln!("Error joining CPU runtime thread: {e:?}",); 31 | } 32 | } 33 | } 34 | } 35 | 36 | impl IoRuntime { 37 | /// Create a new Tokio Runtime for non-blocking tasks. 38 | pub(crate) fn try_new() -> ZarrQueryResult { 39 | let io_runtime = tokio::runtime::Builder::new_multi_thread() 40 | .worker_threads(1) 41 | .enable_time() 42 | .enable_io() 43 | .build()?; 44 | 45 | let handle = io_runtime.handle().clone(); 46 | let notify_shutdown = Arc::new(Notify::new()); 47 | let notify_shutdown_captured = Arc::clone(¬ify_shutdown); 48 | 49 | // The io_runtime runs and is dropped on a separate thread. 50 | let thread_join_handle = std::thread::spawn(move || { 51 | io_runtime.block_on(async move { 52 | notify_shutdown_captured.notified().await; 53 | }); 54 | // The io_runtime is dropped here, which will wait for all tasks 55 | // to complete. 56 | }); 57 | 58 | Ok(Self { 59 | handle, 60 | notify_shutdown, 61 | thread_join_handle: Some(thread_join_handle), 62 | }) 63 | } 64 | 65 | pub(crate) fn handle(&self) -> &Handle { 66 | &self.handle 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/errors/zarr_errors.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | use std::error::Error; 19 | 20 | use arrow::error::ArrowError; 21 | use datafusion::error::DataFusionError; 22 | use zarrs::array::codec::CodecError; 23 | use zarrs::array::{ArrayCreateError, ArrayError}; 24 | use zarrs_storage::{StorageError, StorePrefixError}; 25 | 26 | #[derive(Debug)] 27 | pub enum ZarrQueryError { 28 | InvalidProjection(String), 29 | InvalidType(String), 30 | InvalidArrayShapes(String), 31 | InvalidMetadata(String), 32 | InvalidCompute(String), 33 | RecordBatchError(Box), 34 | Zarrs(Box), 35 | Io(Box), 36 | } 37 | 38 | impl std::fmt::Display for ZarrQueryError { 39 | fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { 40 | match &self { 41 | Self::InvalidProjection(msg) => write!(fmt, "Invalid projection: {msg}"), 42 | Self::InvalidType(msg) => write!(fmt, "Invaild type: {msg}"), 43 | Self::InvalidArrayShapes(msg) => write!(fmt, "Invaild array shapes: {msg}"), 44 | Self::InvalidMetadata(msg) => write!(fmt, "Invaild meta data: {msg}"), 45 | Self::InvalidCompute(msg) => write!(fmt, "Invaild compute: {msg}"), 46 | Self::RecordBatchError(e) => write!(fmt, "A record batch call returned an error: {e}"), 47 | Self::Zarrs(e) => write!(fmt, "A zarrs call returned an error: {e}"), 48 | Self::Io(e) => write!(fmt, "A zarrs call returned an error: {e}"), 49 | } 50 | } 51 | } 52 | 53 | impl Error for ZarrQueryError {} 54 | 55 | impl From for ZarrQueryError { 56 | fn from(e: StorageError) -> ZarrQueryError { 57 | ZarrQueryError::Zarrs(Box::new(e)) 58 | } 59 | } 60 | 61 | impl From for ZarrQueryError { 62 | fn from(e: StorePrefixError) -> ZarrQueryError { 63 | ZarrQueryError::Zarrs(Box::new(e)) 64 | } 65 | } 66 | 67 | impl From for ZarrQueryError { 68 | fn from(e: ArrayCreateError) -> ZarrQueryError { 69 | ZarrQueryError::Zarrs(Box::new(e)) 70 | } 71 | } 72 | 73 | impl From for ZarrQueryError { 74 | fn from(e: CodecError) -> ZarrQueryError { 75 | ZarrQueryError::Zarrs(Box::new(e)) 76 | } 77 | } 78 | 79 | impl From for ZarrQueryError { 80 | fn from(e: ArrayError) -> ZarrQueryError { 81 | ZarrQueryError::Zarrs(Box::new(e)) 82 | } 83 | } 84 | 85 | impl From for ZarrQueryError { 86 | fn from(e: ArrowError) -> ZarrQueryError { 87 | ZarrQueryError::RecordBatchError(Box::new(e)) 88 | } 89 | } 90 | 91 | impl From for ZarrQueryError { 92 | fn from(e: std::io::Error) -> ZarrQueryError { 93 | ZarrQueryError::Io(Box::new(e)) 94 | } 95 | } 96 | 97 | /// A specialized [`Result`] for [`ZarrError`]s. 98 | pub type ZarrQueryResult = Result; 99 | 100 | impl From for ArrowError { 101 | fn from(e: ZarrQueryError) -> ArrowError { 102 | ArrowError::ExternalError(Box::new(e)) 103 | } 104 | } 105 | 106 | impl From for DataFusionError { 107 | fn from(e: ZarrQueryError) -> DataFusionError { 108 | DataFusionError::External(Box::new(e)) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/table/scanner.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::sync::Arc; 3 | 4 | use datafusion::datasource::listing::PartitionedFile; 5 | use datafusion::datasource::physical_plan::{ 6 | FileGroup, FileScanConfigBuilder, FileSource, FileStream, 7 | }; 8 | use datafusion::execution::object_store::ObjectStoreUrl; 9 | use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; 10 | use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; 11 | use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; 12 | use datafusion::physical_plan::{ 13 | DisplayAs, DisplayFormatType, ExecutionPlan, PhysicalExpr, PlanProperties, 14 | SendableRecordBatchStream, 15 | }; 16 | use object_store::local::LocalFileSystem; 17 | 18 | use super::config::ZarrTableConfig; 19 | use super::opener::ZarrSource; 20 | 21 | #[derive(Debug, Clone)] 22 | pub struct ZarrScan { 23 | zarr_config: ZarrTableConfig, 24 | filters: Option>, 25 | plan_properties: PlanProperties, 26 | } 27 | 28 | impl ZarrScan { 29 | pub(crate) fn new( 30 | zarr_config: ZarrTableConfig, 31 | filters: Option>, 32 | ) -> Self { 33 | let plan_properties = PlanProperties::new( 34 | EquivalenceProperties::new(zarr_config.get_projected_schema_ref()), 35 | Partitioning::UnknownPartitioning(1), 36 | EmissionType::Incremental, 37 | Boundedness::Bounded, 38 | ); 39 | 40 | Self { 41 | zarr_config, 42 | filters, 43 | plan_properties, 44 | } 45 | } 46 | } 47 | 48 | impl DisplayAs for ZarrScan { 49 | fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { 50 | write!(f, "ZarrScan") 51 | } 52 | } 53 | 54 | impl ExecutionPlan for ZarrScan { 55 | fn name(&self) -> &str { 56 | "ZarrScan" 57 | } 58 | 59 | fn as_any(&self) -> &dyn Any { 60 | self 61 | } 62 | 63 | fn children(&self) -> Vec<&Arc> { 64 | vec![] 65 | } 66 | 67 | fn with_new_children( 68 | self: Arc, 69 | _children: Vec>, 70 | ) -> datafusion::error::Result> { 71 | Ok(self) 72 | } 73 | 74 | fn properties(&self) -> &PlanProperties { 75 | &self.plan_properties 76 | } 77 | 78 | fn repartitioned( 79 | &self, 80 | target_partitions: usize, 81 | _config: &datafusion::config::ConfigOptions, 82 | ) -> datafusion::error::Result>> { 83 | let mut new_plan = self.clone(); 84 | new_plan.plan_properties = new_plan 85 | .plan_properties 86 | .with_partitioning(Partitioning::UnknownPartitioning(target_partitions)); 87 | Ok(Some(Arc::new(new_plan))) 88 | } 89 | 90 | fn execute( 91 | &self, 92 | partition: usize, 93 | _context: Arc, 94 | ) -> datafusion::error::Result { 95 | let n_partitions = match self.plan_properties.partitioning { 96 | Partitioning::UnknownPartitioning(n) => n, 97 | _ => { 98 | return Err(datafusion::error::DataFusionError::Execution( 99 | "Only Unknown partitioning support for zarr scans".into(), 100 | )); 101 | } 102 | }; 103 | 104 | let zarr_source = 105 | ZarrSource::new(self.zarr_config.clone(), n_partitions, self.filters.clone()); 106 | let file_groups = vec![FileGroup::new(vec![PartitionedFile::new("", 0)])]; 107 | let file_scan_config = FileScanConfigBuilder::new( 108 | ObjectStoreUrl::parse("file://").unwrap(), 109 | self.zarr_config.get_schema_ref(), 110 | Arc::new(zarr_source.clone()), 111 | ) 112 | .with_file_groups(file_groups) 113 | .with_projection(self.zarr_config.get_projection()) 114 | .build(); 115 | 116 | let dummy_object_store = Arc::new(LocalFileSystem::new()); 117 | let file_opener = 118 | zarr_source.create_file_opener(dummy_object_store, &file_scan_config, partition); 119 | let metrics = ExecutionPlanMetricsSet::default(); 120 | 121 | // Note: the "partition" argument is hardcoded to 0 here. We are not making 122 | // use of most of the logic in the file stream, for example the partitioning 123 | // logic is handled in the zarr stream object, so we need to effectively 124 | // "disable" it in the file stream obejct by always setting it to 0. 125 | let file_stream = FileStream::new(&file_scan_config, 0, file_opener, &metrics).unwrap(); 126 | 127 | Ok(Box::pin(file_stream)) 128 | } 129 | } 130 | 131 | #[cfg(test)] 132 | mod scanner_tests { 133 | use std::collections::HashMap; 134 | 135 | use arrow::datatypes::Float64Type; 136 | use arrow_schema::DataType; 137 | use datafusion::config::ConfigOptions; 138 | use datafusion::datasource::listing::ListingTableUrl; 139 | use datafusion::prelude::SessionContext; 140 | use futures_util::TryStreamExt; 141 | 142 | use super::*; 143 | use crate::table::config::ZarrTableUrl; 144 | use crate::test_utils::{ 145 | get_local_zarr_store, validate_names_and_types, validate_primitive_column, 146 | }; 147 | 148 | #[tokio::test] 149 | async fn read_data_test() { 150 | let (wrapper, schema) = get_local_zarr_store(true, 0.0, "lat_lon_data_for_scan").await; 151 | let path = wrapper.get_store_path(); 152 | let table_url = ZarrTableUrl::ZarrStore(ListingTableUrl::parse(path).unwrap()); 153 | let config = ZarrTableConfig::new(table_url, schema); 154 | 155 | let session = SessionContext::new(); 156 | let scan = ZarrScan::new(config, None); 157 | let records: Vec<_> = scan 158 | .execute(0, session.task_ctx()) 159 | .unwrap() 160 | .try_collect() 161 | .await 162 | .unwrap(); 163 | 164 | let target_types = HashMap::from([ 165 | ("lat".to_string(), DataType::Float64), 166 | ("lon".to_string(), DataType::Float64), 167 | ("data".to_string(), DataType::Float64), 168 | ]); 169 | validate_names_and_types(&target_types, &records[0]); 170 | assert_eq!(records.len(), 9); 171 | 172 | // the top left chunk, full 3x3 173 | validate_primitive_column::( 174 | "lat", 175 | &records[0], 176 | &[35., 35., 35., 36., 36., 36., 37., 37., 37.], 177 | ); 178 | validate_primitive_column::( 179 | "lon", 180 | &records[0], 181 | &[ 182 | -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, 183 | ], 184 | ); 185 | validate_primitive_column::( 186 | "data", 187 | &records[0], 188 | &[0.0, 1.0, 2.0, 8.0, 9.0, 10.0, 16.0, 17.0, 18.0], 189 | ); 190 | } 191 | 192 | #[tokio::test] 193 | async fn read_partition_test() { 194 | let (wrapper, schema) = 195 | get_local_zarr_store(true, 0.0, "lat_lon_data_for_scan_with_partition").await; 196 | let path = wrapper.get_store_path(); 197 | let table_url = ZarrTableUrl::ZarrStore(ListingTableUrl::parse(path).unwrap()); 198 | let config = ZarrTableConfig::new(table_url, schema); 199 | 200 | let session = SessionContext::new(); 201 | let scan = ZarrScan::new(config, None); 202 | let scan = scan 203 | .repartitioned(2, &ConfigOptions::default()) 204 | .unwrap() 205 | .unwrap(); 206 | 207 | let records: Vec<_> = scan 208 | .execute(1, session.task_ctx()) 209 | .unwrap() 210 | .try_collect() 211 | .await 212 | .unwrap(); 213 | 214 | let target_types = HashMap::from([ 215 | ("lat".to_string(), DataType::Float64), 216 | ("lon".to_string(), DataType::Float64), 217 | ("data".to_string(), DataType::Float64), 218 | ]); 219 | validate_names_and_types(&target_types, &records[0]); 220 | assert_eq!(records.len(), 4); 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /benches/s3_bench.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::env; 3 | use std::hint::black_box; 4 | use std::sync::Arc; 5 | 6 | use arrow_zarr::table::ZarrTableFactory; 7 | use aws_config::{self, BehaviorVersion}; 8 | use aws_sdk_s3::types::{Delete, ObjectIdentifier}; 9 | use aws_sdk_s3::Client; 10 | use criterion::{criterion_group, criterion_main, Criterion}; 11 | use datafusion::datasource::listing::ListingTableUrl; 12 | use datafusion::execution::SessionStateBuilder; 13 | use datafusion::prelude::SessionContext; 14 | use icechunk::config::{S3Credentials, S3Options}; 15 | use icechunk::{ObjectStorage, Repository}; 16 | use ndarray::{Array, Array2}; 17 | use zarrs::array::{codec, ArrayBuilder, DataType, FillValue}; 18 | use zarrs::array_subset::ArraySubset; 19 | use zarrs_icechunk::AsyncIcechunkStore; 20 | use zarrs_storage::AsyncReadableWritableListableStorageTraits; 21 | 22 | async fn create_s3_icechunk(url: &str) -> Arc { 23 | let listing_url = ListingTableUrl::parse(url).unwrap(); 24 | let bucket = listing_url 25 | .object_store() 26 | .as_str() 27 | .replace("s3://", "") 28 | .trim_end_matches("/") 29 | .to_string(); 30 | 31 | let credentials = S3Credentials::FromEnv; 32 | let config = S3Options { 33 | region: env::var("AWS_DEFAULT_REGION").ok(), 34 | endpoint_url: None, 35 | anonymous: false, 36 | allow_http: false, 37 | force_path_style: false, 38 | network_stream_timeout_seconds: None, 39 | requester_pays: false, 40 | }; 41 | 42 | let store = ObjectStorage::new_s3( 43 | bucket, 44 | Some(listing_url.prefix().as_ref().to_string()), 45 | Some(credentials), 46 | Some(config), 47 | ) 48 | .await 49 | .unwrap(); 50 | 51 | let repo = Repository::create(None, Arc::new(store), HashMap::new()) 52 | .await 53 | .unwrap(); 54 | let session = repo.writable_session("main").await.unwrap(); 55 | 56 | Arc::new(AsyncIcechunkStore::new(session)) 57 | } 58 | 59 | fn get_lz4_compressor() -> codec::BloscCodec { 60 | codec::BloscCodec::new( 61 | codec::bytes_to_bytes::blosc::BloscCompressor::LZ4, 62 | 5.try_into().unwrap(), 63 | Some(0), 64 | codec::bytes_to_bytes::blosc::BloscShuffleMode::NoShuffle, 65 | Some(1), 66 | ) 67 | .unwrap() 68 | } 69 | 70 | async fn write_data_to_store( 71 | store: Arc, 72 | start_var_idx: usize, 73 | prefix: &str, 74 | ) { 75 | let n = 512; 76 | let fill_value: i64 = 0; 77 | let mut array_builder = ArrayBuilder::new( 78 | vec![n, n], 79 | [8, 8], 80 | DataType::Int64, 81 | FillValue::from(fill_value), 82 | ); 83 | 84 | let mut builder_ref = &mut array_builder; 85 | let codec = get_lz4_compressor(); 86 | builder_ref = builder_ref.bytes_to_bytes_codecs(vec![Arc::new(codec)]); 87 | 88 | let prefix = if prefix.is_empty() { 89 | prefix 90 | } else { 91 | &format!("/{}", prefix) 92 | }; 93 | for var_idx in start_var_idx..(start_var_idx + 8) { 94 | let arr = builder_ref 95 | .build(store.clone(), &format!("{}/var{}", prefix, var_idx)) 96 | .unwrap(); 97 | arr.async_store_metadata().await.unwrap(); 98 | 99 | let arr_data: Array2 = Array::from_vec((0..(n * n) as i64).step_by(1).collect()) 100 | .into_shape_with_order((n as usize, n as usize)) 101 | .unwrap(); 102 | arr.async_store_array_subset_ndarray( 103 | ArraySubset::new_with_ranges(&[0..n, 0..n]).start(), 104 | arr_data, 105 | ) 106 | .await 107 | .unwrap(); 108 | } 109 | } 110 | 111 | struct S3TestFixture { 112 | bucket: String, 113 | prefix: String, 114 | client: Client, 115 | session: SessionContext, 116 | } 117 | 118 | impl S3TestFixture { 119 | fn new() -> Self { 120 | let url = "s3://zarr-unit-tests/test_data_1"; 121 | let rt = tokio::runtime::Runtime::new().unwrap(); 122 | 123 | let (client, session) = rt.block_on(async { 124 | let store = create_s3_icechunk(url).await; 125 | write_data_to_store(store.clone(), 1, "").await; 126 | let _ = store 127 | .session() 128 | .write() 129 | .await 130 | .commit("some test data", None) 131 | .await 132 | .unwrap(); 133 | 134 | let config = aws_config::load_defaults(BehaviorVersion::latest()).await; 135 | let client = Client::new(&config); 136 | 137 | let mut state = SessionStateBuilder::new().build(); 138 | 139 | state 140 | .table_factories_mut() 141 | .insert("ICECHUNK_REPO".into(), Arc::new(ZarrTableFactory {})); 142 | let session = SessionContext::new_with_state(state.clone()); 143 | 144 | let query = format!( 145 | " 146 | CREATE EXTERNAL TABLE zarr_table 147 | STORED AS ICECHUNK_REPO LOCATION '{}' 148 | ", 149 | url 150 | ); 151 | session.sql(&query).await.unwrap(); 152 | 153 | (client, session) 154 | }); 155 | 156 | Self { 157 | bucket: "zarr-unit-tests".into(), 158 | prefix: "test_data_1".into(), 159 | client, 160 | session, 161 | } 162 | } 163 | 164 | fn get_session(&self) -> &SessionContext { 165 | &self.session 166 | } 167 | } 168 | 169 | impl Drop for S3TestFixture { 170 | fn drop(&mut self) { 171 | let rt = tokio::runtime::Runtime::new().unwrap(); 172 | 173 | rt.block_on(async { 174 | let objects = self 175 | .client 176 | .list_objects_v2() 177 | .bucket(self.bucket.clone()) 178 | .prefix(self.prefix.clone()) 179 | .send() 180 | .await 181 | .unwrap(); 182 | 183 | let to_delete: Vec<_> = objects 184 | .contents() 185 | .iter() 186 | .filter_map(|obj| { 187 | obj.key() 188 | .map(|k| ObjectIdentifier::builder().key(k).build().unwrap()) 189 | }) 190 | .collect(); 191 | 192 | if !to_delete.is_empty() { 193 | let delete = Delete::builder() 194 | .set_objects(Some(to_delete)) 195 | .build() 196 | .unwrap(); 197 | self.client 198 | .delete_objects() 199 | .bucket(self.bucket.clone()) 200 | .delete(delete) 201 | .send() 202 | .await 203 | .unwrap(); 204 | } 205 | }) 206 | } 207 | } 208 | 209 | async fn run_query(query: &str, session: SessionContext) { 210 | let df = session.sql(query).await.unwrap(); 211 | let _ = df.collect().await.unwrap(); 212 | } 213 | 214 | fn benchmark_query(c: &mut Criterion) { 215 | let rt = tokio::runtime::Runtime::new().unwrap(); 216 | let s3_fixture = S3TestFixture::new(); 217 | 218 | let mut group = c.benchmark_group("my_group"); 219 | group.sample_size(20); 220 | 221 | let session = s3_fixture.get_session().clone(); 222 | let query = " 223 | SELECT t1.*, t2.* 224 | FROM zarr_table as t1 225 | JOIN zarr_table as t2 226 | ON t1.var1 % 12 = 0 227 | AND t1.var1 < t2.var1 + 1 228 | AND t1.var1 >= t2.var1 - 1 229 | "; 230 | 231 | group.bench_function("benchmark 1", |b| { 232 | b.to_async(&rt) 233 | .iter(|| async { run_query(black_box(query), black_box(session.clone())).await }) 234 | }); 235 | 236 | let query = " 237 | SELECT * 238 | FROM zarr_table 239 | 240 | UNION ALL 241 | 242 | SELECT * 243 | FROM zarr_table 244 | "; 245 | group.bench_function("benchmark 2", |b| { 246 | b.to_async(&rt) 247 | .iter(|| async { run_query(black_box(query), black_box(session.clone())).await }) 248 | }); 249 | } 250 | 251 | criterion_group!(benches, benchmark_query); 252 | criterion_main!(benches); 253 | -------------------------------------------------------------------------------- /src/table/opener.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::fmt; 3 | use std::fmt::Formatter; 4 | use std::sync::Arc; 5 | 6 | use arrow_schema::SchemaRef; 7 | use datafusion::common::Statistics; 8 | use datafusion::datasource::listing::PartitionedFile; 9 | use datafusion::datasource::physical_plan::{ 10 | FileMeta, FileOpenFuture, FileOpener, FileScanConfig, FileSource, 11 | }; 12 | use datafusion::error::{DataFusionError, Result as DfResult}; 13 | use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; 14 | use datafusion::physical_plan::{DisplayFormatType, PhysicalExpr}; 15 | use futures::StreamExt; 16 | use object_store::ObjectStore; 17 | 18 | use super::config::ZarrTableConfig; 19 | use crate::zarr_store_opener::ZarrChunkFilter; 20 | use crate::ZarrRecordBatchStream; 21 | 22 | /// Implementation of [`FileOpener`] for zarr. 23 | pub(crate) struct ZarrOpener { 24 | config: ZarrTableConfig, 25 | n_partitions: usize, 26 | partition: usize, 27 | filter_expr: Option>, 28 | } 29 | 30 | impl ZarrOpener { 31 | fn new( 32 | config: ZarrTableConfig, 33 | n_partitions: usize, 34 | partition: usize, 35 | filter_expr: Option>, 36 | ) -> Self { 37 | Self { 38 | config, 39 | n_partitions, 40 | partition, 41 | filter_expr, 42 | } 43 | } 44 | } 45 | 46 | impl FileOpener for ZarrOpener { 47 | // We don't actually need any information about the file partitions, as those 48 | // don't really make sense for zarr. there's a high level store, and the data 49 | // is retrieved one chunk at a time, with all the logic inside the zarr stream. 50 | // There is the option to split the zarr chunks between some number of partitions, 51 | // but this is again handled inside the zarr stream. We are only implementing 52 | // this to re-use some of the datafusion functionalities. 53 | fn open(&self, _file_meta: FileMeta, _file: PartitionedFile) -> DfResult { 54 | let config = self.config.clone(); 55 | let (n_partitions, partition) = (self.n_partitions, self.partition); 56 | 57 | let filter = if let Some(filter_expr) = &self.filter_expr { 58 | Some(ZarrChunkFilter::new(filter_expr, config.get_schema_ref())?) 59 | } else { 60 | None 61 | }; 62 | 63 | let stream = Box::pin(async move { 64 | let (store, prefix) = config.get_store_pointer_and_prefix().await?; 65 | let inner_stream = ZarrRecordBatchStream::try_new( 66 | store, 67 | config.get_schema_ref(), 68 | prefix, 69 | config.get_projection(), 70 | n_partitions, 71 | partition, 72 | filter, 73 | ) 74 | .await 75 | .map_err(|e| DataFusionError::External(Box::new(e)))?; 76 | Ok(inner_stream.boxed()) 77 | }); 78 | 79 | Ok(stream) 80 | } 81 | } 82 | 83 | /// Implementation of [`FileSource`] for zarr. 84 | #[derive(Clone)] 85 | pub(crate) struct ZarrSource { 86 | config: ZarrTableConfig, 87 | n_partitions: usize, 88 | exec_plan_metrics: ExecutionPlanMetricsSet, 89 | filter_expr: Option>, 90 | } 91 | 92 | impl ZarrSource { 93 | pub(crate) fn new( 94 | config: ZarrTableConfig, 95 | n_partitions: usize, 96 | filter_expr: Option>, 97 | ) -> Self { 98 | Self { 99 | config, 100 | n_partitions, 101 | exec_plan_metrics: ExecutionPlanMetricsSet::default(), 102 | filter_expr, 103 | } 104 | } 105 | } 106 | 107 | impl FileSource for ZarrSource { 108 | // Once again, we don't really need this, it's only so that 109 | // we can re-use some stuff from datafusion. 110 | fn create_file_opener( 111 | &self, 112 | _object_store: Arc, 113 | _base_config: &FileScanConfig, 114 | partition: usize, 115 | ) -> Arc { 116 | let file_opener = ZarrOpener::new( 117 | self.config.clone(), 118 | self.n_partitions, 119 | partition, 120 | self.filter_expr.clone(), 121 | ); 122 | Arc::new(file_opener) 123 | } 124 | 125 | fn as_any(&self) -> &dyn Any { 126 | self 127 | } 128 | 129 | // We don't really need most of the below functions, since we're 130 | // barely using this struct, but they are required by the trait. 131 | fn with_batch_size(&self, _batch_size: usize) -> Arc { 132 | Arc::new(self.clone()) 133 | } 134 | 135 | fn with_schema(&self, _schema: SchemaRef) -> Arc { 136 | Arc::new(self.clone()) 137 | } 138 | 139 | fn with_projection(&self, _config: &FileScanConfig) -> Arc { 140 | Arc::new(self.clone()) 141 | } 142 | 143 | fn with_statistics(&self, _statistics: Statistics) -> Arc { 144 | Arc::new(self.clone()) 145 | } 146 | 147 | fn metrics(&self) -> &ExecutionPlanMetricsSet { 148 | &self.exec_plan_metrics 149 | } 150 | 151 | fn statistics(&self) -> DfResult { 152 | Ok(Statistics::default()) 153 | } 154 | 155 | /// String representation of file source 156 | fn file_type(&self) -> &str { 157 | "zarr" 158 | } 159 | 160 | /// Format FileType specific information 161 | fn fmt_extra(&self, _t: DisplayFormatType, _f: &mut Formatter) -> fmt::Result { 162 | Ok(()) 163 | } 164 | } 165 | 166 | #[cfg(test)] 167 | mod file_opener_tests { 168 | use std::collections::HashMap; 169 | 170 | use arrow::datatypes::Float64Type; 171 | use arrow_schema::DataType; 172 | use datafusion::datasource::listing::ListingTableUrl; 173 | use datafusion::datasource::physical_plan::{FileGroup, FileScanConfigBuilder, FileStream}; 174 | use datafusion::execution::object_store::ObjectStoreUrl; 175 | use futures_util::TryStreamExt; 176 | use object_store::local::LocalFileSystem; 177 | 178 | use super::*; 179 | use crate::table::config::ZarrTableUrl; 180 | use crate::test_utils::{ 181 | get_local_zarr_store, validate_names_and_types, validate_primitive_column, 182 | }; 183 | 184 | #[tokio::test] 185 | async fn filestream_tests() { 186 | let (wrapper, schema) = get_local_zarr_store(true, 0.0, "data_for_file_opener").await; 187 | let path = wrapper.get_store_path(); 188 | let table_url = ZarrTableUrl::ZarrStore(ListingTableUrl::parse(path).unwrap()); 189 | 190 | let zarr_config = ZarrTableConfig::new(table_url, schema.clone()); 191 | let zarr_souce = ZarrSource::new(zarr_config, 1, None); 192 | 193 | let file_groups = vec![FileGroup::new(vec![PartitionedFile::new("", 0)])]; 194 | let file_scan_config = FileScanConfigBuilder::new( 195 | ObjectStoreUrl::parse("file://").unwrap(), 196 | schema, 197 | Arc::new(zarr_souce.clone()), 198 | ) 199 | .with_file_groups(file_groups) 200 | .build(); 201 | let dummy_object_store = Arc::new(LocalFileSystem::new()); 202 | let file_opener = zarr_souce.create_file_opener(dummy_object_store, &file_scan_config, 0); 203 | 204 | let metrics = ExecutionPlanMetricsSet::default(); 205 | let file_stream = FileStream::new(&file_scan_config, 0, file_opener, &metrics).unwrap(); 206 | let records: Vec<_> = file_stream.try_collect().await.unwrap(); 207 | 208 | let target_types = HashMap::from([ 209 | ("lat".to_string(), DataType::Float64), 210 | ("lon".to_string(), DataType::Float64), 211 | ("data".to_string(), DataType::Float64), 212 | ]); 213 | validate_names_and_types(&target_types, &records[0]); 214 | assert_eq!(records.len(), 9); 215 | 216 | // the top left chunk, full 3x3 217 | validate_primitive_column::( 218 | "lat", 219 | &records[0], 220 | &[35., 35., 35., 36., 36., 36., 37., 37., 37.], 221 | ); 222 | validate_primitive_column::( 223 | "lon", 224 | &records[0], 225 | &[ 226 | -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, 227 | ], 228 | ); 229 | validate_primitive_column::( 230 | "data", 231 | &records[0], 232 | &[0.0, 1.0, 2.0, 8.0, 9.0, 10.0, 16.0, 17.0, 18.0], 233 | ); 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/table/config.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "icechunk")] 2 | use std::collections::HashMap; 3 | #[cfg(feature = "icechunk")] 4 | use std::env; 5 | use std::path::PathBuf; 6 | use std::sync::Arc; 7 | 8 | use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; 9 | use datafusion::datasource::listing::ListingTableUrl; 10 | use datafusion::error::{DataFusionError, Result as DfResult}; 11 | #[cfg(feature = "icechunk")] 12 | use icechunk::{ObjectStorage, Repository}; 13 | use object_store::aws::AmazonS3Builder; 14 | use object_store::local::LocalFileSystem; 15 | use zarrs::array::data_type::DataType as zarr_dtype; 16 | use zarrs::array::Array; 17 | use zarrs::registry::ExtensionAliases; 18 | #[cfg(feature = "icechunk")] 19 | use zarrs_icechunk::AsyncIcechunkStore; 20 | use zarrs_metadata::v3::MetadataV3; 21 | use zarrs_metadata::ArrayMetadata; 22 | use zarrs_object_store::AsyncObjectStore; 23 | use zarrs_storage::{AsyncReadableListableStorageTraits, StorePrefix}; 24 | /// A zarr table configuration. 25 | #[derive(Clone, Debug)] 26 | pub struct ZarrTableConfig { 27 | schema_ref: SchemaRef, 28 | table_url: ZarrTableUrl, 29 | projection: Option>, 30 | } 31 | 32 | impl ZarrTableConfig { 33 | pub(crate) fn new(table_url: ZarrTableUrl, schema_ref: SchemaRef) -> Self { 34 | Self { 35 | schema_ref, 36 | table_url, 37 | projection: None, 38 | } 39 | } 40 | 41 | pub(crate) async fn get_store_pointer_and_prefix( 42 | &self, 43 | ) -> DfResult<( 44 | Arc, 45 | Option, 46 | )> { 47 | self.table_url.get_store_pointer_and_prefix().await 48 | } 49 | 50 | pub(crate) fn with_projection(mut self, projection: Vec) -> Self { 51 | self.projection = Some(projection); 52 | self 53 | } 54 | 55 | pub(crate) fn get_projection(&self) -> Option> { 56 | self.projection.clone() 57 | } 58 | 59 | pub(crate) fn get_schema_ref(&self) -> SchemaRef { 60 | self.schema_ref.clone() 61 | } 62 | 63 | pub(crate) fn get_projected_schema_ref(&self) -> SchemaRef { 64 | if let Some(projection) = &self.projection { 65 | let projected_fields: Fields = projection 66 | .iter() 67 | .map(|&i| self.schema_ref.field(i).clone()) 68 | .collect(); 69 | Arc::new(Schema::new(projected_fields)) 70 | } else { 71 | self.schema_ref.clone() 72 | } 73 | } 74 | } 75 | 76 | /// We can create a table based on a directory with a supported zarr 77 | /// file/folder structure, or from an icechunk repo. 78 | #[derive(Clone, Debug)] 79 | pub(crate) enum ZarrTableUrl { 80 | ZarrStore(ListingTableUrl), 81 | #[cfg(feature = "icechunk")] 82 | IcechunkRepo(ListingTableUrl), 83 | } 84 | 85 | impl ZarrTableUrl { 86 | async fn get_store_pointer_and_prefix( 87 | &self, 88 | ) -> DfResult<( 89 | Arc, 90 | Option, 91 | )> { 92 | // the Option that is returned here requires some explanation. 93 | // for some remote stores, the full url is not used as a prefix when 94 | // writing and reading from the store. for example for aws s3, it 95 | // seems the bucket is extracted from the url, but not the rest, so 96 | // when reading from the store, you always need to provide a prefix 97 | // to get to the actual zarr store. but for local object stores, it 98 | // actually can store the prefix, to be applied when you read from 99 | // the store. so we need to sometimes return no prefix (None) and 100 | // sometimes return one (Some(prefix)). 101 | match self { 102 | // this is for the case of a directory with a zarr file structure inside. 103 | Self::ZarrStore(table_url) => match table_url.scheme() { 104 | "file" => { 105 | let path = PathBuf::from("/".to_owned() + table_url.prefix().as_ref()); 106 | let store = AsyncObjectStore::new(LocalFileSystem::new_with_prefix(path)?); 107 | Ok((Arc::new(store), None)) 108 | } 109 | "s3" => { 110 | let store = AmazonS3Builder::from_env() 111 | .with_url(table_url.get_url().as_str()) 112 | .build()?; 113 | let store = AsyncObjectStore::new(store); 114 | Ok((Arc::new(store), Some(table_url.prefix().to_string()))) 115 | } 116 | _ => Err(DataFusionError::Execution(format!( 117 | "Unsupported table url scheme {} for zarr stores", 118 | table_url.scheme() 119 | ))), 120 | }, 121 | 122 | // this is for the case of an icechunk repo. note that here we hard code 123 | // reading from the main branch, and "as of" now. 124 | #[cfg(feature = "icechunk")] 125 | Self::IcechunkRepo(table_url) => { 126 | let object_storage = match table_url.scheme() { 127 | "file" => { 128 | let path = PathBuf::from("/".to_owned() + table_url.prefix().as_ref()); 129 | ObjectStorage::new_local_filesystem(&path) 130 | .await 131 | .map_err(|e| DataFusionError::External(Box::new(e)))? 132 | } 133 | "s3" => { 134 | use icechunk::config::{S3Credentials, S3Options}; 135 | 136 | let bucket = table_url 137 | .object_store() 138 | .as_str() 139 | .replace("s3://", "") 140 | .trim_end_matches("/") 141 | .to_string(); 142 | let credentials = S3Credentials::FromEnv; 143 | let config = S3Options { 144 | region: env::var("AWS_DEFAULT_REGION").ok(), 145 | endpoint_url: None, 146 | anonymous: false, 147 | allow_http: false, 148 | force_path_style: false, 149 | network_stream_timeout_seconds: None, 150 | requester_pays: false, 151 | }; 152 | 153 | ObjectStorage::new_s3( 154 | bucket, 155 | Some(table_url.prefix().as_ref().to_string()), 156 | Some(credentials), 157 | Some(config), 158 | ) 159 | .await 160 | .map_err(|e| DataFusionError::External(Box::new(e)))? 161 | } 162 | _ => { 163 | return Err(DataFusionError::Execution(format!( 164 | "Unsupported table url scheme {} for icechunk repos", 165 | table_url.scheme() 166 | ))) 167 | } 168 | }; 169 | let repo = Repository::open(None, Arc::new(object_storage), HashMap::new()) 170 | .await 171 | .map_err(|e| DataFusionError::External(Box::new(e)))?; 172 | let session = repo 173 | .readonly_session(&icechunk::repository::VersionInfo::AsOf { 174 | branch: "main".into(), 175 | at: chrono::Utc::now(), 176 | }) 177 | .await 178 | .map_err(|e| DataFusionError::External(Box::new(e)))?; 179 | Ok((Arc::new(AsyncIcechunkStore::new(session)), None)) 180 | } 181 | } 182 | } 183 | 184 | pub(crate) async fn infer_schema(&self) -> DfResult { 185 | let (store, store_prefix) = self.get_store_pointer_and_prefix().await?; 186 | let store_prefix = store_prefix 187 | .as_ref() 188 | .map_or("".into(), |p| p.to_owned() + "/"); 189 | 190 | let prefixes = store 191 | .list_prefix( 192 | &StorePrefix::new(store_prefix.to_owned()) 193 | .map_err(|e| DataFusionError::External(Box::new(e)))?, 194 | ) 195 | .await 196 | .map_err(|e| DataFusionError::External(Box::new(e)))?; 197 | let mut fields = Vec::with_capacity(prefixes.len()); 198 | 199 | for prefix in prefixes { 200 | if prefix.as_str().contains("zarr.json") { 201 | let field_name = prefix.parent(); 202 | if field_name.as_str() == "" { 203 | continue; 204 | } 205 | 206 | // this is ugly, but I'm not sure there's a better way 207 | // to extract the array name... 208 | let field_name_prefix = field_name.parent(); 209 | let mut field_name = field_name 210 | .as_str() 211 | .strip_suffix("/") 212 | .ok_or(DataFusionError::Execution( 213 | "Invalid directory name in zarr store".into(), 214 | ))? 215 | .to_string(); 216 | let read_prefix = field_name.clone(); 217 | if let Some(field_name_prefix) = field_name_prefix { 218 | let to_remove = field_name_prefix.as_str(); 219 | field_name = field_name.replace(to_remove, ""); 220 | } 221 | 222 | let arr = Array::async_open(store.clone(), &("/".to_owned() + &read_prefix)) 223 | .await 224 | .map_err(|e| DataFusionError::External(Box::new(e)))?; 225 | let meta = match arr.metadata() { 226 | ArrayMetadata::V3(meta) => Ok(meta), 227 | _ => Err(DataFusionError::Execution( 228 | "Only Zarr v3 metadata is supported".into(), 229 | )), 230 | }?; 231 | 232 | fields.push(Field::new( 233 | field_name, 234 | get_schema_type(&meta.data_type)?, 235 | true, 236 | )); 237 | } 238 | } 239 | 240 | Ok(Arc::new(Schema::new(Fields::from(fields)))) 241 | } 242 | } 243 | 244 | fn get_schema_type(value: &MetadataV3) -> DfResult { 245 | let data_type = zarr_dtype::from_metadata(value, &ExtensionAliases::default()) 246 | .map_err(|e| DataFusionError::External(Box::new(e)))?; 247 | 248 | match data_type { 249 | zarr_dtype::Bool => Ok(DataType::Boolean), 250 | zarr_dtype::UInt8 => Ok(DataType::UInt8), 251 | zarr_dtype::UInt16 => Ok(DataType::UInt16), 252 | zarr_dtype::UInt32 => Ok(DataType::UInt32), 253 | zarr_dtype::UInt64 => Ok(DataType::UInt64), 254 | zarr_dtype::Int8 => Ok(DataType::Int8), 255 | zarr_dtype::Int16 => Ok(DataType::Int16), 256 | zarr_dtype::Int32 => Ok(DataType::Int32), 257 | zarr_dtype::Int64 => Ok(DataType::Int64), 258 | zarr_dtype::Float32 => Ok(DataType::Float32), 259 | zarr_dtype::Float64 => Ok(DataType::Float64), 260 | zarr_dtype::String => Ok(DataType::Utf8), 261 | _ => Err(DataFusionError::Execution(format!( 262 | "Unsupported type {value} from zarr metadata" 263 | ))), 264 | } 265 | } 266 | 267 | #[cfg(test)] 268 | mod zarr_config_tests { 269 | use super::*; 270 | #[cfg(feature = "icechunk")] 271 | use crate::test_utils::get_local_icechunk_repo; 272 | use crate::test_utils::get_local_zarr_store; 273 | 274 | #[tokio::test] 275 | async fn schema_inference_tests() { 276 | // local zarr directory. 277 | let (wrapper, schema) = get_local_zarr_store(true, 0.0, "data_for_config_dir").await; 278 | let path = wrapper.get_store_path(); 279 | 280 | let table_url = ListingTableUrl::parse(path).unwrap(); 281 | let zarr_table_url = ZarrTableUrl::ZarrStore(table_url); 282 | let inferred_schema = zarr_table_url.infer_schema().await.unwrap(); 283 | assert_eq!(inferred_schema, schema); 284 | 285 | // local icechunk repo. 286 | #[cfg(feature = "icechunk")] 287 | { 288 | let (wrapper, schema) = 289 | get_local_icechunk_repo(true, 0.0, "data_for_config_repo").await; 290 | let path = wrapper.get_store_path(); 291 | 292 | let table_url = ListingTableUrl::parse(path).unwrap(); 293 | let zarr_table_url = ZarrTableUrl::IcechunkRepo(table_url); 294 | let inferred_schema = zarr_table_url.infer_schema().await.unwrap(); 295 | assert_eq!(inferred_schema, schema); 296 | } 297 | } 298 | } 299 | -------------------------------------------------------------------------------- /src/zarr_store_opener/filter.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeSet; 2 | use std::sync::Arc; 3 | 4 | use arrow_array::{BooleanArray, RecordBatch}; 5 | use arrow_schema::{ArrowError, SchemaRef}; 6 | use datafusion::common::cast::as_boolean_array; 7 | use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; 8 | use datafusion::common::Result as DfResult; 9 | use datafusion::physical_expr::expressions::Column; 10 | use datafusion::physical_expr::utils::reassign_predicate_columns; 11 | use datafusion::physical_expr::{split_conjunction, PhysicalExpr}; 12 | use itertools::Itertools; 13 | 14 | /// A predicate operating on [`RecordBatch`]. 15 | pub trait ZarrArrowPredicate: Send + 'static { 16 | /// Evaluate this predicate for the given [`RecordBatch`]. Rows that are `true` 17 | /// in the returned [`BooleanArray`] satisfy the predicate condition, whereas those 18 | /// that are `false` do not. The method should not return any `Null` values. 19 | /// Note that the [`RecordBatch`] is passed by reference and not consumed by 20 | /// the method. 21 | fn evaluate(&self, batch: &RecordBatch) -> Result; 22 | } 23 | 24 | struct ZarrFilterExpression { 25 | physical_expr: Arc, 26 | filter_schema: SchemaRef, 27 | required_columns: Vec, 28 | } 29 | 30 | impl ZarrFilterExpression { 31 | fn new(physical_expr: Arc, table_schema: SchemaRef) -> DfResult { 32 | // this part is to make sure that the indices for each column in the 33 | // predicate match the columns in the filter schema. the physical 34 | // expressions are created from the full table schema initally, but 35 | // the record batches will come in with the filter schema, that's why 36 | // this step is needed. 37 | let required_columns = pushdown_columns(&physical_expr, table_schema.clone())?; 38 | let filter_schema = table_schema.project(&required_columns)?; 39 | let physical_expr = reassign_predicate_columns(physical_expr, &filter_schema, true)?; 40 | 41 | Ok(Self { 42 | physical_expr, 43 | filter_schema: Arc::new(filter_schema), 44 | required_columns, 45 | }) 46 | } 47 | } 48 | 49 | impl ZarrArrowPredicate for ZarrFilterExpression { 50 | fn evaluate(&self, batch: &RecordBatch) -> Result { 51 | // if there was only one filter expression in the full chunk filter, 52 | // the record batch would come in with the right schema all the time 53 | // (because the caller would first check what schema is required for 54 | // filter, only evaluate those columns and pass in that record batch). 55 | // but there could be multiple expressions in the final, full chunk 56 | // filter, so in practice the record batch is a superset of what each 57 | // individual filter needs, hence we need to project it to make sure 58 | // field/column indices match between the expression and the record 59 | // batch schema. 60 | let batch = batch.project( 61 | &(self 62 | .filter_schema 63 | .fields() 64 | .iter() 65 | .map(|f| batch.schema().index_of(f.name())) 66 | .collect::, _>>()?[..]), 67 | )?; 68 | 69 | match self 70 | .physical_expr 71 | .evaluate(&batch) 72 | .and_then(|v| v.into_array(batch.num_rows())) 73 | { 74 | Ok(array) => { 75 | let bool_arr = as_boolean_array(&array)?.clone(); 76 | Ok(bool_arr) 77 | } 78 | Err(e) => Err(ArrowError::ComputeError(format!( 79 | "Error evaluating filter predicate: {e:?}" 80 | ))), 81 | } 82 | } 83 | } 84 | 85 | /// A struct that implements TreeNodeRewriter to traverse a PhysicalExpr tree structure 86 | /// to determine which columns are required to evaluate it. 87 | struct PushdownChecker { 88 | // Indices into the table schema of the columns required to evaluate the expression 89 | required_columns: BTreeSet, 90 | table_schema: SchemaRef, 91 | } 92 | 93 | // Note that the zarr case is simpler than the other cases that support projected 94 | // columns or partition columns. There is not much we need to check here, columns 95 | // are just columns, we just need to check which columns are in which predicate. 96 | impl PushdownChecker { 97 | fn new(table_schema: SchemaRef) -> Self { 98 | Self { 99 | required_columns: BTreeSet::default(), 100 | table_schema, 101 | } 102 | } 103 | } 104 | 105 | impl TreeNodeVisitor<'_> for PushdownChecker { 106 | type Node = Arc; 107 | 108 | fn f_down(&mut self, node: &Self::Node) -> DfResult { 109 | if let Some(column) = node.as_any().downcast_ref::() { 110 | let idx = self.table_schema.index_of(column.name())?; 111 | self.required_columns.insert(idx); 112 | } 113 | 114 | Ok(TreeNodeRecursion::Continue) 115 | } 116 | } 117 | 118 | fn pushdown_columns(expr: &Arc, table_schema: SchemaRef) -> DfResult> { 119 | let mut checker = PushdownChecker::new(table_schema); 120 | expr.visit(&mut checker)?; 121 | Ok(checker.required_columns.into_iter().collect()) 122 | } 123 | 124 | /// A collection of one or more objects that implement [`ZarrArrowPredicate`]. The way 125 | /// filters are used for zarr store is by determining whether or not the a chunk needs to be 126 | /// read based on the predicate. First, only the columns needed in the predicate are read, 127 | /// then the predicate is evaluated, and if there is a least one row that satistifes the 128 | /// condition, the other variables that we requested are read. 129 | pub struct ZarrChunkFilter { 130 | /// A list of [`ZarrArrowPredicate`] 131 | predicates: Vec>, 132 | schema_ref: SchemaRef, 133 | } 134 | 135 | impl ZarrChunkFilter { 136 | pub fn new(expr: &Arc, table_schema: SchemaRef) -> Result { 137 | let predicate_exprs = split_conjunction(expr); 138 | let mut predicates: Vec> = 139 | Vec::with_capacity(predicate_exprs.len()); 140 | let mut schema_indices: Vec = Vec::new(); 141 | 142 | // we don't bother reorganizing filters to start with the cheaper ones 143 | // before we do the more expensive ones. in terms of the amount of data 144 | // read, it's the same for all the chunks. some operations might be 145 | // cheaper to check (computationally), not sure how e.g. the parquet 146 | // case handles this, I might revisit later to optimize things a bit. 147 | for pred_expr in predicate_exprs { 148 | let filter_expr = ZarrFilterExpression::new(pred_expr.clone(), table_schema.clone())?; 149 | schema_indices.extend(filter_expr.required_columns.clone()); 150 | predicates.push(Box::new(filter_expr)); 151 | } 152 | 153 | schema_indices.sort(); 154 | schema_indices.dedup(); 155 | let table_schema = table_schema.project(&schema_indices)?; 156 | 157 | Ok(Self { 158 | predicates, 159 | schema_ref: Arc::new(table_schema), 160 | }) 161 | } 162 | 163 | pub fn schema_ref(&self) -> &SchemaRef { 164 | &self.schema_ref 165 | } 166 | 167 | /// Applies all the filters in the chunk filter and returns true only 168 | /// if all the filters return true for at least one row in the record 169 | /// batch. 170 | pub fn evaluate(&self, rec_batch: &RecordBatch) -> Result { 171 | let mut bool_arr: Option = None; 172 | for predicate in self.predicates.iter() { 173 | let mask = predicate.evaluate(rec_batch)?; 174 | if let Some(old_bool_arr) = bool_arr { 175 | bool_arr = Some(BooleanArray::from( 176 | old_bool_arr 177 | .iter() 178 | .zip(mask.iter()) 179 | .map(|(x, y)| x.unwrap() && y.unwrap()) 180 | .collect_vec(), 181 | )); 182 | } else { 183 | bool_arr = Some(mask); 184 | } 185 | } 186 | 187 | if let Some(bool_arr) = bool_arr { 188 | Ok(bool_arr.true_count() > 0) 189 | } else { 190 | Ok(true) 191 | } 192 | } 193 | } 194 | 195 | #[cfg(test)] 196 | mod filter_tests { 197 | use std::sync::Arc; 198 | 199 | use arrow_array::{Int32Array, RecordBatch}; 200 | use arrow_schema::{DataType, Field, Schema}; 201 | use datafusion::logical_expr::Operator; 202 | use datafusion::physical_expr::expressions::col; 203 | use datafusion::physical_plan::expressions::binary; 204 | 205 | use super::*; 206 | 207 | #[test] 208 | fn test_single_filter() { 209 | let schema = Schema::new(vec![ 210 | Field::new("a", DataType::Int32, false), 211 | Field::new("b", DataType::Int32, false), 212 | Field::new("c", DataType::Int32, false), 213 | ]); 214 | let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); 215 | let b = Int32Array::from(vec![3, 3, 3, 3, 3, 3]); 216 | let c = Int32Array::from(vec![4, 4, 4, 4, 4, 4]); 217 | let batch = RecordBatch::try_new( 218 | Arc::new(schema.clone()), 219 | vec![Arc::new(a), Arc::new(b), Arc::new(c)], 220 | ) 221 | .unwrap(); 222 | 223 | // expression: "a > b" 224 | let expr = binary( 225 | col("a", &schema).unwrap(), 226 | Operator::Gt, 227 | col("b", &schema).unwrap(), 228 | &schema, 229 | ) 230 | .unwrap(); 231 | 232 | let filter = ZarrFilterExpression::new(expr, Arc::new(schema.clone())).unwrap(); 233 | let mask = filter.evaluate(&batch).unwrap(); 234 | 235 | assert_eq!(mask, vec![false, false, false, true, true, true].into()); 236 | 237 | // this test in particular is important because it applies the filter 238 | // to a record batch where the data is ordered differently than in the 239 | // physical expression for filter. 240 | // expression: "c < a" 241 | let expr = binary( 242 | col("c", &schema).unwrap(), 243 | Operator::Lt, 244 | col("a", &schema).unwrap(), 245 | &schema, 246 | ) 247 | .unwrap(); 248 | 249 | let filter = ZarrFilterExpression::new(expr, Arc::new(schema)).unwrap(); 250 | let mask = filter.evaluate(&batch).unwrap(); 251 | 252 | assert_eq!(mask, vec![false, false, false, false, true, true].into()); 253 | } 254 | 255 | #[test] 256 | fn test_chunk_filter() { 257 | let schema = Schema::new(vec![ 258 | Field::new("a", DataType::Int32, false), 259 | Field::new("b", DataType::Int32, false), 260 | Field::new("c", DataType::Int32, false), 261 | Field::new("d", DataType::Int32, false), 262 | ]); 263 | let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); 264 | let b = Int32Array::from(vec![3, 3, 3, 3, 3, 3]); 265 | let c = Int32Array::from(vec![1, 1, 2, 2, 4, 4]); 266 | let d = Int32Array::from(vec![2, 3, 1, 1, 1, 1]); 267 | let batch = RecordBatch::try_new( 268 | Arc::new(schema.clone()), 269 | vec![Arc::new(a), Arc::new(b), Arc::new(c), Arc::new(d)], 270 | ) 271 | .unwrap(); 272 | 273 | // expression: "b < c AND a < d" 274 | let expr = binary( 275 | binary( 276 | col("b", &schema).unwrap(), 277 | Operator::Lt, 278 | col("c", &schema).unwrap(), 279 | &schema, 280 | ) 281 | .unwrap(), 282 | Operator::And, 283 | binary( 284 | col("a", &schema).unwrap(), 285 | Operator::Lt, 286 | col("d", &schema).unwrap(), 287 | &schema, 288 | ) 289 | .unwrap(), 290 | &schema, 291 | ) 292 | .unwrap(); 293 | 294 | let chunk_filter = ZarrChunkFilter::new(&Arc::new(expr), Arc::new(schema.clone())).unwrap(); 295 | let filter_passed = chunk_filter.evaluate(&batch).unwrap(); 296 | assert!(!filter_passed); 297 | 298 | // expression: "b < c OR a < d" 299 | let expr = binary( 300 | binary( 301 | col("b", &schema).unwrap(), 302 | Operator::Lt, 303 | col("c", &schema).unwrap(), 304 | &schema, 305 | ) 306 | .unwrap(), 307 | Operator::Or, 308 | binary( 309 | col("a", &schema).unwrap(), 310 | Operator::Lt, 311 | col("d", &schema).unwrap(), 312 | &schema, 313 | ) 314 | .unwrap(), 315 | &schema, 316 | ) 317 | .unwrap(); 318 | 319 | let chunk_filter = ZarrChunkFilter::new(&Arc::new(expr), Arc::new(schema.clone())).unwrap(); 320 | let filter_passed = chunk_filter.evaluate(&batch).unwrap(); 321 | assert!(filter_passed); 322 | 323 | let expr = binary( 324 | col("b", &schema).unwrap(), 325 | Operator::Lt, 326 | col("c", &schema).unwrap(), 327 | &schema, 328 | ) 329 | .unwrap(); 330 | let chunk_filter = ZarrChunkFilter::new(&Arc::new(expr), Arc::new(schema.clone())).unwrap(); 331 | assert_eq!( 332 | vec!["b", "c"], 333 | chunk_filter 334 | .schema_ref() 335 | .fields() 336 | .iter() 337 | .map(|f| f.name()) 338 | .collect::>() 339 | ); 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // Licensed to the Apache Software Foundation (ASF) under one 2 | // or more contributor license agreements. See the NOTICE file 3 | // distributed with this work for additional information 4 | // regarding copyright ownership. The ASF licenses this file 5 | // to you under the Apache License, Version 2.0 (the 6 | // "License"); you may not use this file except in compliance 7 | // with the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, 12 | // software distributed under the License is distributed on an 13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | // KIND, either express or implied. See the License for the 15 | // specific language governing permissions and limitations 16 | // under the License. 17 | 18 | pub mod errors; 19 | pub mod table; 20 | pub mod zarr_store_opener; 21 | 22 | pub use zarr_store_opener::ZarrRecordBatchStream; 23 | 24 | #[cfg(test)] 25 | mod test_utils { 26 | use std::collections::HashMap; 27 | use std::fmt::Debug; 28 | use std::fs; 29 | use std::path::PathBuf; 30 | use std::sync::Arc; 31 | 32 | use arrow::buffer::ScalarBuffer; 33 | use arrow_array::cast::AsArray; 34 | use arrow_array::types::*; 35 | use arrow_array::RecordBatch; 36 | use arrow_schema::{DataType as ArrowDataType, Field, Schema, SchemaRef}; 37 | use futures::executor::block_on; 38 | #[cfg(feature = "icechunk")] 39 | use icechunk::{ObjectStorage, Repository}; 40 | use itertools::enumerate; 41 | use ndarray::{Array, Array1, Array2}; 42 | use object_store::local::LocalFileSystem; 43 | use walkdir::WalkDir; 44 | use zarrs::array::{codec, ArrayBuilder, DataType, FillValue}; 45 | use zarrs::array_subset::ArraySubset; 46 | #[cfg(feature = "icechunk")] 47 | use zarrs_icechunk::AsyncIcechunkStore; 48 | use zarrs_object_store::AsyncObjectStore; 49 | use zarrs_storage::{ 50 | AsyncReadableWritableListableStorageTraits, AsyncWritableStorageTraits, StorePrefix, 51 | }; 52 | 53 | // convenience class to make sure the local zarr stores get cleanup 54 | // after we're done running a test. 55 | pub(crate) struct LocalZarrStoreWrapper { 56 | store: Arc>, 57 | path: PathBuf, 58 | } 59 | 60 | impl LocalZarrStoreWrapper { 61 | pub(crate) fn new(store_name: String) -> Self { 62 | if store_name.is_empty() { 63 | panic!("name for test zarr store cannot be empty!") 64 | } 65 | 66 | let p = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(store_name); 67 | fs::create_dir(p.clone()).unwrap(); 68 | let store = AsyncObjectStore::new(LocalFileSystem::new_with_prefix(p.clone()).unwrap()); 69 | Self { 70 | store: Arc::new(store), 71 | path: p, 72 | } 73 | } 74 | 75 | pub(crate) fn get_store(&self) -> Arc> { 76 | self.store.clone() 77 | } 78 | 79 | pub(crate) fn get_store_path(&self) -> String { 80 | self.path.as_os_str().to_str().unwrap().into() 81 | } 82 | } 83 | 84 | impl Drop for LocalZarrStoreWrapper { 85 | fn drop(&mut self) { 86 | let prefix = StorePrefix::new("").unwrap(); 87 | block_on(self.store.erase_prefix(&prefix)).unwrap(); 88 | 89 | while fs::exists(self.path.clone()).unwrap() { 90 | for d in WalkDir::new(self.path.clone()) { 91 | let _ = fs::remove_dir(d.unwrap().path()); 92 | } 93 | } 94 | } 95 | } 96 | 97 | // convenience class to make sure the local icechunk repos get cleanup 98 | // after we're done running a test. 99 | #[cfg(feature = "icechunk")] 100 | pub(crate) struct LocalIcechunkRepoWrapper { 101 | store: Arc, 102 | path: PathBuf, 103 | } 104 | 105 | #[cfg(feature = "icechunk")] 106 | impl LocalIcechunkRepoWrapper { 107 | pub(crate) async fn new(store_name: String) -> Self { 108 | if store_name.is_empty() { 109 | panic!("name for test icechunk repo cannot be empty!") 110 | } 111 | let p = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(store_name); 112 | fs::create_dir(p.clone()).unwrap(); 113 | let repo = Repository::create( 114 | None, 115 | Arc::new(ObjectStorage::new_local_filesystem(&p).await.unwrap()), 116 | HashMap::new(), 117 | ) 118 | .await 119 | .unwrap(); 120 | let session = repo.writable_session("main").await.unwrap(); 121 | Self { 122 | store: Arc::new(AsyncIcechunkStore::new(session)), 123 | path: p, 124 | } 125 | } 126 | 127 | pub(crate) fn get_store(&self) -> Arc { 128 | self.store.clone() 129 | } 130 | 131 | pub(crate) fn get_store_path(&self) -> String { 132 | self.path.to_str().unwrap().into() 133 | } 134 | } 135 | 136 | // TODO: Implement Drop. Just not sure how to do this cleanly yet. 137 | #[cfg(feature = "icechunk")] 138 | impl Drop for LocalIcechunkRepoWrapper { 139 | fn drop(&mut self) { 140 | if !self 141 | .path 142 | .to_str() 143 | .unwrap() 144 | .contains(env!("CARGO_MANIFEST_DIR")) 145 | { 146 | panic!("should not be deleting this icechunk repo!") 147 | } 148 | 149 | //delete the different icechunk repo components one at a time. 150 | fs::remove_dir_all(self.path.join("manifests")).unwrap(); 151 | fs::remove_dir_all(self.path.join("refs")).unwrap(); 152 | fs::remove_dir_all(self.path.join("snapshots")).unwrap(); 153 | fs::remove_dir_all(self.path.join("transactions")).unwrap(); 154 | fs::remove_dir(self.path.clone()).unwrap(); 155 | } 156 | } 157 | 158 | // helpers to create some test data on the fly. 159 | fn get_lz4_compressor() -> codec::BloscCodec { 160 | codec::BloscCodec::new( 161 | codec::bytes_to_bytes::blosc::BloscCompressor::LZ4, 162 | 5.try_into().unwrap(), 163 | Some(0), 164 | codec::bytes_to_bytes::blosc::BloscShuffleMode::NoShuffle, 165 | Some(1), 166 | ) 167 | .unwrap() 168 | } 169 | 170 | pub(crate) async fn write_1d_float_array( 171 | data: Vec, 172 | fillvalue: f64, 173 | shape: u64, 174 | chunk: u64, 175 | store: Arc, 176 | path: &str, 177 | dimensions: Option>, 178 | ) { 179 | let mut array_builder = ArrayBuilder::new( 180 | vec![shape], 181 | [chunk], 182 | DataType::Float64, 183 | FillValue::from(fillvalue), 184 | ); 185 | let mut builder_ref = &mut array_builder; 186 | let codec = get_lz4_compressor(); 187 | builder_ref = builder_ref.bytes_to_bytes_codecs(vec![Arc::new(codec)]); 188 | if let Some(dimensions) = dimensions { 189 | builder_ref = builder_ref.dimension_names(dimensions.into()); 190 | } 191 | 192 | let arr = builder_ref.build(store, path).unwrap(); 193 | arr.async_store_metadata().await.unwrap(); 194 | 195 | let arr_data: Array1 = Array::from_vec(data) 196 | .into_shape_with_order(shape as usize) 197 | .unwrap(); 198 | arr.async_store_array_subset_ndarray(&[0], arr_data) 199 | .await 200 | .unwrap(); 201 | } 202 | 203 | pub(crate) async fn write_2d_float_array( 204 | data: Option>, 205 | fillvalue: f64, 206 | shape: (u64, u64), 207 | chunk: (u64, u64), 208 | store: Arc, 209 | path: &str, 210 | dimensions: Option>, 211 | ) { 212 | let mut array_builder = ArrayBuilder::new( 213 | vec![shape.0, shape.1], 214 | [chunk.0, chunk.1], 215 | DataType::Float64, 216 | FillValue::from(fillvalue), 217 | ); 218 | 219 | let mut builder_ref = &mut array_builder; 220 | let codec = get_lz4_compressor(); 221 | builder_ref = builder_ref.bytes_to_bytes_codecs(vec![Arc::new(codec)]); 222 | if let Some(dimensions) = dimensions { 223 | builder_ref = builder_ref.dimension_names(dimensions.into()); 224 | } 225 | 226 | let arr = builder_ref.build(store, path).unwrap(); 227 | arr.async_store_metadata().await.unwrap(); 228 | 229 | if let Some(data) = data { 230 | let arr_data: Array2 = Array::from_vec(data) 231 | .into_shape_with_order((shape.0 as usize, shape.1 as usize)) 232 | .unwrap(); 233 | arr.async_store_array_subset_ndarray( 234 | ArraySubset::new_with_ranges(&[0..shape.0, 0..shape.1]).start(), 235 | arr_data, 236 | ) 237 | .await 238 | .unwrap(); 239 | } 240 | } 241 | 242 | // helpers to validate test data. 243 | pub(crate) fn validate_primitive_column(col_name: &str, rec: &RecordBatch, targets: &[U]) 244 | where 245 | T: ArrowPrimitiveType, 246 | [U]: AsRef<[::Native]>, 247 | U: Debug, 248 | { 249 | let mut matched = false; 250 | for (idx, col) in enumerate(rec.schema().fields.iter()) { 251 | if col.name().as_str() == col_name { 252 | assert_eq!(rec.column(idx).as_primitive::().values(), targets); 253 | matched = true; 254 | } 255 | } 256 | assert!(matched); 257 | } 258 | 259 | pub(crate) fn validate_names_and_types( 260 | targets: &HashMap, 261 | rec: &RecordBatch, 262 | ) { 263 | let mut target_cols: Vec<&String> = targets.keys().collect(); 264 | let schema = rec.schema(); 265 | let from_rec: Vec<&String> = schema.fields.iter().map(|f| f.name()).collect(); 266 | 267 | target_cols.sort(); 268 | assert_eq!(from_rec, target_cols); 269 | 270 | for field in schema.fields.iter() { 271 | assert_eq!(field.data_type(), targets.get(field.name()).unwrap()); 272 | } 273 | } 274 | 275 | pub(crate) fn extract_col(col_name: &str, rec_batch: &RecordBatch) -> ScalarBuffer 276 | where 277 | T: ArrowPrimitiveType, 278 | { 279 | rec_batch 280 | .column_by_name(col_name) 281 | .unwrap() 282 | .as_primitive::() 283 | .values() 284 | .clone() 285 | } 286 | 287 | async fn write_lat_lon_data_to_store( 288 | store: Arc, 289 | write_data: bool, 290 | fillvalue: f64, 291 | ) { 292 | let lats = vec![35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0]; 293 | write_1d_float_array( 294 | lats, 295 | 0.0, 296 | 8, 297 | 3, 298 | store.clone(), 299 | "/lat", 300 | Some(["lat".into()].to_vec()), 301 | ) 302 | .await; 303 | 304 | let lons = vec![ 305 | -120.0, -119.0, -118.0, -117.0, -116.0, -115.0, -114.0, -113.0, 306 | ]; 307 | write_1d_float_array( 308 | lons, 309 | 0.0, 310 | 8, 311 | 3, 312 | store.clone(), 313 | "/lon", 314 | Some(["lon".into()].to_vec()), 315 | ) 316 | .await; 317 | 318 | let data: Option> = if write_data { 319 | Some((0..64).map(|i| i as f64).collect()) 320 | } else { 321 | None 322 | }; 323 | write_2d_float_array( 324 | data, 325 | fillvalue, 326 | (8, 8), 327 | (3, 3), 328 | store.clone(), 329 | "/data", 330 | Some(["lat".into(), "lon".into()].to_vec()), 331 | ) 332 | .await; 333 | } 334 | 335 | async fn write_mixed_dims_lat_lon_data_to_store( 336 | store: Arc, 337 | fillvalue: f64, 338 | ) { 339 | let lats = [ 340 | vec![35.0; 8], 341 | vec![36.0; 8], 342 | vec![37.0; 8], 343 | vec![38.0; 8], 344 | vec![39.0; 8], 345 | vec![40.0; 8], 346 | vec![41.0; 8], 347 | vec![42.0; 8], 348 | ] 349 | .concat(); 350 | write_2d_float_array( 351 | Some(lats), 352 | 0.0, 353 | (8, 8), 354 | (3, 3), 355 | store.clone(), 356 | "/lat", 357 | Some(["lat".into(), "lon".into()].to_vec()), 358 | ) 359 | .await; 360 | 361 | let lons = vec![ 362 | -120.0, -119.0, -118.0, -117.0, -116.0, -115.0, -114.0, -113.0, 363 | ]; 364 | write_1d_float_array( 365 | lons, 366 | 0.0, 367 | 8, 368 | 3, 369 | store.clone(), 370 | "/lon", 371 | Some(["lon".into()].to_vec()), 372 | ) 373 | .await; 374 | 375 | let data = (0..64).map(|i| i as f64).collect(); 376 | write_2d_float_array( 377 | Some(data), 378 | fillvalue, 379 | (8, 8), 380 | (3, 3), 381 | store.clone(), 382 | "/data", 383 | Some(["lat".into(), "lon".into()].to_vec()), 384 | ) 385 | .await; 386 | } 387 | 388 | pub(crate) async fn get_local_zarr_store( 389 | write_data: bool, 390 | fillvalue: f64, 391 | dir_name: &str, 392 | ) -> (LocalZarrStoreWrapper, SchemaRef) { 393 | let wrapper = LocalZarrStoreWrapper::new(dir_name.into()); 394 | let store = wrapper.get_store(); 395 | 396 | write_lat_lon_data_to_store(store, write_data, fillvalue).await; 397 | let schema = Arc::new(Schema::new(vec![ 398 | Field::new("data", ArrowDataType::Float64, true), 399 | Field::new("lat", ArrowDataType::Float64, true), 400 | Field::new("lon", ArrowDataType::Float64, true), 401 | ])); 402 | 403 | (wrapper, schema) 404 | } 405 | 406 | pub(crate) async fn get_local_zarr_store_mix_dims( 407 | fillvalue: f64, 408 | dir_name: &str, 409 | ) -> (LocalZarrStoreWrapper, SchemaRef) { 410 | let wrapper = LocalZarrStoreWrapper::new(dir_name.into()); 411 | let store = wrapper.get_store(); 412 | 413 | write_mixed_dims_lat_lon_data_to_store(store, fillvalue).await; 414 | let schema = Arc::new(Schema::new(vec![ 415 | Field::new("data", ArrowDataType::Float64, true), 416 | Field::new("lat", ArrowDataType::Float64, true), 417 | Field::new("lon", ArrowDataType::Float64, true), 418 | ])); 419 | 420 | (wrapper, schema) 421 | } 422 | 423 | #[cfg(feature = "icechunk")] 424 | pub(crate) async fn get_local_icechunk_repo( 425 | write_data: bool, 426 | fillvalue: f64, 427 | dir_name: &str, 428 | ) -> (LocalIcechunkRepoWrapper, SchemaRef) { 429 | let wrapper = LocalIcechunkRepoWrapper::new(dir_name.into()).await; 430 | let store = wrapper.get_store(); 431 | 432 | write_lat_lon_data_to_store(store.clone(), write_data, fillvalue).await; 433 | let _ = store 434 | .session() 435 | .write() 436 | .await 437 | .commit("some test data", None) 438 | .await 439 | .unwrap(); 440 | let schema = Arc::new(Schema::new(vec![ 441 | Field::new("data", ArrowDataType::Float64, true), 442 | Field::new("lat", ArrowDataType::Float64, true), 443 | Field::new("lon", ArrowDataType::Float64, true), 444 | ])); 445 | 446 | (wrapper, schema) 447 | } 448 | } 449 | -------------------------------------------------------------------------------- /src/table/table_provider.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::fmt::Debug; 3 | use std::sync::Arc; 4 | 5 | use async_trait::async_trait; 6 | use datafusion::arrow::datatypes::{Schema, SchemaRef}; 7 | use datafusion::catalog::{Session, TableProviderFactory}; 8 | use datafusion::common::ToDFSchema; 9 | use datafusion::datasource::listing::ListingTableUrl; 10 | use datafusion::datasource::{TableProvider, TableType}; 11 | use datafusion::error::{DataFusionError, Result as DfResult}; 12 | use datafusion::logical_expr::utils::conjunction; 13 | use datafusion::logical_expr::{CreateExternalTable, Expr, TableProviderFilterPushDown}; 14 | use datafusion::physical_expr::create_physical_expr; 15 | use datafusion::physical_plan::ExecutionPlan; 16 | 17 | use super::config::ZarrTableConfig; 18 | use super::scanner::ZarrScan; 19 | use crate::table::config::ZarrTableUrl; 20 | 21 | /// The table provider for zarr stores. 22 | pub struct ZarrTable { 23 | table_config: ZarrTableConfig, 24 | } 25 | 26 | impl Debug for ZarrTable { 27 | fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 28 | Ok(()) 29 | } 30 | } 31 | 32 | impl ZarrTable { 33 | pub fn new(table_config: ZarrTableConfig) -> Self { 34 | Self { table_config } 35 | } 36 | 37 | pub async fn from_path(path: String) -> Self { 38 | let table_url = ListingTableUrl::parse(path).unwrap(); 39 | // TODO(alxmrs): Figure out how to optionally support icechunk 40 | let zarr_url = ZarrTableUrl::ZarrStore(table_url); 41 | let schema = zarr_url.infer_schema().await.unwrap(); 42 | let table_config = ZarrTableConfig::new(zarr_url, schema); 43 | Self { table_config } 44 | } 45 | } 46 | 47 | #[async_trait] 48 | impl TableProvider for ZarrTable { 49 | fn as_any(&self) -> &dyn Any { 50 | self 51 | } 52 | 53 | fn schema(&self) -> SchemaRef { 54 | self.table_config.get_schema_ref() 55 | } 56 | 57 | fn table_type(&self) -> TableType { 58 | TableType::Base 59 | } 60 | 61 | // there's no projected columns or partitions with the zarr data, 62 | // so really all we have are arrays that are present in all the data 63 | // chunks. there's not much to check here, we do use the filter 64 | // pushdown to avoid reading entire chunk, so pretty much all the 65 | // available arrays can be used as Inexact filters. 66 | fn supports_filters_pushdown( 67 | &self, 68 | filters: &[&Expr], 69 | ) -> datafusion::error::Result> { 70 | Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) 71 | } 72 | 73 | async fn scan( 74 | &self, 75 | state: &dyn Session, 76 | projection: Option<&Vec>, 77 | filters: &[Expr], 78 | _limit: Option, 79 | ) -> datafusion::error::Result> { 80 | let mut filters_physical_expr = None; 81 | if let Some(filters) = conjunction(filters.to_vec()) { 82 | filters_physical_expr = Some(create_physical_expr( 83 | &filters, 84 | &self.table_config.get_schema_ref().to_dfschema()?, 85 | state.execution_props(), 86 | )?); 87 | } 88 | 89 | let mut config = self.table_config.clone(); 90 | if let Some(proj) = projection { 91 | config = config.with_projection(proj.to_vec()); 92 | } 93 | let scanner = ZarrScan::new(config, filters_physical_expr); 94 | 95 | Ok(Arc::new(scanner)) 96 | } 97 | } 98 | 99 | /// The factory for the zarr table. 100 | #[derive(Debug)] 101 | pub struct ZarrTableFactory {} 102 | 103 | #[async_trait] 104 | impl TableProviderFactory for ZarrTableFactory { 105 | async fn create( 106 | &self, 107 | _state: &dyn Session, 108 | cmd: &CreateExternalTable, 109 | ) -> DfResult> { 110 | let table_url = match cmd.file_type.as_str() { 111 | "ZARR_STORE" => ZarrTableUrl::ZarrStore(ListingTableUrl::parse(&cmd.location)?), 112 | #[cfg(feature = "icechunk")] 113 | "ICECHUNK_REPO" => ZarrTableUrl::IcechunkRepo(ListingTableUrl::parse(&cmd.location)?), 114 | _ => { 115 | return Err(DataFusionError::Execution(format!( 116 | "Unsupported file type {}", 117 | cmd.file_type 118 | ))) 119 | } 120 | }; 121 | 122 | let inferred_schema = table_url.infer_schema().await?; 123 | let schema = if cmd.schema.fields().is_empty() { 124 | inferred_schema 125 | } else { 126 | let provided_schema: Schema = cmd.schema.as_ref().into(); 127 | for field in provided_schema.fields() { 128 | let target_type = inferred_schema.field_with_name(field.name())?.data_type(); 129 | if field.data_type() != target_type { 130 | return Err(DataFusionError::Execution(format!( 131 | "Requested column {}'s type does not match data from store", 132 | field.name() 133 | ))); 134 | } 135 | } 136 | 137 | Arc::new(provided_schema) 138 | }; 139 | 140 | let zarr_config = ZarrTableConfig::new(table_url, schema); 141 | let table_provider = ZarrTable::new(zarr_config); 142 | Ok(Arc::new(table_provider)) 143 | } 144 | } 145 | 146 | #[cfg(test)] 147 | mod table_provider_tests { 148 | use std::collections::HashMap; 149 | 150 | use arrow::array::AsArray; 151 | use arrow::compute::concat_batches; 152 | use arrow::datatypes::Float64Type; 153 | use arrow_schema::DataType; 154 | use datafusion::execution::SessionStateBuilder; 155 | use datafusion::prelude::SessionContext; 156 | use futures_util::TryStreamExt; 157 | 158 | use super::*; 159 | use crate::table::table_provider::ZarrTable; 160 | #[cfg(feature = "icechunk")] 161 | use crate::test_utils::get_local_icechunk_repo; 162 | use crate::test_utils::{ 163 | extract_col, get_local_zarr_store, validate_names_and_types, validate_primitive_column, 164 | }; 165 | 166 | async fn read_and_validate(table_url: ZarrTableUrl, schema: SchemaRef) { 167 | let config = ZarrTableConfig::new(table_url, schema); 168 | 169 | let table_provider = ZarrTable::new(config); 170 | let state = SessionStateBuilder::new().build(); 171 | let session = SessionContext::new(); 172 | 173 | let scan = table_provider 174 | .scan(&state, None, &Vec::new(), None) 175 | .await 176 | .unwrap(); 177 | let records: Vec<_> = scan 178 | .execute(0, session.task_ctx()) 179 | .unwrap() 180 | .try_collect() 181 | .await 182 | .unwrap(); 183 | 184 | let target_types = HashMap::from([ 185 | ("lat".to_string(), DataType::Float64), 186 | ("lon".to_string(), DataType::Float64), 187 | ("data".to_string(), DataType::Float64), 188 | ]); 189 | validate_names_and_types(&target_types, &records[0]); 190 | assert_eq!(records.len(), 9); 191 | 192 | // the top left chunk, full 3x3 193 | validate_primitive_column::( 194 | "lat", 195 | &records[0], 196 | &[35., 35., 35., 36., 36., 36., 37., 37., 37.], 197 | ); 198 | validate_primitive_column::( 199 | "lon", 200 | &records[0], 201 | &[ 202 | -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, 203 | ], 204 | ); 205 | validate_primitive_column::( 206 | "data", 207 | &records[0], 208 | &[0.0, 1.0, 2.0, 8.0, 9.0, 10.0, 16.0, 17.0, 18.0], 209 | ); 210 | } 211 | 212 | #[tokio::test] 213 | async fn read_data_test() { 214 | // a zarr store in a local directory. 215 | let (wrapper, schema) = get_local_zarr_store(true, 0.0, "lat_lon_data_for_provider").await; 216 | let path = wrapper.get_store_path(); 217 | let table_url = ZarrTableUrl::ZarrStore(ListingTableUrl::parse(path).unwrap()); 218 | 219 | read_and_validate(table_url, schema).await; 220 | 221 | // a local icechunk repo. 222 | #[cfg(feature = "icechunk")] 223 | { 224 | let (wrapper, schema) = 225 | get_local_icechunk_repo(true, 0.0, "lat_lon_repo_for_provider").await; 226 | let path = wrapper.get_store_path(); 227 | let table_url = ZarrTableUrl::IcechunkRepo(ListingTableUrl::parse(path).unwrap()); 228 | 229 | read_and_validate(table_url, schema).await; 230 | } 231 | } 232 | 233 | #[tokio::test] 234 | async fn create_table_provider_test() { 235 | let (wrapper, _) = get_local_zarr_store(true, 0.0, "lat_lon_data_for_factory").await; 236 | let mut state = SessionStateBuilder::new().build(); 237 | let table_path = wrapper.get_store_path(); 238 | state 239 | .table_factories_mut() 240 | .insert("ZARR_STORE".into(), Arc::new(ZarrTableFactory {})); 241 | 242 | // create a table with 2 explicitly selected columns 243 | let query = format!( 244 | "CREATE EXTERNAL TABLE zarr_table_partial(lat double, lon double) STORED AS ZARR_STORE LOCATION '{}'", 245 | table_path, 246 | ); 247 | 248 | let session = SessionContext::new_with_state(state.clone()); 249 | session.sql(&query).await.unwrap(); 250 | 251 | // both columns are 1d coordinates. This should get resolved to 252 | // all combinations of lat with lon (8 lats, 8 lons -> 64 rows). 253 | let query = "SELECT lat, lon FROM zarr_table_partial"; 254 | let df = session.sql(query).await.unwrap(); 255 | let batches = df.collect().await.unwrap(); 256 | 257 | let schema = batches[0].schema(); 258 | let batch = concat_batches(&schema, &batches).unwrap(); 259 | assert_eq!(batch.num_columns(), 2); 260 | assert_eq!(batch.num_rows(), 64); 261 | 262 | // create a table, with 3 columns, lat, lon and data. 263 | let query = format!( 264 | "CREATE EXTERNAL TABLE zarr_table STORED AS ZARR_STORE LOCATION '{}'", 265 | table_path, 266 | ); 267 | 268 | let session = SessionContext::new_with_state(state.clone()); 269 | session.sql(&query).await.unwrap(); 270 | 271 | // a simple select statement with a limit. 272 | let query = "SELECT lat, lon FROM zarr_table LIMIT 10"; 273 | let df = session.sql(query).await.unwrap(); 274 | let batches = df.collect().await.unwrap(); 275 | 276 | let schema = batches[0].schema(); 277 | let batch = concat_batches(&schema, &batches).unwrap(); 278 | assert_eq!(batch.num_columns(), 2); 279 | assert_eq!(batch.num_rows(), 10); 280 | 281 | // a slightly more complex query involving a join. 282 | let query = " 283 | WITH d1 AS ( 284 | SELECT lat, lon, data 285 | FROM zarr_table 286 | ), 287 | 288 | d2 AS ( 289 | SELECT lat, lon, data*2 as data2 290 | FROM zarr_table 291 | ) 292 | 293 | SELECT data, data2 294 | FROM d1 295 | JOIN d2 296 | ON d1.lat = d2.lat 297 | AND d1.lon = d2.lon 298 | "; 299 | let df = session.sql(query).await.unwrap(); 300 | let batches = df.collect().await.unwrap(); 301 | 302 | let schema = batches[0].schema(); 303 | let batch = concat_batches(&schema, &batches).unwrap(); 304 | 305 | let data1: Vec<_> = batch 306 | .column_by_name("data") 307 | .unwrap() 308 | .as_primitive::() 309 | .values() 310 | .iter() 311 | .map(|f| f * 2.0) 312 | .collect(); 313 | let data2 = batch 314 | .column_by_name("data2") 315 | .unwrap() 316 | .as_primitive::() 317 | .values() 318 | .to_vec(); 319 | assert_eq!(data1, data2); 320 | 321 | // create a table from an icechunk repo. 322 | #[cfg(feature = "icechunk")] 323 | { 324 | let (wrapper, _) = get_local_icechunk_repo(true, 0.0, "lat_lon_repo_for_factory").await; 325 | let table_path = wrapper.get_store_path(); 326 | state 327 | .table_factories_mut() 328 | .insert("ICECHUNK_REPO".into(), Arc::new(ZarrTableFactory {})); 329 | 330 | let query = format!( 331 | "CREATE EXTERNAL TABLE zarr_table_icechunk STORED AS ICECHUNK_REPO LOCATION '{}'", 332 | table_path, 333 | ); 334 | 335 | let session = SessionContext::new_with_state(state.clone()); 336 | session.sql(&query).await.unwrap(); 337 | 338 | let query = "SELECT lat, lon FROM zarr_table LIMIT 10"; 339 | let df = session.sql(query).await.unwrap(); 340 | let batches = df.collect().await.unwrap(); 341 | 342 | let schema = batches[0].schema(); 343 | let batch = concat_batches(&schema, &batches).unwrap(); 344 | assert_eq!(batch.num_columns(), 2); 345 | assert_eq!(batch.num_rows(), 10); 346 | } 347 | } 348 | 349 | #[tokio::test] 350 | async fn partial_coordinates_query() { 351 | let (wrapper, _) = 352 | get_local_zarr_store(true, 0.0, "lat_lon_data_partial_coord_query").await; 353 | let mut state = SessionStateBuilder::new().build(); 354 | let table_path = wrapper.get_store_path(); 355 | state 356 | .table_factories_mut() 357 | .insert("ZARR_STORE".into(), Arc::new(ZarrTableFactory {})); 358 | 359 | let query = format!( 360 | "CREATE EXTERNAL TABLE zarr_table STORED AS ZARR_STORE LOCATION '{}'", 361 | table_path, 362 | ); 363 | 364 | let session = SessionContext::new_with_state(state.clone()); 365 | session.sql(&query).await.unwrap(); 366 | 367 | // select the 2D data and only one of the 1D coordinates. This should get 368 | // resolved to the lon being brodacasted to match the 2D data. 369 | let query = "SELECT data, lon FROM zarr_table"; 370 | let df = session.sql(query).await.unwrap(); 371 | let batches = df.collect().await.unwrap(); 372 | 373 | let schema = batches[0].schema(); 374 | let batch = concat_batches(&schema, &batches).unwrap(); 375 | assert_eq!(batch.num_columns(), 2); 376 | assert_eq!(batch.num_rows(), 64); 377 | } 378 | 379 | #[tokio::test] 380 | async fn query_with_filter() { 381 | let (wrapper, _) = get_local_zarr_store(true, 0.0, "lat_lon_data_filter_query").await; 382 | let mut state = SessionStateBuilder::new().build(); 383 | let table_path = wrapper.get_store_path(); 384 | state 385 | .table_factories_mut() 386 | .insert("ZARR_STORE".into(), Arc::new(ZarrTableFactory {})); 387 | 388 | let query = format!( 389 | "CREATE EXTERNAL TABLE zarr_table STORED AS ZARR_STORE LOCATION '{}'", 390 | table_path, 391 | ); 392 | 393 | let session = SessionContext::new_with_state(state.clone()); 394 | session.sql(&query).await.unwrap(); 395 | 396 | // select the 2D data and only one of the 1D coordinates. This should get 397 | // resolved to the lon being brodacasted to match the 2D data. 398 | let query = " 399 | SELECT lat, lon, data 400 | FROM zarr_table 401 | WHERE lat < 38.1 402 | AND lon > -116.9 403 | "; 404 | let df = session.sql(query).await.unwrap(); 405 | let batches = df.collect().await.unwrap(); 406 | 407 | // this tests for the actual WHERE clause, which is a combination 408 | // of the filter pushdown and some filtering provided by datafusion, 409 | // out-of-the-box, so the condition in the test matches the WHERE 410 | // clause exactly. 411 | for batch in batches { 412 | let lat_values = extract_col::("lat", &batch); 413 | let lon_values = extract_col::("lon", &batch); 414 | assert!(lat_values 415 | .iter() 416 | .zip(lon_values.iter()) 417 | .all(|(lat, lon)| *lat < 38.1 && *lon > -116.9)); 418 | } 419 | } 420 | 421 | #[tokio::test] 422 | async fn table_factory_error_test() { 423 | let (wrapper, _) = get_local_zarr_store(true, 0.0, "lat_lon_data_for_factory_error").await; 424 | let mut state = SessionStateBuilder::new().build(); 425 | let table_path = wrapper.get_store_path(); 426 | state 427 | .table_factories_mut() 428 | .insert("ZARR_STORE".into(), Arc::new(ZarrTableFactory {})); 429 | 430 | // create a table with 2 explicitly selected columns, but the names 431 | // are wrong so it should error out. 432 | let query = format!( 433 | "CREATE EXTERNAL TABLE zarr_table(latitude double, longitude double) STORED AS ZARR_STORE LOCATION '{}'", 434 | table_path, 435 | ); 436 | 437 | let session = SessionContext::new_with_state(state.clone()); 438 | let res = session.sql(&query).await; 439 | match res { 440 | Ok(_) => panic!(), 441 | Err(e) => { 442 | assert_eq!( 443 | e.to_string(), 444 | "Arrow error: Schema error: Unable to get field named \"latitude\". Valid fields: [\"data\", \"lat\", \"lon\"]" 445 | ); 446 | } 447 | } 448 | 449 | // create a table with 2 explicitly selected columns, but the type for the 450 | // columns are wrong so it should error out. 451 | let query = format!( 452 | "CREATE EXTERNAL TABLE zarr_table(lat int, lon int) STORED AS ZARR_STORE LOCATION '{}'", 453 | table_path, 454 | ); 455 | 456 | let session = SessionContext::new_with_state(state.clone()); 457 | let res = session.sql(&query).await; 458 | match res { 459 | Ok(_) => panic!(), 460 | Err(e) => { 461 | assert_eq!( 462 | e.to_string(), 463 | "Execution error: Requested column lat's type does not match data from store" 464 | ); 465 | } 466 | } 467 | } 468 | } 469 | -------------------------------------------------------------------------------- /src/zarr_store_opener/zarr_data_stream.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | use std::cmp::min; 3 | use std::collections::{HashMap, VecDeque}; 4 | use std::path::PathBuf; 5 | use std::pin::Pin; 6 | use std::sync::Arc; 7 | use std::task::{Context, Poll}; 8 | 9 | use arrow::array::*; 10 | use arrow::datatypes::*; 11 | use arrow::record_batch::RecordBatch; 12 | use arrow_schema::ArrowError; 13 | use async_stream::try_stream; 14 | use bytes::Bytes; 15 | use futures::stream::{BoxStream, Stream}; 16 | use itertools::iproduct; 17 | use tokio::sync::mpsc::Receiver; 18 | use tokio::task::JoinSet; 19 | use zarrs::array::codec::{ArrayToBytesCodecTraits, CodecOptions}; 20 | use zarrs::array::{Array, ArrayBytes, ArraySize, DataType as zDataType, ElementOwned}; 21 | use zarrs::array_subset::ArraySubset; 22 | use zarrs_storage::AsyncReadableListableStorageTraits; 23 | 24 | use super::filter::ZarrChunkFilter; 25 | use super::io_runtime::IoRuntime; 26 | use crate::errors::zarr_errors::{ZarrQueryError, ZarrQueryResult}; 27 | 28 | /// this function handles having multiple values for a given vector, 29 | /// one per array, including some arrays that might be lower 30 | /// dimension coordinates. 31 | fn resolve_vector( 32 | coords: &ZarrCoordinates, 33 | vecs: HashMap>, 34 | ) -> ZarrQueryResult> { 35 | let mut final_vec: Option> = None; 36 | for (k, vec) in vecs.iter() { 37 | if let Some(final_vec) = &final_vec { 38 | if let Some(pos) = coords.get_coord_position(k) { 39 | // if we have a the final vector (from a previous non 40 | // coordinate array), and this current array is a coordinate 41 | // array, its one vector element must match the array element 42 | // in the final vector at the position of the cooridnate. 43 | if final_vec[pos] != vec[0] { 44 | return Err(ZarrQueryError::InvalidMetadata( 45 | "Mismatch between vectors for different arrays".into(), 46 | )); 47 | } 48 | // if the current array is not a coordinate, it must match the 49 | // final vector we have extracted from a previous array. 50 | } else if final_vec != vec { 51 | return Err(ZarrQueryError::InvalidMetadata( 52 | "Mismatch between vectors for different arrays".into(), 53 | )); 54 | } 55 | } else if !coords.is_coordinate(k) { 56 | final_vec = Some(vec.clone()); 57 | } 58 | } 59 | 60 | if let Some(final_vec) = final_vec { 61 | Ok(final_vec) 62 | // the else branch here would happen if all the arrays are coordinates. 63 | } else { 64 | let mut final_vec: Vec = vec![0; coords.coord_positions.len()]; 65 | for (k, p) in coords.coord_positions.iter() { 66 | final_vec[*p] = vecs.get(k).ok_or(ZarrQueryError::InvalidMetadata( 67 | "Array is missing from array map".into(), 68 | ))?[0]; 69 | } 70 | Ok(final_vec) 71 | } 72 | } 73 | 74 | /// A struct to handle coordinate variables, and "broadcasting" them when reading multidimensional 75 | /// data. 76 | #[derive(Debug)] 77 | struct ZarrCoordinates { 78 | /// the position of each coordinate in the overall chunk shape. 79 | /// the coordinates are arrays that contain data that characterises 80 | /// a dimension, such as time, or a longitude or latitude. 81 | coord_positions: HashMap, 82 | } 83 | 84 | impl ZarrCoordinates { 85 | fn new( 86 | arrays: &HashMap>, 87 | schema_ref: SchemaRef, 88 | ) -> ZarrQueryResult { 89 | // the goal of these "coordinates" is to determine what needs 90 | // to be broadcasted from 1D to ND depending on what columns 91 | // were selected. based on what is a broadcastable coordinate 92 | // and its position in the overall chunk dimensionality, we 93 | // can combine a 1D array with ND arrays later on. 94 | let mut coord_positions: HashMap = HashMap::new(); 95 | 96 | // this is pretty messy, but essentially for each array we 97 | // extract it's dimentionality and its dimension. at this stage, 98 | // we allow for an array to not have dimensions, but not to have 99 | // dimensions without a name. 100 | let arr_dims = arrays 101 | .iter() 102 | .map(|(k, v)| { 103 | ( 104 | k, 105 | v.dimensionality(), 106 | v.dimension_names() 107 | .clone() 108 | .map(|vec| { 109 | vec.into_iter().collect::>>().ok_or( 110 | ZarrQueryError::InvalidMetadata( 111 | "Null dimension names not supported".into(), 112 | ), 113 | ) 114 | }) 115 | .transpose(), 116 | ) 117 | }) 118 | .map(|(k, d, res)| res.map(|names| (k, d, names))) 119 | .collect::>)>>>()?; 120 | 121 | // first case to check, do all the arrays have the same 122 | // dimensionality. 123 | let mut ordered_dim_names: Option> = None; 124 | if arr_dims.windows(2).all(|w| w[0].1 == w[1].1) { 125 | // this is the case where all the arrays are coordinates, 126 | // so we determine the broadcasting order from the schema. 127 | if arr_dims.iter().all(|d| Some(vec![d.0.clone()]) == d.2) { 128 | ordered_dim_names = Some( 129 | schema_ref 130 | .fields 131 | .into_iter() 132 | .map(|f| f.name().to_string()) 133 | .collect(), 134 | ); 135 | // this is the case where there is a mix of data and 136 | // coordinates, but all the coordinates are already 137 | // stored as broadcasted ararys, so there is no need 138 | // to do anything later on. 139 | } else { 140 | return Ok(Self { coord_positions }); 141 | } 142 | } 143 | 144 | // if we didn't hit the above conditions, then the arrays 145 | // have mixed dimensionality. we extract the chunk dimension 146 | // names, which must be consistent across all arrays (that 147 | // are not broadcastable coordinates). 148 | if ordered_dim_names.is_none() { 149 | for d in &arr_dims { 150 | if d.1 != 1 { 151 | let d = d.2.clone(); 152 | let arr_dim_names: Vec<_> = d.ok_or(ZarrQueryError::InvalidMetadata( 153 | "With mixed array dimensionality, dimension names are required".into(), 154 | ))?; 155 | 156 | if let Some(ordered_dim_names) = &ordered_dim_names { 157 | if *ordered_dim_names != arr_dim_names { 158 | return Err(ZarrQueryError::InvalidMetadata( 159 | "Dimension names must be consistent across arrays".into(), 160 | )); 161 | } 162 | } else { 163 | ordered_dim_names = Some(arr_dim_names); 164 | } 165 | } 166 | } 167 | } 168 | 169 | // for each 1D array, we check that it is a coordinate, it has 170 | // to be at this point in the function, and find its position 171 | // in the chunk dimension names. 172 | let ordered_dim_names = ordered_dim_names.ok_or(ZarrQueryError::InvalidMetadata( 173 | "With mixed array dimensionality, dimension names are required".into(), 174 | ))?; 175 | for d in arr_dims { 176 | if d.1 == 1 { 177 | if Some(vec![d.0.clone()]) != d.2 { 178 | return Err(ZarrQueryError::InvalidMetadata( 179 | "With mixed array dimensionality, 1D arrays must be coordinates".into(), 180 | )); 181 | } 182 | let pos = ordered_dim_names.iter().position(|dim| dim == d.0).ok_or( 183 | ZarrQueryError::InvalidMetadata( 184 | "Could not find coordinate in dimension names".into(), 185 | ), 186 | )?; 187 | coord_positions.insert(d.0.clone(), pos); 188 | } 189 | } 190 | 191 | Ok(Self { coord_positions }) 192 | } 193 | 194 | /// checks if a column name corresponds to a coordinate. 195 | fn is_coordinate(&self, col: &str) -> bool { 196 | self.coord_positions.contains_key(col) 197 | } 198 | 199 | /// returns the position of a coordinate within the chunk 200 | /// dimensionality if the column is a coordinate, if not 201 | /// returns None. 202 | fn get_coord_position(&self, col: &str) -> Option { 203 | self.coord_positions.get(col).cloned() 204 | } 205 | 206 | /// return the vector element that corresponds to a coordinate's 207 | /// position within the dimensionality (if the variable is a coordinate). 208 | fn reduce_if_coord(&self, vec: Vec, col: &str) -> Vec { 209 | if let Some(pos) = self.coord_positions.get(col) { 210 | return vec![vec[*pos]]; 211 | } 212 | 213 | vec 214 | } 215 | 216 | /// broadacast a 1D array to a nD array if the variable is a coordinate. 217 | /// note that we return a 1D vector, but this is just because we map all 218 | /// the chunk to columnar data, so a m x n array gets mapped to a 1D 219 | /// vector of length m x n. 220 | fn broadcast_if_coord( 221 | &self, 222 | coord_name: &str, 223 | data: Vec, 224 | full_chunk_shape: &[u64], 225 | ) -> ZarrQueryResult> { 226 | let dim_idx = self.get_coord_position(coord_name); 227 | if dim_idx.is_none() || full_chunk_shape.len() == 1 { 228 | return Ok(data); 229 | } 230 | let dim_idx = dim_idx.unwrap(); 231 | 232 | match (full_chunk_shape.len(), dim_idx) { 233 | (2, 0) => Ok(data 234 | .into_iter() 235 | .flat_map(|v| std::iter::repeat_n(v, full_chunk_shape[1] as usize)) 236 | .collect()), 237 | (2, 1) => Ok(vec![&data[..]; full_chunk_shape[0] as usize].concat()), 238 | (3, 0) => Ok(data 239 | .into_iter() 240 | .flat_map(|v| { 241 | std::iter::repeat_n(v, (full_chunk_shape[1] * full_chunk_shape[2]) as usize) 242 | }) 243 | .collect()), 244 | (3, 1) => { 245 | let v: Vec<_> = data 246 | .into_iter() 247 | .flat_map(|v| std::iter::repeat_n(v, full_chunk_shape[2] as usize)) 248 | .collect(); 249 | Ok(vec![&v[..]; full_chunk_shape[0] as usize].concat()) 250 | } 251 | (3, 2) => { 252 | Ok(vec![&data[..]; (full_chunk_shape[0] * full_chunk_shape[1]) as usize].concat()) 253 | } 254 | _ => Err(ZarrQueryError::InvalidCompute( 255 | "Invalid dimensionality when trying to broadcast dimension".into(), 256 | )), 257 | } 258 | } 259 | } 260 | 261 | /// An interface to a zarr array that can be used to retrieve 262 | /// data and then decode it. 263 | /// 264 | /// the chunk index corresponds to the chunk that is being read, the 265 | /// coords to the coordinates for the full chunk (which can be made up 266 | /// of one or more arrays) and the full chunk shape is relevant when 267 | /// the chunk has some coordinate arrays, which are 1 dimensional, while 268 | /// the non coordinate arrays can be multi dimensional. the full chunk 269 | /// size is used to broadcast the coordinates to the full size. 270 | struct ArrayInterface { 271 | name: String, 272 | arr: Arc>, 273 | coords: Arc, 274 | full_chunk_shape: Vec, 275 | chk_index: Vec, 276 | } 277 | 278 | // T doesn't need to be Clone, but deriving apparently requires 279 | // that, so I have implement manually. 280 | impl Clone for ArrayInterface { 281 | fn clone(&self) -> Self { 282 | Self { 283 | name: self.name.to_string(), 284 | arr: self.arr.clone(), 285 | coords: self.coords.clone(), 286 | full_chunk_shape: self.full_chunk_shape.clone(), 287 | chk_index: self.chk_index.clone(), 288 | } 289 | } 290 | } 291 | 292 | /// in most cases, we will read encoded bytes and decode them after, 293 | /// but in the case of a missing chunk the result of the read operation 294 | /// will be done.vin a few cases though we will read pre-decoded bytes, 295 | /// hence why we have this enum. 296 | enum BytesFromArray { 297 | Decoded(Bytes), 298 | Encoded(Option), 299 | } 300 | 301 | impl ArrayInterface { 302 | fn new( 303 | name: String, 304 | arr: Arc>, 305 | coords: Arc, 306 | full_chunk_shape: Vec, 307 | mut chk_index: Vec, 308 | ) -> Self { 309 | chk_index = coords.reduce_if_coord(chk_index, &name); 310 | Self { 311 | name, 312 | arr, 313 | coords, 314 | full_chunk_shape, 315 | chk_index, 316 | } 317 | } 318 | 319 | /// read the bytes from the chunk the interface was built for. 320 | async fn read_bytes(&self) -> ZarrQueryResult { 321 | let chunk_grid = self.arr.chunk_grid_shape(); 322 | let is_edge_grid = self 323 | .chk_index 324 | .iter() 325 | .zip(chunk_grid.iter()) 326 | .any(|(i, g)| i == &(g - 1)); 327 | // handling edges is easier if we just read a subset of the array 328 | // from the start, but to do that we need to read decoded bytes. 329 | if is_edge_grid { 330 | let arr_shape = self.arr.shape(); 331 | let chunk_shape = self.arr.chunk_shape(&self.chk_index)?.to_array_shape(); 332 | 333 | // determine the real size for each of the dimensions (at least 334 | // one of which will be at the edge of the array.) 335 | let ranges: Vec<_> = self 336 | .chk_index 337 | .iter() 338 | .zip(arr_shape.iter()) 339 | .zip(chunk_shape.iter()) 340 | .map(|((i, a), c)| 0..(std::cmp::min(a - i * c, *c))) 341 | .collect(); 342 | 343 | let array_subset = ArraySubset::new_with_ranges(&ranges); 344 | let data = self 345 | .arr 346 | .async_retrieve_chunk_subset(&self.chk_index, &array_subset) 347 | .await?; 348 | let data = data.into_fixed()?; 349 | Ok(BytesFromArray::Decoded(data.into_owned().into())) 350 | // this will be the more common case, everything except edge chunks. 351 | } else { 352 | let data = self 353 | .arr 354 | .async_retrieve_encoded_chunk(&self.chk_index) 355 | .await?; 356 | Ok(BytesFromArray::Encoded(data)) 357 | } 358 | } 359 | 360 | /// decode the chunk that was read previously read from this interface. 361 | /// the reason the 2 functionalities are separated is that we want to 362 | /// interleave the async part (reading data) with the compute part 363 | /// (decoding the data, creating the record batch) so that we can make 364 | /// progress on the latter while the former is running. 365 | fn decode_data(&self, bytes: BytesFromArray) -> ZarrQueryResult { 366 | let decoded_bytes = match bytes { 367 | BytesFromArray::Encoded(bytes) => { 368 | if let Some(bytes) = bytes { 369 | self.arr.codecs().decode( 370 | Cow::Owned(bytes.into()), 371 | &self.arr.chunk_array_representation(&self.chk_index)?, 372 | &CodecOptions::default(), 373 | )? 374 | } else { 375 | let chk_shp = self 376 | .coords 377 | .reduce_if_coord(self.full_chunk_shape.clone(), &self.name); 378 | let num_elems = chk_shp.iter().fold(1, |mut acc, x| { 379 | acc *= x; 380 | acc 381 | }); 382 | let array_size = ArraySize::new(self.arr.data_type().size(), num_elems); 383 | ArrayBytes::new_fill_value(array_size, self.arr.fill_value()) 384 | } 385 | } 386 | BytesFromArray::Decoded(bytes) => ArrayBytes::Fixed(Cow::Owned(bytes.into())), 387 | }; 388 | 389 | let t = self.arr.data_type(); 390 | macro_rules! return_array_ref { 391 | ($array_t: ty, $prim_type: ty) => {{ 392 | let arr_ref: $array_t = self 393 | .coords 394 | .broadcast_if_coord( 395 | &self.name, 396 | <$prim_type>::from_array_bytes(t, decoded_bytes)?, 397 | &self.full_chunk_shape, 398 | )? 399 | .into(); 400 | return Ok(Arc::new(arr_ref) as ArrayRef); 401 | }}; 402 | } 403 | 404 | match t { 405 | zDataType::Bool => return_array_ref!(BooleanArray, bool), 406 | zDataType::UInt8 => return_array_ref!(PrimitiveArray, u8), 407 | zDataType::UInt16 => return_array_ref!(PrimitiveArray, u16), 408 | zDataType::UInt32 => return_array_ref!(PrimitiveArray, u32), 409 | zDataType::UInt64 => return_array_ref!(PrimitiveArray, u64), 410 | zDataType::Int8 => return_array_ref!(PrimitiveArray, i8), 411 | zDataType::Int16 => return_array_ref!(PrimitiveArray, i16), 412 | zDataType::Int32 => return_array_ref!(PrimitiveArray, i32), 413 | zDataType::Int64 => return_array_ref!(PrimitiveArray, i64), 414 | zDataType::Float32 => return_array_ref!(PrimitiveArray, f32), 415 | zDataType::Float64 => return_array_ref!(PrimitiveArray, f64), 416 | zDataType::String => return_array_ref!(StringArray, String), 417 | _ => Err(ZarrQueryError::InvalidType(format!( 418 | "Unsupported type {t} from zarr metadata" 419 | ))), 420 | } 421 | } 422 | } 423 | 424 | /// A structure to accumulate zarr array data until we can output 425 | /// the whole chunk as a record batch. 426 | struct ZarrInMemoryChunk { 427 | data: HashMap, 428 | } 429 | 430 | impl ZarrInMemoryChunk { 431 | fn new() -> Self { 432 | Self { 433 | data: HashMap::new(), 434 | } 435 | } 436 | 437 | fn add_data(&mut self, arr_name: String, data: ArrayRef) { 438 | self.data.insert(arr_name, data); 439 | } 440 | 441 | fn combine(&mut self, other: ZarrInMemoryChunk) { 442 | self.data.extend(other.data); 443 | } 444 | 445 | fn check_filter(&self, filter: &ZarrChunkFilter) -> Result { 446 | let array_refs: Vec<(String, ArrayRef)> = filter 447 | .schema_ref() 448 | .fields() 449 | .iter() 450 | .map(|f| self.data.get(f.name()).cloned()) 451 | .collect::>>() 452 | .ok_or(ZarrQueryError::InvalidProjection( 453 | "Array missing from array map".into(), 454 | ))? 455 | .into_iter() 456 | .zip(filter.schema_ref().fields.iter()) 457 | .map(|(ar, f)| (f.name().to_string(), ar)) 458 | .collect(); 459 | 460 | let rec_batch = RecordBatch::try_from_iter(array_refs)?; 461 | filter.evaluate(&rec_batch) 462 | } 463 | 464 | /// the columns in the record batch will be ordered following 465 | /// the field names in the schema. 466 | fn into_record_batch(mut self, schema: &SchemaRef) -> ZarrQueryResult { 467 | let array_refs: Vec<(String, ArrayRef)> = schema 468 | .fields() 469 | .iter() 470 | .map(|f| self.data.remove(f.name())) 471 | .collect::>>() 472 | .ok_or(ZarrQueryError::InvalidProjection( 473 | "Array missing from array map".into(), 474 | ))? 475 | .into_iter() 476 | .zip(schema.fields.iter()) 477 | .map(|(ar, f)| (f.name().to_string(), ar)) 478 | .collect(); 479 | 480 | RecordBatch::try_from_iter(array_refs) 481 | .map_err(|e| ZarrQueryError::RecordBatchError(Box::new(e))) 482 | } 483 | } 484 | 485 | /// A wrapper for a map of arrays, which will handle interleaving 486 | /// reading and decoding data from zarr storage. 487 | type ZarrReceiver = Receiver<(ZarrQueryResult, ArrayInterface)>; 488 | struct ZarrStore { 489 | arrays: HashMap>>, 490 | coordinates: Arc, 491 | chunk_shape: Vec, 492 | chunk_grid_shape: Vec, 493 | array_shape: Vec, 494 | io_runtime: IoRuntime, 495 | join_set: JoinSet<()>, 496 | state: Option<(ZarrReceiver, Vec, Vec)>, 497 | } 498 | 499 | impl ZarrStore { 500 | fn new(arrays: HashMap>, schema_ref: SchemaRef) -> ZarrQueryResult { 501 | let coordinates = ZarrCoordinates::new(&arrays, schema_ref)?; 502 | 503 | // technically getting the chunk shape requires a chunk 504 | // index, but it seems the zarrs library doesn't actually 505 | // return a chunk size that depends on the index, at least 506 | // for regular grids (it ignores edges in other words). 507 | // so here we just retrive "chunk 0", store that, and adjust 508 | // for array edges in a separate function. 509 | let mut chk_shapes: HashMap> = HashMap::new(); 510 | for (k, arr) in arrays.iter() { 511 | let chk_idx = vec![0; arr.shape().len()]; 512 | chk_shapes.insert(k.to_owned(), arr.chunk_shape(&chk_idx)?.to_array_shape()); 513 | } 514 | let chunk_shape = resolve_vector(&coordinates, chk_shapes)?; 515 | 516 | let mut chk_grid_shapes: HashMap> = HashMap::new(); 517 | for (k, arr) in arrays.iter() { 518 | chk_grid_shapes.insert(k.to_owned(), arr.chunk_grid_shape().clone()); 519 | } 520 | let chunk_grid_shape = resolve_vector(&coordinates, chk_grid_shapes)?; 521 | 522 | let mut arr_shapes: HashMap> = HashMap::new(); 523 | for (k, arr) in arrays.iter() { 524 | arr_shapes.insert(k.to_owned(), arr.shape().to_vec()); 525 | } 526 | let array_shape = resolve_vector(&coordinates, arr_shapes)?; 527 | 528 | // this runtime will handle the i/o. i/o tasks spawned in 529 | // that runtime will not share a thead pool with out (probably 530 | // compute heavy, blocking) tasks. 531 | let io_runtime = IoRuntime::try_new()?; 532 | 533 | Ok(Self { 534 | arrays: arrays.into_iter().map(|(k, a)| (k, Arc::new(a))).collect(), 535 | coordinates: Arc::new(coordinates), 536 | chunk_shape, 537 | chunk_grid_shape, 538 | array_shape, 539 | io_runtime, 540 | join_set: JoinSet::new(), 541 | state: None, 542 | }) 543 | } 544 | 545 | /// return the chunk shape for a given index, taking into account 546 | /// the array edges where the "real" chunk is smaller than the 547 | /// chunk size in the metadata. 548 | fn get_chunk_shape(&self, chk_idx: &[u64]) -> ZarrQueryResult> { 549 | let is_edge_grid = chk_idx 550 | .iter() 551 | .zip(self.chunk_grid_shape.iter()) 552 | .any(|(i, g)| i == &(g - 1)); 553 | 554 | let mut chunk_shape = self.chunk_shape.clone(); 555 | if is_edge_grid { 556 | chunk_shape = chk_idx 557 | .iter() 558 | .zip(self.array_shape.iter()) 559 | .zip(chunk_shape.iter()) 560 | .map(|((i, a), c)| std::cmp::min(a - i * c, *c)) 561 | .collect(); 562 | } 563 | 564 | Ok(chunk_shape) 565 | } 566 | 567 | fn get_array_interfaces( 568 | &self, 569 | cols: Vec, 570 | chk_idx: Vec, 571 | ) -> ZarrQueryResult>> { 572 | let full_chunk_shape = self.get_chunk_shape(&chk_idx)?; 573 | let arr_interfaces = cols 574 | .iter() 575 | .map(|col| { 576 | let arr = self 577 | .arrays 578 | .get(col) 579 | .ok_or_else(|| ZarrQueryError::InvalidCompute("".into()))? 580 | .clone(); 581 | Ok(ArrayInterface::new( 582 | col.to_string(), 583 | arr, 584 | self.coordinates.clone(), 585 | full_chunk_shape.clone(), 586 | chk_idx.clone(), 587 | )) 588 | }) 589 | .collect::>>() 590 | .unwrap(); 591 | Ok(arr_interfaces) 592 | } 593 | 594 | /// this is the main function that does the heavy lifting, getting 595 | /// the data from the zarr store and decoding it. 596 | async fn get_chunk( 597 | &mut self, 598 | cols: Vec, 599 | chk_idx: Vec, 600 | use_cached_value: bool, 601 | next_chunk_idx: Option>, 602 | ) -> ZarrQueryResult { 603 | if cols.is_empty() { 604 | return Err(ZarrQueryError::InvalidProjection( 605 | "No columns when polling zarr store for chunks".into(), 606 | )); 607 | } 608 | let mut chk_data = ZarrInMemoryChunk::new(); 609 | 610 | // if there is a cached zarr chunk that was triggered from the 611 | // previous call, we use that. 612 | if use_cached_value & self.state.is_some() { 613 | let (mut rx, cached_idx, cached_cols) = self 614 | .state 615 | .take() 616 | .expect("Cached zarr received unexpectedly available"); 617 | if chk_idx != cached_idx { 618 | return Err(ZarrQueryError::InvalidCompute( 619 | "Cached zarr chunk index doesn't match requested chunk index".into(), 620 | )); 621 | } 622 | 623 | if cols != cached_cols { 624 | return Err(ZarrQueryError::InvalidCompute( 625 | "Cached zarr chunk columns don't match requested columns".into(), 626 | )); 627 | } 628 | 629 | while let Some((data, arr_interface)) = rx.recv().await { 630 | let data = data?; 631 | let data = arr_interface.decode_data(data)?; 632 | chk_data.add_data(arr_interface.name, data); 633 | } 634 | } 635 | // if we are either not pre reading data, or we are but 636 | // this is the first call and there is no cached data yet, 637 | // we read the data now and wait for it to be ready. 638 | else { 639 | let arr_interfaces = self.get_array_interfaces(cols.clone(), chk_idx)?; 640 | let (tx, mut rx) = tokio::sync::mpsc::channel(arr_interfaces.len()); 641 | for arr_interface in arr_interfaces { 642 | let tx_copy = tx.clone(); 643 | let io_task = async move { 644 | let b = arr_interface.read_bytes().await; 645 | let _ = tx_copy.send((b, arr_interface)).await; 646 | }; 647 | self.join_set.spawn_on(io_task, self.io_runtime.handle()); 648 | 649 | if let Some((Ok(d), arr_int)) = rx.recv().await { 650 | let data = arr_int.decode_data(d)?; 651 | chk_data.add_data(arr_int.name, data); 652 | } else { 653 | return Err(ZarrQueryError::InvalidCompute( 654 | "Unable to retrieve decoded chunk".into(), 655 | )); 656 | } 657 | } 658 | }; 659 | 660 | // if the call was made with an index for the next chunk, we 661 | // submit a job to read that next chunk before returning, 662 | // so that we can fetch the data while other operations run 663 | // between now and the next call to this function. 664 | if let Some(next_chunk_idx) = next_chunk_idx { 665 | let arr_interfaces = self.get_array_interfaces(cols.clone(), next_chunk_idx.clone())?; 666 | let (tx, rx) = tokio::sync::mpsc::channel(arr_interfaces.len()); 667 | for arr_interface in arr_interfaces { 668 | let tx_copy = tx.clone(); 669 | let io_task = async move { 670 | let b = arr_interface.read_bytes().await; 671 | let _ = tx_copy.send((b, arr_interface)).await; 672 | }; 673 | self.join_set.spawn_on(io_task, self.io_runtime.handle()); 674 | } 675 | self.state = Some((rx, next_chunk_idx, cols)); 676 | }; 677 | 678 | Ok(chk_data) 679 | } 680 | } 681 | 682 | /// A stream of RecordBatches read from a Zarr store. 683 | /// 684 | /// This struct is separate from `ZarrRecordBatchStream`, so that we can avoid manually 685 | /// implementing [`Stream`]. Instead, we use the `async-stream` crate to convert an async iterable 686 | /// into a stream. 687 | struct ZarrRecordBatchStreamInner { 688 | zarr_store: Arc>, 689 | projected_schema_ref: SchemaRef, 690 | schema_without_filter_cols: Option, 691 | filter: Option, 692 | chunk_indices: VecDeque>, 693 | } 694 | 695 | impl ZarrRecordBatchStreamInner { 696 | /// Create a new ZarrRecordBatchStreamInner. 697 | /// 698 | /// This function is intentionally private, as all users should call 699 | /// [`ZarrRecordBatchStream::new`] instead. 700 | async fn new( 701 | store: Arc, 702 | schema_ref: SchemaRef, 703 | prefix: Option, 704 | projection: Option>, 705 | n_partitions: usize, 706 | partition: usize, 707 | ) -> ZarrQueryResult { 708 | // quick check to make sure the partition we're reading from does 709 | // not exceed the number of partitions. 710 | if partition >= n_partitions { 711 | return Err(ZarrQueryError::InvalidCompute( 712 | "Parition number exceeds number of partition in zarr stream".into(), 713 | )); 714 | } 715 | 716 | // if there is a projection provided, modify the schema. 717 | let projected_schema_ref = match projection { 718 | Some(proj) => Arc::new(schema_ref.project(&proj)?), 719 | None => schema_ref.clone(), 720 | }; 721 | 722 | // the prefix is necessary when reading from some remote 723 | // stores that don't work off of the url and require a 724 | // prefix. for example aws s3 object store doesn't seem 725 | // to use the url, just the bucket, so the path to the 726 | // actual zarr store needs to be provided separately. 727 | let prefix = if let Some(prefix) = prefix { 728 | ["/".into(), prefix].join("") 729 | } else { 730 | "/".to_string() 731 | }; 732 | 733 | // this will extract column (i.e. array) names based (possibly 734 | // projected) schema. 735 | let cols: Vec<_> = projected_schema_ref 736 | .fields() 737 | .iter() 738 | .map(|f| f.name()) 739 | .collect(); 740 | 741 | // open all the arrays based on the column names. 742 | let mut arrays: HashMap> = HashMap::new(); 743 | for col in &cols { 744 | let path = PathBuf::from(&prefix) 745 | .join(col) 746 | .into_os_string() 747 | .to_str() 748 | .ok_or(ZarrQueryError::InvalidMetadata( 749 | "could not form path from group and column name".into(), 750 | ))? 751 | .to_string(); 752 | let arr = Array::async_open(store.clone(), &path).await?; 753 | arrays.insert(col.to_string(), arr); 754 | } 755 | 756 | // store all the zarr arrays in a struct that we can use 757 | // to access them later. 758 | let zarr_store = Arc::new(ZarrStore::new(arrays, projected_schema_ref.clone())?); 759 | 760 | // this creates all the chunk indices we will be reading from. 761 | let chk_grid_shape = &zarr_store.chunk_grid_shape; 762 | let mut chunk_indices: Vec<_> = match chk_grid_shape.len() { 763 | 1 => (0..chk_grid_shape[0]).map(|i| vec![i]).collect(), 764 | 2 => { 765 | let d0: Vec<_> = (0..chk_grid_shape[0]).collect(); 766 | let d1: Vec<_> = (0..chk_grid_shape[1]).collect(); 767 | iproduct!(d0, d1).map(|(x, y)| vec![x, y]).collect() 768 | } 769 | 3 => { 770 | let d0: Vec<_> = (0..chk_grid_shape[0]).collect(); 771 | let d1: Vec<_> = (0..chk_grid_shape[1]).collect(); 772 | let d2: Vec<_> = (0..chk_grid_shape[2]).collect(); 773 | iproduct!(d0, d1, d2) 774 | .map(|(x, y, z)| vec![x, y, z]) 775 | .collect() 776 | } 777 | _ => { 778 | return Err(ZarrQueryError::InvalidMetadata( 779 | "Only 1, 2 or 3D arrays supported".into(), 780 | )) 781 | } 782 | }; 783 | let chunks_per_partitions = chunk_indices.len().div_ceil(n_partitions); 784 | let max_idx = chunk_indices.len(); 785 | let start = chunks_per_partitions * partition; 786 | let end = min(chunks_per_partitions * (partition + 1), max_idx); 787 | 788 | // this is to handle cases where more partitions than there are 789 | // chunks to read were requested. 790 | if end <= start { 791 | chunk_indices = Vec::new(); 792 | } else { 793 | chunk_indices = chunk_indices[start..end].to_vec(); 794 | } 795 | let chunk_indices = VecDeque::from(chunk_indices); 796 | 797 | Ok(Self { 798 | zarr_store, 799 | projected_schema_ref, 800 | filter: None, 801 | chunk_indices, 802 | schema_without_filter_cols: None, 803 | }) 804 | } 805 | 806 | /// Fetch the next chunk, returning None if there are no more chunks. 807 | pub(crate) async fn next_chunk(&mut self) -> Result, ArrowError> { 808 | // the logic here is not trivial so it wararnts a few explanations. 809 | // if there is a filter to apply, we read whatever data is needed to 810 | // evaluate it. we do pre fetch chunks here, when calling get_chunk, 811 | // and we keep going through the chunk indices until we find a chunk 812 | // where the filter condition is satisfied. 813 | // 814 | // when we do find such a chunk, we move on to the next stage, which 815 | // is to read the data for the actual query. we do save the data we 816 | // read to evaluate the filter, because some of it might also be 817 | // requested in the query. we don't request columns if they are already 818 | // present in the filter data, and then combine the filter data with 819 | // the data for the chunk for the main query. if there is a filter, 820 | // we can't pre fetch the data when reading the data for the main 821 | // query, because the next time we read some data it would be for the 822 | // fitler, not for the main query. 823 | 824 | let mut chunk_index: Option> = None; 825 | let mut filter_zarr_chunk: Option = None; 826 | 827 | let filter = self.filter.take(); 828 | if let Some(filter) = filter { 829 | let mut filter_passed = false; 830 | while !filter_passed { 831 | chunk_index = self.pop_chunk_idx(); 832 | if let Some(chunk_index) = &chunk_index { 833 | let next_chnk_idx = self.see_chunk_idx(); 834 | let column_names: Vec<_> = filter 835 | .schema_ref() 836 | .fields() 837 | .iter() 838 | .map(|f| f.name().to_owned()) 839 | .collect(); 840 | let zarr_chunk = Arc::get_mut(&mut self.zarr_store) 841 | .expect("Zarr store pointer unexpectedly not unique") 842 | .get_chunk(column_names, chunk_index.clone(), true, next_chnk_idx) 843 | .await?; 844 | filter_passed = zarr_chunk.check_filter(&filter)?; 845 | filter_zarr_chunk = Some(zarr_chunk); 846 | } else { 847 | filter_passed = true; 848 | } 849 | } 850 | self.filter = Some(filter); 851 | } else { 852 | chunk_index = self.pop_chunk_idx(); 853 | } 854 | 855 | if let Some(chunk_index) = chunk_index { 856 | let mut zarr_chunk; 857 | if self.filter.is_some() { 858 | let filter_zarr_chunk = filter_zarr_chunk.expect("Filter zarr chunk missing."); 859 | let schema = self 860 | .schema_without_filter_cols 861 | .as_ref() 862 | .expect("Schema without filter columns is missing."); 863 | 864 | let column_names: Vec<_> = schema 865 | .fields() 866 | .iter() 867 | .map(|f| f.name().to_owned()) 868 | .collect(); 869 | zarr_chunk = Arc::get_mut(&mut self.zarr_store) 870 | .expect("Zarr store pointer unexpectedly not unique") 871 | .get_chunk(column_names, chunk_index, false, None) 872 | .await?; 873 | zarr_chunk.combine(filter_zarr_chunk); 874 | } else { 875 | let next_chnk_idx = self.see_chunk_idx(); 876 | let column_names: Vec<_> = self 877 | .projected_schema_ref 878 | .fields() 879 | .iter() 880 | .map(|f| f.name().to_owned()) 881 | .collect(); 882 | 883 | zarr_chunk = Arc::get_mut(&mut self.zarr_store) 884 | .expect("Zarr store pointer unexpectedly not unique") 885 | .get_chunk(column_names, chunk_index, true, next_chnk_idx) 886 | .await?; 887 | } 888 | 889 | let record_batch = zarr_chunk.into_record_batch(&self.projected_schema_ref)?; 890 | Ok(Some(record_batch)) 891 | } else { 892 | Ok(None) 893 | } 894 | } 895 | 896 | /// Convert this into a `ZarrRecordBatchStream`, using the `async-stream` crate to handle the 897 | /// low-level specifics of stream polling. 898 | fn into_stream(mut self) -> ZarrRecordBatchStream { 899 | let schema = self.projected_schema_ref.clone(); 900 | let stream = Box::pin(try_stream! { 901 | while let Some(batch) = self.next_chunk().await? { 902 | yield batch; 903 | } 904 | }); 905 | ZarrRecordBatchStream { stream, schema } 906 | } 907 | 908 | fn pop_chunk_idx(&mut self) -> Option> { 909 | self.chunk_indices.pop_front() 910 | } 911 | 912 | fn see_chunk_idx(&self) -> Option> { 913 | self.chunk_indices.front().cloned() 914 | } 915 | 916 | /// adds a filter to avoid reading whole chunks if no values 917 | /// in the corresponding arrays pass the check. this is not to 918 | /// filter out values within a chunk, we rely on datafusion's 919 | /// default filtering for that. basically this here is to handle 920 | /// filter pushdowns. 921 | fn with_filter(mut self, filter: ZarrChunkFilter) -> ZarrQueryResult { 922 | // because we'll need to read the filter data first, evaluate 923 | // the filter, then read the data for the main query, we want 924 | // to re-use the filter data if it's also requested in the 925 | // query, so here we build the schema for the columns that 926 | // are requested in the query, but not in the filter predicate. 927 | let fields: Vec<_> = self 928 | .projected_schema_ref 929 | .fields() 930 | .iter() 931 | .filter(|f| filter.schema_ref().index_of(f.name()).is_err()) 932 | .cloned() 933 | .collect(); 934 | let schema = Schema::new(fields); 935 | self.schema_without_filter_cols = Some(Arc::new(schema)); 936 | 937 | // set the filter on the inner stream. 938 | self.filter = Some(filter); 939 | 940 | Ok(self) 941 | } 942 | } 943 | 944 | /// An async stream of record batches read from the Zarr store. 945 | /// 946 | /// This implementation is modeled to be used with the DataFusion [`RecordBatchStream`] trait. 947 | /// 948 | /// [`RecordBatchStream`]: https://docs.rs/datafusion/latest/datafusion/execution/trait.RecordBatchStream.html 949 | pub struct ZarrRecordBatchStream { 950 | stream: BoxStream<'static, Result>, 951 | schema: SchemaRef, 952 | } 953 | 954 | impl ZarrRecordBatchStream { 955 | /// Create a new ZarrRecordBatchStream. 956 | pub async fn try_new( 957 | store: Arc, 958 | schema_ref: SchemaRef, 959 | prefix: Option, 960 | projection: Option>, 961 | n_partitions: usize, 962 | partition: usize, 963 | filter: Option, 964 | ) -> ZarrQueryResult { 965 | let mut inner = ZarrRecordBatchStreamInner::new( 966 | store, 967 | schema_ref, 968 | prefix, 969 | projection, 970 | n_partitions, 971 | partition, 972 | ) 973 | .await?; 974 | 975 | if let Some(filter) = filter { 976 | inner = inner.with_filter(filter)?; 977 | } 978 | 979 | Ok(Self { 980 | schema: inner.projected_schema_ref.clone(), 981 | stream: inner.into_stream().stream, 982 | }) 983 | } 984 | 985 | /// A reference to the schema of the record batches produced by this stream. 986 | pub fn schema_ref(&self) -> &SchemaRef { 987 | &self.schema 988 | } 989 | 990 | /// The schema of the record batches produced by this stream. 991 | pub fn schema(&self) -> SchemaRef { 992 | self.schema.clone() 993 | } 994 | } 995 | 996 | impl Stream for ZarrRecordBatchStream { 997 | type Item = Result; 998 | 999 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 1000 | Pin::new(&mut self.stream).poll_next(cx) 1001 | } 1002 | } 1003 | 1004 | #[cfg(test)] 1005 | mod zarr_stream_tests { 1006 | use datafusion::logical_expr::Operator; 1007 | use datafusion::physical_expr::expressions::{col, lit}; 1008 | use datafusion::physical_plan::expressions::binary; 1009 | use futures_util::TryStreamExt; 1010 | 1011 | use super::*; 1012 | use crate::test_utils::{ 1013 | extract_col, get_local_zarr_store, get_local_zarr_store_mix_dims, validate_names_and_types, 1014 | validate_primitive_column, 1015 | }; 1016 | 1017 | #[tokio::test] 1018 | async fn read_data_test() { 1019 | let (wrapper, schema) = get_local_zarr_store(true, 0.0, "lat_lon_data").await; 1020 | let store = wrapper.get_store(); 1021 | 1022 | let stream = ZarrRecordBatchStream::try_new(store, schema, None, None, 1, 0, None) 1023 | .await 1024 | .unwrap(); 1025 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1026 | 1027 | let target_types = HashMap::from([ 1028 | ("lat".to_string(), DataType::Float64), 1029 | ("lon".to_string(), DataType::Float64), 1030 | ("data".to_string(), DataType::Float64), 1031 | ]); 1032 | validate_names_and_types(&target_types, &records[0]); 1033 | assert_eq!(records.len(), 9); 1034 | 1035 | // the top left chunk, full 3x3 1036 | validate_primitive_column::( 1037 | "lat", 1038 | &records[0], 1039 | &[35., 35., 35., 36., 36., 36., 37., 37., 37.], 1040 | ); 1041 | validate_primitive_column::( 1042 | "lon", 1043 | &records[0], 1044 | &[ 1045 | -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, 1046 | ], 1047 | ); 1048 | validate_primitive_column::( 1049 | "data", 1050 | &records[0], 1051 | &[0.0, 1.0, 2.0, 8.0, 9.0, 10.0, 16.0, 17.0, 18.0], 1052 | ); 1053 | 1054 | // the top right chunk, 3 x 2 1055 | validate_primitive_column::( 1056 | "lat", 1057 | &records[2], 1058 | &[35., 35., 36., 36., 37., 37.], 1059 | ); 1060 | validate_primitive_column::( 1061 | "lon", 1062 | &records[2], 1063 | &[-114.0, -113.0, -114.0, -113.0, -114.0, -113.0], 1064 | ); 1065 | validate_primitive_column::( 1066 | "data", 1067 | &records[2], 1068 | &[6.0, 7.0, 14.0, 15.0, 22.0, 23.0], 1069 | ); 1070 | 1071 | // the bottom right chunk, 2 x 2 1072 | validate_primitive_column::( 1073 | "lat", 1074 | &records[8], 1075 | &[41.0, 41.0, 42.0, 42.0], 1076 | ); 1077 | validate_primitive_column::( 1078 | "lon", 1079 | &records[8], 1080 | &[-114.0, -113.0, -114.0, -113.0], 1081 | ); 1082 | validate_primitive_column::( 1083 | "data", 1084 | &records[8], 1085 | &[54.0, 55.0, 62.0, 63.0], 1086 | ); 1087 | } 1088 | 1089 | #[tokio::test] 1090 | async fn filter_test() { 1091 | let (wrapper, schema) = get_local_zarr_store(true, 0.0, "lat_lon_data_with_filter").await; 1092 | let store = wrapper.get_store(); 1093 | 1094 | let expr = binary( 1095 | binary( 1096 | col("lat", &schema).unwrap(), 1097 | Operator::Lt, 1098 | lit(38.1), 1099 | &schema, 1100 | ) 1101 | .unwrap(), 1102 | Operator::And, 1103 | binary( 1104 | col("lon", &schema).unwrap(), 1105 | Operator::Gt, 1106 | lit(-116.9), 1107 | &schema, 1108 | ) 1109 | .unwrap(), 1110 | &schema, 1111 | ) 1112 | .unwrap(); 1113 | let filter = Some(ZarrChunkFilter::new(&expr, schema.clone()).unwrap()); 1114 | 1115 | let stream = ZarrRecordBatchStream::try_new(store, schema, None, None, 1, 0, filter) 1116 | .await 1117 | .unwrap(); 1118 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1119 | 1120 | let target_types = HashMap::from([ 1121 | ("lat".to_string(), DataType::Float64), 1122 | ("lon".to_string(), DataType::Float64), 1123 | ("data".to_string(), DataType::Float64), 1124 | ]); 1125 | validate_names_and_types(&target_types, &records[0]); 1126 | 1127 | // this tests for the filter push down, which doesn't completely 1128 | // filter out the results, it only drops chunks of data where 1129 | // not a single "row" passes the filter, so the condition we 1130 | // are checking lines up with the data in the chunks, and is 1131 | // a bit different from the WHERE clause. 1132 | assert_eq!(records.len(), 4); 1133 | for batch in records { 1134 | let lat_values = extract_col::("lat", &batch); 1135 | let lon_values = extract_col::("lon", &batch); 1136 | assert!(lat_values 1137 | .iter() 1138 | .zip(lon_values.iter()) 1139 | .all(|(lat, lon)| *lat < 41.0 && *lon > -118.0)); 1140 | } 1141 | } 1142 | 1143 | #[tokio::test] 1144 | async fn dimension_tests() { 1145 | // this store will have 2d lat coordinates and 1d lon coordinates. 1146 | // that shoudl effecitvely given the same as 1d and 1d. 1147 | let (wrapper, schema) = get_local_zarr_store_mix_dims(0.0, "lat_lon_mixed_dims_data").await; 1148 | let store = wrapper.get_store(); 1149 | 1150 | let stream = ZarrRecordBatchStream::try_new(store, schema, None, None, 1, 0, None) 1151 | .await 1152 | .unwrap(); 1153 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1154 | 1155 | let target_types = HashMap::from([ 1156 | ("lat".to_string(), DataType::Float64), 1157 | ("lon".to_string(), DataType::Float64), 1158 | ("data".to_string(), DataType::Float64), 1159 | ]); 1160 | validate_names_and_types(&target_types, &records[0]); 1161 | assert_eq!(records.len(), 9); 1162 | 1163 | // the top left chunk, full 3x3 1164 | validate_primitive_column::( 1165 | "lat", 1166 | &records[0], 1167 | &[35., 35., 35., 36., 36., 36., 37., 37., 37.], 1168 | ); 1169 | validate_primitive_column::( 1170 | "lon", 1171 | &records[0], 1172 | &[ 1173 | -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, 1174 | ], 1175 | ); 1176 | validate_primitive_column::( 1177 | "data", 1178 | &records[0], 1179 | &[0.0, 1.0, 2.0, 8.0, 9.0, 10.0, 16.0, 17.0, 18.0], 1180 | ); 1181 | } 1182 | 1183 | #[tokio::test] 1184 | async fn read_missing_chunks_test() { 1185 | let fillvalue = 1234.0; 1186 | let (wrapper, schema) = get_local_zarr_store(false, fillvalue, "lat_lon_empty_data").await; 1187 | let store = wrapper.get_store(); 1188 | 1189 | let stream = ZarrRecordBatchStream::try_new(store, schema, None, None, 1, 0, None) 1190 | .await 1191 | .unwrap(); 1192 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1193 | 1194 | let target_types = HashMap::from([ 1195 | ("lat".to_string(), DataType::Float64), 1196 | ("lon".to_string(), DataType::Float64), 1197 | ("data".to_string(), DataType::Float64), 1198 | ]); 1199 | validate_names_and_types(&target_types, &records[0]); 1200 | assert_eq!(records.len(), 9); 1201 | 1202 | // the top left chunk, full 3x3, but "data" is missing. 1203 | validate_primitive_column::( 1204 | "lat", 1205 | &records[0], 1206 | &[35., 35., 35., 36., 36., 36., 37., 37., 37.], 1207 | ); 1208 | validate_primitive_column::( 1209 | "lon", 1210 | &records[0], 1211 | &[ 1212 | -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, -120.0, -119.0, -118.0, 1213 | ], 1214 | ); 1215 | validate_primitive_column::("data", &records[0], &[fillvalue; 9]); 1216 | } 1217 | 1218 | #[tokio::test] 1219 | async fn read_with_partition_test() { 1220 | let (wrapper, schema) = 1221 | get_local_zarr_store(true, 0.0, "lat_lon_data_with_partition").await; 1222 | let store = wrapper.get_store(); 1223 | 1224 | let target_types = HashMap::from([ 1225 | ("lat".to_string(), DataType::Float64), 1226 | ("lon".to_string(), DataType::Float64), 1227 | ("data".to_string(), DataType::Float64), 1228 | ]); 1229 | 1230 | let stream = 1231 | ZarrRecordBatchStream::try_new(store.clone(), schema.clone(), None, None, 2, 0, None) 1232 | .await 1233 | .unwrap(); 1234 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1235 | validate_names_and_types(&target_types, &records[0]); 1236 | assert_eq!(records.len(), 5); 1237 | 1238 | let stream = ZarrRecordBatchStream::try_new(store, schema, None, None, 2, 1, None) 1239 | .await 1240 | .unwrap(); 1241 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1242 | validate_names_and_types(&target_types, &records[0]); 1243 | assert_eq!(records.len(), 4); 1244 | 1245 | // the full data has 3x3 chunks, the first partition would 1246 | // read the first 5, the second one the last 4, so the first 1247 | // chunk of the second stream would effectively be the middle 1248 | // right chunk of the full data. 1249 | validate_primitive_column::( 1250 | "lat", 1251 | &records[0], 1252 | &[38., 38., 39., 39., 40., 40.], 1253 | ); 1254 | validate_primitive_column::( 1255 | "lon", 1256 | &records[0], 1257 | &[-114.0, -113.0, -114.0, -113.0, -114.0, -113.0], 1258 | ); 1259 | validate_primitive_column::( 1260 | "data", 1261 | &records[0], 1262 | &[30.0, 31.0, 38.0, 39.0, 46.0, 47.0], 1263 | ); 1264 | } 1265 | 1266 | #[tokio::test] 1267 | async fn read_too_many_partitions_test() { 1268 | let (wrapper, schema) = 1269 | get_local_zarr_store(true, 0.0, "lat_lon_data_too_many_partition").await; 1270 | let store = wrapper.get_store(); 1271 | 1272 | // there are only 9 chunks, asking for 20 partitions, so each partition up to 1273 | // the 9th parittion should have one batch in them, after that there should be 1274 | // no data returned by the streams. 1275 | let stream = 1276 | ZarrRecordBatchStream::try_new(store.clone(), schema.clone(), None, None, 20, 0, None) 1277 | .await 1278 | .unwrap(); 1279 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1280 | assert_eq!(records.len(), 1); 1281 | 1282 | let stream = 1283 | ZarrRecordBatchStream::try_new(store.clone(), schema.clone(), None, None, 20, 8, None) 1284 | .await 1285 | .unwrap(); 1286 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1287 | assert_eq!(records.len(), 1); 1288 | 1289 | let stream = 1290 | ZarrRecordBatchStream::try_new(store.clone(), schema.clone(), None, None, 20, 10, None) 1291 | .await 1292 | .unwrap(); 1293 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1294 | assert_eq!(records.len(), 0); 1295 | 1296 | let stream = ZarrRecordBatchStream::try_new(store, schema, None, None, 20, 19, None) 1297 | .await 1298 | .unwrap(); 1299 | let records: Vec<_> = stream.try_collect().await.unwrap(); 1300 | assert_eq!(records.len(), 0); 1301 | } 1302 | } 1303 | --------------------------------------------------------------------------------