├── .idea ├── .gitignore ├── aws.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml ├── vcs.xml └── xarray-async.iml ├── README.md ├── main.py ├── requirements.txt └── src ├── __init__.py ├── fsspec ├── __init__.py └── mapping │ ├── __init__.py │ └── mapper.py ├── xarray ├── backends │ └── zarr.py ├── conventions.py ├── core │ └── variable.py └── dataset.py └── zarr ├── __init__.py ├── convenience.py ├── core.py ├── indexing.py └── storage.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/aws.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 16 | 17 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 13 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/xarray-async.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MVP for Async Zarr with Xarray 2 | ## Why? 3 | Opening zarr stores using fsspec is already very optimized because fsspec mappers utilizes a background thread that does all of the calls asynchronously. This is then wrapped in a synchronous function. 4 | 5 | There remains a challenge however in trying to open multiple stores within an event loop. A user would need to spin up a threadpool and open/retrieve data in multiple threads. A general usecase for this would be an ASGI application (like FastAPI) that needs to serve requests to multiple concurrent virtual users. 6 | 7 | Therefore, this solution offers a way to access stores with 0 context switching between threads and creates opening and reading zarr stores completely non-blocking when waiting on IO. 8 | 9 | ## What was changed? 10 | - Introduced a new fsspec mapper that has async methods using the fsspec.asyn.AsyncFileSystem 11 | - changed zarr reading to be done with async functions in a very brittle monkeypatch way 12 | - added a backend entrypoint for xarray which reads the zarr store using async directly 13 | - added async _isel and _sel methods to xarray.Dataset to allow for non-blocking data filtering. 14 | 15 | ## Notes: 16 | The purpose of this project is to spur discussion about how to allow datasets to be accessed with a minimal no context switches in an async framework 17 | There is already async functionality in the fsspec project by introducing the `getitems` method to mapper objects but this project takes it a step further and exposes an async API through the entire chain of fsspec-zarr-xarray. 18 | 19 | This intentionally does not use `dask` because I'm not comfortable enough with dask to extend the needed `dask.array` components to make this project single-threaded. 20 | 21 | ## example: 22 | ```python 23 | 24 | import asyncio 25 | 26 | import s3fs 27 | import xarray as xr 28 | 29 | from src.fsspec.mapping.mapper import AsyncFSMap 30 | from src.xarray.backends.zarr import AsyncZarrBackendEntrypint 31 | 32 | entry_point = AsyncZarrBackendEntrypint() 33 | 34 | 35 | async def get_ds(path: str, fs) -> xr.Dataset: 36 | mapper = AsyncFSMap(path, fs) 37 | return await entry_point.open_dataset(mapper) 38 | 39 | 40 | async def main(): 41 | fs = s3fs.S3FileSystem(asynchronous=True) 42 | ds = await get_ds("s3://", fs) 43 | return await ds._sel(**coordinate_selectors) 44 | 45 | if __name__ == "__main__": 46 | ds = asyncio.run(main()) 47 | print(ds) 48 | ``` 49 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MVP is designed to showcase how to read consolidated 3 | zarr stores with xarray on a single thread. 4 | This is not meant to be ready for use and should be used at your risk 5 | The point of this is to use the fsspec filesystems in a single thread and not a seperate one. 6 | 7 | Another important note here is that this doesn't utilize dask but rather the event loop used 8 | during python runtime. The usefulness here is for when there is a service that needs to open 9 | many different zarr stores concurrently without needing a threadpool. This would be more useful 10 | for ASGI applications. 11 | """ 12 | import asyncio 13 | 14 | import s3fs 15 | import xarray as xr 16 | 17 | from src.fsspec.mapping.mapper import AsyncFSMap 18 | from src.xarray.backends.zarr import AsyncZarrBackendEntrypint 19 | 20 | entry_point = AsyncZarrBackendEntrypint() 21 | 22 | 23 | async def get_ds(path: str, fs) -> xr.Dataset: 24 | mapper = AsyncFSMap(path, fs) 25 | return await entry_point.open_dataset(mapper) 26 | 27 | 28 | async def main(): 29 | fs = s3fs.S3FileSystem(asynchronous=True) 30 | ds = await get_ds("s3://", fs) 31 | return await ds._sel(lat=39.5, lon=-104.99, method="nearest") 32 | 33 | if __name__ == "__main__": 34 | ds = asyncio.run(main()) 35 | print(ds) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | xarray==2022.9.0 2 | zarr==2.13.0 3 | s3fs==2021.10.0 4 | fsspec==2021.10.0 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/__init__.py -------------------------------------------------------------------------------- /src/fsspec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/fsspec/__init__.py -------------------------------------------------------------------------------- /src/fsspec/mapping/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/fsspec/mapping/__init__.py -------------------------------------------------------------------------------- /src/fsspec/mapping/mapper.py: -------------------------------------------------------------------------------- 1 | from fsspec.asyn import AsyncFileSystem 2 | from fsspec.mapping import FSMap, maybe_convert 3 | 4 | 5 | class AsyncFSMap(FSMap): 6 | def __init__(self, root, fs, check=False, create=False, missing_exceptions=None): 7 | assert isinstance(fs, AsyncFileSystem) 8 | super().__init__(root, fs, check, create, missing_exceptions) 9 | 10 | async def clear(self): 11 | try: 12 | await self.fs._rm(self.root, True) 13 | await self.fs._mkdir(self.root) 14 | except Exception: 15 | pass 16 | 17 | async def getitems(self, keys, on_error="raise"): 18 | keys2 = [self._key_to_str(k) for k in keys] 19 | oe = on_error if on_error == "raise" else "return" 20 | try: 21 | out = await self.fs._cat(keys2, on_error=oe) 22 | if isinstance(out, bytes): 23 | out = {keys2[0]: out} 24 | except self.missing_exceptions as e: 25 | raise KeyError from e 26 | out = { 27 | k: (KeyError() if isinstance(v, self.missing_exceptions) else v) 28 | for k, v in out.items() 29 | } 30 | return { 31 | key: out[k2] 32 | for key, k2 in zip(keys, keys2) 33 | if on_error == "return" or not isinstance(out[k2], BaseException) 34 | } 35 | 36 | async def setitems(self, values_dict): 37 | values = {self._key_to_str(k): maybe_convert(v) for k, v in values_dict.items()} 38 | await self.fs._pipe(values) 39 | 40 | async def delitems(self, keys): 41 | """Remove multiple keys from the store""" 42 | await self.fs._rm([self._key_to_str(k) for k in keys]) 43 | 44 | async def __getitem__(self, key, default=None): 45 | """Retrieve data""" 46 | k = self._key_to_str(key) 47 | try: 48 | result = await self.fs._cat(k) 49 | except self.missing_exceptions: 50 | if default is not None: 51 | return default 52 | raise KeyError(key) 53 | return result 54 | 55 | async def pop(self, key, default=None): 56 | result = await self.__getitem__(key, default) 57 | try: 58 | del self[key] 59 | except KeyError: 60 | pass 61 | return result 62 | 63 | async def __setitem__(self, key, value): 64 | """Store value in key""" 65 | key = self._key_to_str(key) 66 | await self.fs._makedirs(self.fs._parent(key), exist_ok=True) 67 | await self.fs._pipe_file(key, maybe_convert(value)) 68 | 69 | async def __iter__(self): 70 | return (self._str_to_key(x) for x in await self.fs._find(self.root)) 71 | 72 | async def __len__(self): 73 | return len(self.fs._find(self.root)) 74 | 75 | async def __delitem__(self, key): 76 | """Remove key""" 77 | try: 78 | await self.fs._rm(self._key_to_str(key)) 79 | except: # noqa: E722 80 | raise KeyError 81 | 82 | async def __contains__(self, key): 83 | """Does key exist in mapping?""" 84 | path = self._key_to_str(key) 85 | return await self.fs._exists(path) and await self.fs._isfile(path) 86 | -------------------------------------------------------------------------------- /src/xarray/backends/zarr.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | from xarray import conventions 8 | from xarray.backends.common import _decode_variable_name, _normalize_path 9 | from xarray.backends.store import StoreBackendEntrypoint 10 | from xarray.backends.zarr import ( 11 | DIMENSION_KEY, 12 | ZarrArrayWrapper, 13 | ZarrBackendEntrypoint, 14 | ZarrStore, 15 | _get_zarr_dims_and_attrs, 16 | ) 17 | from xarray.core import indexing 18 | from xarray.core.utils import FrozenDict 19 | 20 | from ..conventions import decode_cf_variable 21 | from ..core.variable import Variable 22 | from ..dataset import Dataset 23 | from ...zarr.convenience import open_consolidated 24 | 25 | sys.modules["xarray.conventions"].decode_cf_variable = decode_cf_variable 26 | 27 | 28 | class AsyncArrayWrapper(ZarrArrayWrapper): 29 | async def __array__(self, dtype=None): 30 | key = indexing.BasicIndexer((slice(None),) * self.ndim) 31 | return np.asarray(await self[key], dtype=dtype) 32 | 33 | async def __getitem__(self, key): 34 | array = self.get_array() 35 | if isinstance(key, indexing.BasicIndexer): 36 | return await array[key.tuple] 37 | elif isinstance(key, indexing.VectorizedIndexer): 38 | return await array.vindex[ 39 | indexing._arrayize_vectorized_indexer(key, self.shape).tuple 40 | ] 41 | else: 42 | assert isinstance(key, indexing.OuterIndexer) 43 | return await array.oindex[key.tuple] 44 | 45 | 46 | class AsyncStore(ZarrStore): 47 | @classmethod 48 | async def open_group( 49 | cls, 50 | store, 51 | mode="r", 52 | synchronizer=None, 53 | group=None, 54 | consolidated=False, 55 | consolidate_on_close=False, 56 | chunk_store=None, 57 | storage_options=None, 58 | append_dim=None, 59 | write_region=None, 60 | safe_chunks=True, 61 | stacklevel=2, 62 | ): 63 | if isinstance(store, os.PathLike): 64 | raise NotImplementedError("cannot do local storage zarr") 65 | 66 | open_kwargs = dict( 67 | mode=mode, 68 | synchronizer=synchronizer, 69 | path=group, 70 | ) 71 | open_kwargs["storage_options"] = storage_options 72 | 73 | if chunk_store: 74 | open_kwargs["chunk_store"] = chunk_store 75 | if consolidated is None: 76 | consolidated = False 77 | 78 | if consolidated is None: 79 | try: 80 | zarr_group = await open_consolidated(store, **open_kwargs) 81 | except KeyError: 82 | raise NotImplementedError("not ready for non-consolidated stores") 83 | elif consolidated: 84 | # TODO: an option to pass the metadata_key keyword 85 | zarr_group = await open_consolidated(store, **open_kwargs) 86 | else: 87 | raise NotImplementedError("not ready for non-consolidated stores") 88 | return cls( 89 | zarr_group, 90 | mode, 91 | consolidate_on_close, 92 | append_dim, 93 | write_region, 94 | safe_chunks, 95 | ) 96 | 97 | async def load(self): 98 | variables = FrozenDict( 99 | (_decode_variable_name(k), v) 100 | for k, v in (await self.get_variables()).items() 101 | ) 102 | attributes = FrozenDict(self.get_attrs()) 103 | return variables, attributes 104 | 105 | async def get_variables(self): 106 | return FrozenDict( 107 | await asyncio.gather( 108 | *[ 109 | self.open_store_variable_with_key(k, v, k) 110 | for k, v in self.zarr_group.arrays() 111 | ] 112 | ) 113 | ) 114 | 115 | async def open_store_variable_with_key(self, name, zarr_array, label): 116 | return label, await self.open_store_variable(name, zarr_array) 117 | 118 | async def open_store_variable(self, name, zarr_array): 119 | data = AsyncArrayWrapper(name, self) 120 | try_nczarr = self._mode == "r" 121 | dimensions, attributes = _get_zarr_dims_and_attrs( 122 | zarr_array, DIMENSION_KEY, try_nczarr 123 | ) 124 | attributes = dict(attributes) 125 | encoding = { 126 | "chunks": zarr_array.chunks, 127 | "preferred_chunks": dict(zip(dimensions, zarr_array.chunks)), 128 | "compressor": zarr_array.compressor, 129 | "filters": zarr_array.filters, 130 | } 131 | # _FillValue needs to be in attributes, not encoding, so it will get 132 | # picked up by decode_cf 133 | if getattr(zarr_array, "fill_value") is not None: 134 | attributes["_FillValue"] = zarr_array.fill_value 135 | 136 | variable = Variable(dimensions, data, attributes, encoding) 137 | await variable.maybe_preload() 138 | return variable 139 | 140 | 141 | class AsyncStoreBackendEntrypoint(StoreBackendEntrypoint): 142 | async def open_dataset( 143 | self, 144 | store, 145 | *, 146 | mask_and_scale=True, 147 | decode_times=True, 148 | concat_characters=True, 149 | decode_coords=True, 150 | drop_variables=None, 151 | use_cftime=None, 152 | decode_timedelta=None, 153 | ): 154 | vars, attrs = await store.load() 155 | encoding = store.get_encoding() 156 | 157 | vars, attrs, coord_names = conventions.decode_cf_variables( 158 | vars, 159 | attrs, 160 | mask_and_scale=mask_and_scale, 161 | decode_times=decode_times, 162 | concat_characters=concat_characters, 163 | decode_coords=decode_coords, 164 | drop_variables=drop_variables, 165 | use_cftime=use_cftime, 166 | decode_timedelta=decode_timedelta, 167 | ) 168 | 169 | ds = Dataset(vars, attrs=attrs) 170 | ds = ds.set_coords(coord_names.intersection(vars)) 171 | ds.set_close(store.close) 172 | ds.encoding = encoding 173 | 174 | return ds 175 | 176 | 177 | class AsyncZarrBackendEntrypint(ZarrBackendEntrypoint): 178 | async def open_dataset( 179 | self, 180 | filename_or_obj, 181 | mask_and_scale=True, 182 | decode_times=True, 183 | concat_characters=True, 184 | decode_coords=True, 185 | drop_variables=None, 186 | use_cftime=None, 187 | decode_timedelta=None, 188 | group=None, 189 | mode="r", 190 | synchronizer=None, 191 | consolidated=None, 192 | chunk_store=None, 193 | storage_options=None, 194 | stacklevel=3, 195 | ): 196 | filename_or_obj = _normalize_path(filename_or_obj) 197 | store = await AsyncStore.open_group( 198 | filename_or_obj, 199 | group=group, 200 | mode=mode, 201 | synchronizer=synchronizer, 202 | consolidated=consolidated, 203 | consolidate_on_close=False, 204 | chunk_store=chunk_store, 205 | storage_options=storage_options, 206 | stacklevel=stacklevel + 1, 207 | ) 208 | 209 | store_entrypoint = AsyncStoreBackendEntrypoint() 210 | with close_on_error(store): 211 | ds = await store_entrypoint.open_dataset( 212 | store, 213 | mask_and_scale=mask_and_scale, 214 | decode_times=decode_times, 215 | concat_characters=concat_characters, 216 | decode_coords=decode_coords, 217 | drop_variables=drop_variables, 218 | use_cftime=use_cftime, 219 | decode_timedelta=decode_timedelta, 220 | ) 221 | return ds 222 | 223 | 224 | @contextlib.contextmanager 225 | def close_on_error(f): 226 | """Context manager to ensure that a file opened by xarray is closed if an 227 | exception is raised before the user sees the file object. 228 | """ 229 | try: 230 | yield 231 | except Exception: 232 | f.close() 233 | raise 234 | -------------------------------------------------------------------------------- /src/xarray/conventions.py: -------------------------------------------------------------------------------- 1 | from xarray import conventions 2 | 3 | from .core.variable import Variable 4 | 5 | 6 | def decode_cf_variable( 7 | name, 8 | var, 9 | concat_characters=True, 10 | mask_and_scale=True, 11 | decode_times=True, 12 | decode_endianness=True, 13 | stack_char_dim=True, 14 | use_cftime=None, 15 | decode_timedelta=None, 16 | ): 17 | # Ensure datetime-like Variables are passed through unmodified (GH 6453) 18 | if conventions._contains_datetime_like_objects(var): 19 | return var 20 | 21 | original_dtype = var.dtype 22 | 23 | if decode_timedelta is None: 24 | decode_timedelta = decode_times 25 | 26 | if concat_characters: 27 | if stack_char_dim: 28 | var = conventions.strings.CharacterArrayCoder().decode(var, name=name) 29 | var = conventions.strings.EncodedStringCoder().decode(var) 30 | 31 | if mask_and_scale: 32 | for coder in [ 33 | conventions.variables.UnsignedIntegerCoder(), 34 | conventions.variables.CFMaskCoder(), 35 | conventions.variables.CFScaleOffsetCoder(), 36 | ]: 37 | var = coder.decode(var, name=name) 38 | 39 | if decode_timedelta: 40 | var = conventions.times.CFTimedeltaCoder().decode(var, name=name) 41 | if decode_times: 42 | var = conventions.times.CFDatetimeCoder(use_cftime=use_cftime).decode( 43 | var, name=name 44 | ) 45 | 46 | dimensions, data, attributes, encoding = conventions.variables.unpack_for_decoding( 47 | var 48 | ) 49 | # TODO(shoyer): convert everything below to use coders 50 | 51 | if decode_endianness and not data.dtype.isnative: 52 | # do this last, so it's only done if we didn't already unmask/scale 53 | data = conventions.NativeEndiannessArray(data) 54 | original_dtype = data.dtype 55 | 56 | encoding.setdefault("dtype", original_dtype) 57 | 58 | if "dtype" in attributes and attributes["dtype"] == "bool": 59 | del attributes["dtype"] 60 | data = conventions.BoolTypeArray(data) 61 | 62 | return Variable(dimensions, data, attributes, encoding=encoding) 63 | -------------------------------------------------------------------------------- /src/xarray/core/variable.py: -------------------------------------------------------------------------------- 1 | from asyncio import iscoroutinefunction 2 | from typing import Any 3 | 4 | import numpy as np 5 | from xarray.core import variable 6 | 7 | 8 | class Variable(variable.Variable): 9 | async def maybe_preload(self): 10 | if self.size < 1e5 and iscoroutinefunction(getattr(self._data, "__array__")): 11 | self._data = await self._data.__array__() 12 | 13 | @property 14 | def data(self) -> Any: 15 | return self._data 16 | 17 | @data.setter 18 | def data(self, data): 19 | data = variable.as_compatible_data(data) 20 | if data.shape != self.shape: 21 | raise ValueError( 22 | f"replacement data must match the Variable's shape. " 23 | f"replacement data has shape {data.shape}; Variable has shape {self.shape}" 24 | ) 25 | self._data = data 26 | 27 | async def _isel( 28 | self, 29 | indexers: variable.Mapping[Any, Any] = None, 30 | missing_dims="raise", 31 | **indexers_kwargs: Any, 32 | ): 33 | indexers = variable.either_dict_or_kwargs(indexers, indexers_kwargs, "isel") 34 | indexers = variable.drop_dims_from_indexers(indexers, self.dims, missing_dims) 35 | 36 | key = tuple(indexers.get(dim, slice(None)) for dim in self.dims) 37 | return await self.__agetitem__(key) 38 | 39 | async def __agetitem__(self, key): 40 | """Return a new Variable object whose contents are consistent with 41 | getting the provided key from the underlying data. 42 | 43 | NB. __getitem__ and __setitem__ implement xarray-style indexing, 44 | where if keys are unlabeled arrays, we index the array orthogonally 45 | with them. If keys are labeled array (such as Variables), they are 46 | broadcasted with our usual scheme and then the array is indexed with 47 | the broadcasted key, like numpy's fancy indexing. 48 | 49 | If you really want to do indexing like `x[x > 0]`, manipulate the numpy 50 | array `x.values` directly. 51 | """ 52 | dims, indexer, new_order = self._broadcast_indexes(key) 53 | data = await self._data[indexer] 54 | if new_order: 55 | data = np.moveaxis(data, range(len(new_order)), new_order) 56 | return self._finalize_indexing_result(dims, data) 57 | -------------------------------------------------------------------------------- /src/xarray/dataset.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Hashable, Iterable, Mapping 3 | 4 | from xarray import Dataset as XDs 5 | from xarray.core.indexes import isel_indexes 6 | from xarray.core.indexing import is_fancy_indexer, map_index_queries 7 | from xarray.core.utils import drop_dims_from_indexers, either_dict_or_kwargs 8 | 9 | 10 | class Dataset(XDs): 11 | async def _isel( 12 | self, 13 | indexers: Mapping[Any, Any] | None = None, 14 | drop: bool = False, 15 | missing_dims="raise", 16 | **indexers_kwargs: Any, 17 | ): 18 | indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") 19 | if any(is_fancy_indexer(idx) for idx in indexers.values()): 20 | return self._isel_fancy(indexers, drop=drop, missing_dims=missing_dims) 21 | 22 | # Much faster algorithm for when all indexers are ints, slices, one-dimensional 23 | # lists, or zero or one-dimensional np.ndarray's 24 | indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims) 25 | 26 | variables = {} 27 | dims: dict[Hashable, int] = {} 28 | coord_names = self._coord_names.copy() 29 | 30 | indexes, index_variables = isel_indexes(self.xindexes, indexers) 31 | 32 | async def place_var(name, var): 33 | if name in index_variables: 34 | var = index_variables[name] 35 | else: 36 | var_indexers = {k: v for k, v in indexers.items() if k in var.dims} 37 | if var_indexers: 38 | if hasattr(var, "_isel"): 39 | var = await var._isel(var_indexers) 40 | else: 41 | var = var.isel(var_indexers) 42 | if drop and var.ndim == 0 and name in coord_names: 43 | coord_names.remove(name) 44 | return 45 | variables[name] = var 46 | dims.update(zip(var.dims, var.shape)) 47 | 48 | await asyncio.gather( 49 | *[place_var(name, var) for name, var in self._variables.items()] 50 | ) 51 | # preserve variable order 52 | # if name in index_variables: 53 | # var = index_variables[name] 54 | # else: 55 | # var_indexers = {k: v for k, v in indexers.items() if k in var.dims} 56 | # if var_indexers: 57 | # var = var.isel(var_indexers) 58 | # if drop and var.ndim == 0 and name in coord_names: 59 | # coord_names.remove(name) 60 | # continue 61 | # variables[name] = var 62 | # dims.update(zip(var.dims, var.shape)) 63 | 64 | return self._construct_direct( 65 | variables=variables, 66 | coord_names=coord_names, 67 | dims=dims, 68 | attrs=self._attrs, 69 | indexes=indexes, 70 | encoding=self._encoding, 71 | close=self._close, 72 | ) 73 | 74 | async def _sel( 75 | self, 76 | indexers: Mapping[Any, Any] = None, 77 | method: str = None, 78 | tolerance: int | float | Iterable[int | float] | None = None, 79 | drop: bool = False, 80 | **indexers_kwargs: Any, 81 | ): 82 | indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") 83 | query_results = map_index_queries( 84 | self, indexers=indexers, method=method, tolerance=tolerance 85 | ) 86 | 87 | if drop: 88 | no_scalar_variables = {} 89 | for k, v in query_results.variables.items(): 90 | if v.dims: 91 | no_scalar_variables[k] = v 92 | else: 93 | if k in self._coord_names: 94 | query_results.drop_coords.append(k) 95 | query_results.variables = no_scalar_variables 96 | 97 | result = await self._isel(indexers=query_results.dim_indexers, drop=drop) 98 | return result._overwrite_indexes(*query_results.as_tuple()[1:]) 99 | -------------------------------------------------------------------------------- /src/zarr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/zarr/__init__.py -------------------------------------------------------------------------------- /src/zarr/convenience.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from zarr.convenience import StoreLike, open 4 | from zarr.storage import normalize_store_arg 5 | 6 | from .core import Array 7 | from .storage import ConsolidatedMetadataStore 8 | 9 | sys.modules["zarr.core"].Array = Array 10 | sys.modules["zarr.hierarchy"].Array = Array 11 | 12 | 13 | async def open_consolidated( 14 | store: StoreLike, metadata_key=".zmetadata", mode="r+", **kwargs 15 | ): 16 | zarr_version = kwargs.get("zarr_version") 17 | store = normalize_store_arg( 18 | store, 19 | storage_options=kwargs.get("storage_options"), 20 | mode=mode, 21 | zarr_version=zarr_version, 22 | ) 23 | if mode not in {"r", "r+"}: 24 | raise ValueError( 25 | "invalid mode, expected either 'r' or 'r+'; found {!r}".format(mode) 26 | ) 27 | 28 | path = kwargs.pop("path", None) 29 | if store._store_version == 2: 30 | ConsolidatedStoreClass = ConsolidatedMetadataStore 31 | else: 32 | raise NotImplementedError() 33 | # assert_zarr_v3_api_available() 34 | # ConsolidatedStoreClass = ConsolidatedMetadataStoreV3 35 | # # default is to store within 'consolidated' group on v3 36 | # if not metadata_key.startswith("meta/root/"): 37 | # metadata_key = "meta/root/consolidated/" + metadata_key 38 | 39 | # setup metadata store 40 | meta_store = ConsolidatedStoreClass(store, metadata_key=metadata_key) 41 | await meta_store.init_coro 42 | 43 | # pass through 44 | chunk_store = kwargs.pop("chunk_store", None) or store 45 | return open( 46 | store=meta_store, chunk_store=chunk_store, mode=mode, path=path, **kwargs 47 | ) 48 | -------------------------------------------------------------------------------- /src/zarr/core.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | 4 | import numpy as np 5 | from numcodecs.compat import ensure_ndarray_like 6 | from zarr.core import Array as ZA 7 | from zarr.errors import err_too_many_indices 8 | from zarr.indexing import ( 9 | BasicIndexer, 10 | CoordinateIndexer, 11 | MaskIndexer, 12 | OrthogonalIndexer, 13 | check_fields, 14 | ensure_tuple, 15 | is_pure_fancy_indexing, 16 | pop_fields, 17 | ) 18 | from zarr.util import check_array_shape 19 | 20 | from .indexing import OIndex, VIndex 21 | 22 | sys.modules["zarr.core"].OIndex = OIndex 23 | sys.modules["zarr.core"].VIndex = VIndex 24 | 25 | 26 | class Array(ZA): 27 | async def __array__(self, *args): 28 | a = await self[...] 29 | if args: 30 | a = a.astype(args[0]) 31 | return a 32 | 33 | async def islice(self, start=None, end=None): 34 | if len(self.shape) == 0: 35 | # Same error as numpy 36 | raise TypeError("iteration over a 0-d array") 37 | if start is None: 38 | start = 0 39 | if end is None or end > self.shape[0]: 40 | end = self.shape[0] 41 | 42 | if not isinstance(start, int) or start < 0: 43 | raise ValueError("start must be a nonnegative integer") 44 | 45 | if not isinstance(end, int) or end < 0: 46 | raise ValueError("end must be a nonnegative integer") 47 | 48 | # Avoid repeatedly decompressing chunks by iterating over the chunks 49 | # in the first dimension. 50 | chunk_size = self.chunks[0] 51 | chunk = None 52 | for j in range(start, end): 53 | if j % chunk_size == 0: 54 | chunk = await self[j : j + chunk_size] 55 | # init chunk if we start offset of chunk borders 56 | elif chunk is None: 57 | chunk_start = j - j % chunk_size 58 | chunk_end = chunk_start + chunk_size 59 | chunk = await self[chunk_start:chunk_end] 60 | yield await chunk[j % chunk_size] 61 | 62 | async def __iter__(self): 63 | return await self.islice() 64 | 65 | async def __aiter__(self): 66 | return await self.islice() 67 | 68 | async def __getitem__(self, selection): 69 | fields, pure_selection = pop_fields(selection) 70 | if is_pure_fancy_indexing(pure_selection, self.ndim): 71 | result = await self.vindex[selection] 72 | else: 73 | result = await self.get_basic_selection(pure_selection, fields=fields) 74 | return result 75 | 76 | async def get_basic_selection(self, selection=Ellipsis, out=None, fields=None): 77 | if not self._cache_metadata: 78 | self._load_metadata() 79 | check_fields(fields, self._dtype) 80 | if self._shape == (): 81 | return await self._get_basic_selection_zd( 82 | selection=selection, out=out, fields=fields 83 | ) 84 | else: 85 | return await self._get_basic_selection_nd( 86 | selection=selection, out=out, fields=fields 87 | ) 88 | 89 | async def _get_basic_selection_zd(self, selection, out=None, fields=None): 90 | selection = ensure_tuple(selection) 91 | if selection not in ((), (Ellipsis,)): 92 | err_too_many_indices(selection, ()) 93 | 94 | try: 95 | # obtain encoded data for chunk 96 | ckey = self._chunk_key((0,)) 97 | cdata = await self.chunk_store[ckey] 98 | 99 | except KeyError: 100 | # chunk not initialized 101 | chunk = np.zeros_like(self._meta_array, shape=(), dtype=self._dtype) 102 | if self._fill_value is not None: 103 | chunk.fill(self._fill_value) 104 | 105 | else: 106 | chunk = self._decode_chunk(cdata) 107 | 108 | # handle fields 109 | if fields: 110 | chunk = chunk[fields] 111 | 112 | # handle selection of the scalar value via empty tuple 113 | if out is None: 114 | out = chunk[selection] 115 | else: 116 | out[selection] = chunk[selection] 117 | 118 | return out 119 | 120 | async def _get_basic_selection_nd(self, selection, out=None, fields=None): 121 | indexer = BasicIndexer(selection, self) 122 | 123 | return await self._get_selection(indexer=indexer, out=out, fields=fields) 124 | 125 | async def _get_selection(self, indexer, out=None, fields=None): 126 | out_dtype = check_fields(fields, self._dtype) 127 | out_shape = indexer.shape 128 | if out is None: 129 | out = np.empty_like( 130 | self._meta_array, shape=out_shape, dtype=out_dtype, order=self._order 131 | ) 132 | else: 133 | check_array_shape("out", out, out_shape) 134 | # sequentially get one key at a time from storage 135 | await asyncio.gather( 136 | *[ 137 | self._chunk_getitem( 138 | chunk_coords, 139 | chunk_selection, 140 | out, 141 | out_selection, 142 | drop_axes=indexer.drop_axes, 143 | fields=fields, 144 | ) 145 | for chunk_coords, chunk_selection, out_selection in indexer 146 | ] 147 | ) 148 | 149 | if out.shape: 150 | return out 151 | else: 152 | return out[()] 153 | 154 | async def _chunk_getitem( 155 | self, 156 | chunk_coords, 157 | chunk_selection, 158 | out, 159 | out_selection, 160 | drop_axes=None, 161 | fields=None, 162 | ): 163 | out_is_ndarray = True 164 | try: 165 | out = ensure_ndarray_like(out) 166 | except TypeError: 167 | out_is_ndarray = False 168 | 169 | assert len(chunk_coords) == len(self._cdata_shape) 170 | ckey = self._chunk_key(chunk_coords) 171 | try: 172 | # obtain compressed data for chunk 173 | cdata = await self.chunk_store[ckey] 174 | except KeyError: 175 | # chunk not initialized 176 | if self._fill_value is not None: 177 | if fields: 178 | fill_value = self._fill_value[fields] 179 | else: 180 | fill_value = self._fill_value 181 | out[out_selection] = fill_value 182 | 183 | else: 184 | self._process_chunk( 185 | out, 186 | cdata, 187 | chunk_selection, 188 | drop_axes, 189 | out_is_ndarray, 190 | fields, 191 | out_selection, 192 | ) 193 | 194 | async def get_orthogonal_selection(self, selection, out=None, fields=None): 195 | if not self._cache_metadata: 196 | self._load_metadata() 197 | check_fields(fields, self._dtype) 198 | indexer = OrthogonalIndexer(selection, self) 199 | return await self._get_selection(indexer=indexer, out=out, fields=fields) 200 | 201 | async def set_orthogonal_selection(self, selection, value, fields=None): 202 | raise NotImplementedError() 203 | 204 | async def get_coordinate_selection(self, selection, out=None, fields=None): 205 | if not self._cache_metadata: 206 | self._load_metadata() 207 | 208 | # check args 209 | check_fields(fields, self._dtype) 210 | 211 | # setup indexer 212 | indexer = CoordinateIndexer(selection, self) 213 | 214 | # handle output - need to flatten 215 | if out is not None: 216 | out = out.reshape(-1) 217 | 218 | out = await self._get_selection(indexer=indexer, out=out, fields=fields) 219 | 220 | # restore shape 221 | out = out.reshape(indexer.sel_shape) 222 | 223 | return out 224 | 225 | async def get_mask_selection(self, selection, out=None, fields=None): 226 | if not self._cache_metadata: 227 | self._load_metadata() 228 | 229 | # check args 230 | check_fields(fields, self._dtype) 231 | 232 | # setup indexer 233 | indexer = MaskIndexer(selection, self) 234 | 235 | return await self._get_selection(indexer=indexer, out=out, fields=fields) 236 | -------------------------------------------------------------------------------- /src/zarr/indexing.py: -------------------------------------------------------------------------------- 1 | from zarr.errors import VindexInvalidSelectionError 2 | from zarr.indexing import OIndex as ZO 3 | from zarr.indexing import VIndex as ZV 4 | from zarr.indexing import ( 5 | ensure_tuple, 6 | is_coordinate_selection, 7 | is_mask_selection, 8 | pop_fields, 9 | replace_lists, 10 | ) 11 | 12 | 13 | class OIndex(ZO): 14 | async def __getitem__(self, selection): 15 | fields, selection = pop_fields(selection) 16 | selection = ensure_tuple(selection) 17 | selection = replace_lists(selection) 18 | return await self.array.get_orthogonal_selection(selection, fields=fields) 19 | 20 | async def __setitem__(self, selection, value): 21 | fields, selection = pop_fields(selection) 22 | selection = ensure_tuple(selection) 23 | selection = replace_lists(selection) 24 | return await self.array.set_orthogonal_selection( 25 | selection, value, fields=fields 26 | ) 27 | 28 | 29 | class VIndex(ZV): 30 | async def __getitem__(self, selection): 31 | fields, selection = pop_fields(selection) 32 | selection = ensure_tuple(selection) 33 | selection = replace_lists(selection) 34 | if is_coordinate_selection(selection, self.array): 35 | return await self.array.get_coordinate_selection(selection, fields=fields) 36 | elif is_mask_selection(selection, self.array): 37 | return await self.array.get_mask_selection(selection, fields=fields) 38 | else: 39 | raise VindexInvalidSelectionError(selection) 40 | -------------------------------------------------------------------------------- /src/zarr/storage.py: -------------------------------------------------------------------------------- 1 | from zarr.errors import MetadataError 2 | from zarr.storage import ConsolidatedMetadataStore as zCMS 3 | from zarr.storage import KVStore, Store, StoreLike 4 | from zarr.util import json_loads 5 | 6 | 7 | class ConsolidatedMetadataStore(zCMS): 8 | def __init__(self, store: StoreLike, metadata_key=".zmetadata"): 9 | self.init_coro = self.ainit(store, metadata_key) 10 | 11 | async def ainit(self, store: StoreLike, metadata_key=".zmetadata"): 12 | self.store = Store._ensure_store(store) 13 | 14 | # retrieve consolidated metadata 15 | meta = json_loads(await self.store[metadata_key]) 16 | 17 | # check format of consolidated metadata 18 | consolidated_format = meta.get("zarr_consolidated_format", None) 19 | if consolidated_format != 1: 20 | raise MetadataError( 21 | "unsupported zarr consolidated metadata format: %s" 22 | % consolidated_format 23 | ) 24 | 25 | # decode metadata 26 | self.meta_store: Store = KVStore(meta["metadata"]) 27 | --------------------------------------------------------------------------------