├── .idea
├── .gitignore
├── aws.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── other.xml
├── vcs.xml
└── xarray-async.iml
├── README.md
├── main.py
├── requirements.txt
└── src
├── __init__.py
├── fsspec
├── __init__.py
└── mapping
│ ├── __init__.py
│ └── mapper.py
├── xarray
├── backends
│ └── zarr.py
├── conventions.py
├── core
│ └── variable.py
└── dataset.py
└── zarr
├── __init__.py
├── convenience.py
├── core.py
├── indexing.py
└── storage.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/aws.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/xarray-async.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MVP for Async Zarr with Xarray
2 | ## Why?
3 | Opening zarr stores using fsspec is already very optimized because fsspec mappers utilizes a background thread that does all of the calls asynchronously. This is then wrapped in a synchronous function.
4 |
5 | There remains a challenge however in trying to open multiple stores within an event loop. A user would need to spin up a threadpool and open/retrieve data in multiple threads. A general usecase for this would be an ASGI application (like FastAPI) that needs to serve requests to multiple concurrent virtual users.
6 |
7 | Therefore, this solution offers a way to access stores with 0 context switching between threads and creates opening and reading zarr stores completely non-blocking when waiting on IO.
8 |
9 | ## What was changed?
10 | - Introduced a new fsspec mapper that has async methods using the fsspec.asyn.AsyncFileSystem
11 | - changed zarr reading to be done with async functions in a very brittle monkeypatch way
12 | - added a backend entrypoint for xarray which reads the zarr store using async directly
13 | - added async _isel and _sel methods to xarray.Dataset to allow for non-blocking data filtering.
14 |
15 | ## Notes:
16 | The purpose of this project is to spur discussion about how to allow datasets to be accessed with a minimal no context switches in an async framework
17 | There is already async functionality in the fsspec project by introducing the `getitems` method to mapper objects but this project takes it a step further and exposes an async API through the entire chain of fsspec-zarr-xarray.
18 |
19 | This intentionally does not use `dask` because I'm not comfortable enough with dask to extend the needed `dask.array` components to make this project single-threaded.
20 |
21 | ## example:
22 | ```python
23 |
24 | import asyncio
25 |
26 | import s3fs
27 | import xarray as xr
28 |
29 | from src.fsspec.mapping.mapper import AsyncFSMap
30 | from src.xarray.backends.zarr import AsyncZarrBackendEntrypint
31 |
32 | entry_point = AsyncZarrBackendEntrypint()
33 |
34 |
35 | async def get_ds(path: str, fs) -> xr.Dataset:
36 | mapper = AsyncFSMap(path, fs)
37 | return await entry_point.open_dataset(mapper)
38 |
39 |
40 | async def main():
41 | fs = s3fs.S3FileSystem(asynchronous=True)
42 | ds = await get_ds("s3://", fs)
43 | return await ds._sel(**coordinate_selectors)
44 |
45 | if __name__ == "__main__":
46 | ds = asyncio.run(main())
47 | print(ds)
48 | ```
49 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | This MVP is designed to showcase how to read consolidated
3 | zarr stores with xarray on a single thread.
4 | This is not meant to be ready for use and should be used at your risk
5 | The point of this is to use the fsspec filesystems in a single thread and not a seperate one.
6 |
7 | Another important note here is that this doesn't utilize dask but rather the event loop used
8 | during python runtime. The usefulness here is for when there is a service that needs to open
9 | many different zarr stores concurrently without needing a threadpool. This would be more useful
10 | for ASGI applications.
11 | """
12 | import asyncio
13 |
14 | import s3fs
15 | import xarray as xr
16 |
17 | from src.fsspec.mapping.mapper import AsyncFSMap
18 | from src.xarray.backends.zarr import AsyncZarrBackendEntrypint
19 |
20 | entry_point = AsyncZarrBackendEntrypint()
21 |
22 |
23 | async def get_ds(path: str, fs) -> xr.Dataset:
24 | mapper = AsyncFSMap(path, fs)
25 | return await entry_point.open_dataset(mapper)
26 |
27 |
28 | async def main():
29 | fs = s3fs.S3FileSystem(asynchronous=True)
30 | ds = await get_ds("s3://", fs)
31 | return await ds._sel(lat=39.5, lon=-104.99, method="nearest")
32 |
33 | if __name__ == "__main__":
34 | ds = asyncio.run(main())
35 | print(ds)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | xarray==2022.9.0
2 | zarr==2.13.0
3 | s3fs==2021.10.0
4 | fsspec==2021.10.0
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/__init__.py
--------------------------------------------------------------------------------
/src/fsspec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/fsspec/__init__.py
--------------------------------------------------------------------------------
/src/fsspec/mapping/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/fsspec/mapping/__init__.py
--------------------------------------------------------------------------------
/src/fsspec/mapping/mapper.py:
--------------------------------------------------------------------------------
1 | from fsspec.asyn import AsyncFileSystem
2 | from fsspec.mapping import FSMap, maybe_convert
3 |
4 |
5 | class AsyncFSMap(FSMap):
6 | def __init__(self, root, fs, check=False, create=False, missing_exceptions=None):
7 | assert isinstance(fs, AsyncFileSystem)
8 | super().__init__(root, fs, check, create, missing_exceptions)
9 |
10 | async def clear(self):
11 | try:
12 | await self.fs._rm(self.root, True)
13 | await self.fs._mkdir(self.root)
14 | except Exception:
15 | pass
16 |
17 | async def getitems(self, keys, on_error="raise"):
18 | keys2 = [self._key_to_str(k) for k in keys]
19 | oe = on_error if on_error == "raise" else "return"
20 | try:
21 | out = await self.fs._cat(keys2, on_error=oe)
22 | if isinstance(out, bytes):
23 | out = {keys2[0]: out}
24 | except self.missing_exceptions as e:
25 | raise KeyError from e
26 | out = {
27 | k: (KeyError() if isinstance(v, self.missing_exceptions) else v)
28 | for k, v in out.items()
29 | }
30 | return {
31 | key: out[k2]
32 | for key, k2 in zip(keys, keys2)
33 | if on_error == "return" or not isinstance(out[k2], BaseException)
34 | }
35 |
36 | async def setitems(self, values_dict):
37 | values = {self._key_to_str(k): maybe_convert(v) for k, v in values_dict.items()}
38 | await self.fs._pipe(values)
39 |
40 | async def delitems(self, keys):
41 | """Remove multiple keys from the store"""
42 | await self.fs._rm([self._key_to_str(k) for k in keys])
43 |
44 | async def __getitem__(self, key, default=None):
45 | """Retrieve data"""
46 | k = self._key_to_str(key)
47 | try:
48 | result = await self.fs._cat(k)
49 | except self.missing_exceptions:
50 | if default is not None:
51 | return default
52 | raise KeyError(key)
53 | return result
54 |
55 | async def pop(self, key, default=None):
56 | result = await self.__getitem__(key, default)
57 | try:
58 | del self[key]
59 | except KeyError:
60 | pass
61 | return result
62 |
63 | async def __setitem__(self, key, value):
64 | """Store value in key"""
65 | key = self._key_to_str(key)
66 | await self.fs._makedirs(self.fs._parent(key), exist_ok=True)
67 | await self.fs._pipe_file(key, maybe_convert(value))
68 |
69 | async def __iter__(self):
70 | return (self._str_to_key(x) for x in await self.fs._find(self.root))
71 |
72 | async def __len__(self):
73 | return len(self.fs._find(self.root))
74 |
75 | async def __delitem__(self, key):
76 | """Remove key"""
77 | try:
78 | await self.fs._rm(self._key_to_str(key))
79 | except: # noqa: E722
80 | raise KeyError
81 |
82 | async def __contains__(self, key):
83 | """Does key exist in mapping?"""
84 | path = self._key_to_str(key)
85 | return await self.fs._exists(path) and await self.fs._isfile(path)
86 |
--------------------------------------------------------------------------------
/src/xarray/backends/zarr.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import contextlib
3 | import os
4 | import sys
5 |
6 | import numpy as np
7 | from xarray import conventions
8 | from xarray.backends.common import _decode_variable_name, _normalize_path
9 | from xarray.backends.store import StoreBackendEntrypoint
10 | from xarray.backends.zarr import (
11 | DIMENSION_KEY,
12 | ZarrArrayWrapper,
13 | ZarrBackendEntrypoint,
14 | ZarrStore,
15 | _get_zarr_dims_and_attrs,
16 | )
17 | from xarray.core import indexing
18 | from xarray.core.utils import FrozenDict
19 |
20 | from ..conventions import decode_cf_variable
21 | from ..core.variable import Variable
22 | from ..dataset import Dataset
23 | from ...zarr.convenience import open_consolidated
24 |
25 | sys.modules["xarray.conventions"].decode_cf_variable = decode_cf_variable
26 |
27 |
28 | class AsyncArrayWrapper(ZarrArrayWrapper):
29 | async def __array__(self, dtype=None):
30 | key = indexing.BasicIndexer((slice(None),) * self.ndim)
31 | return np.asarray(await self[key], dtype=dtype)
32 |
33 | async def __getitem__(self, key):
34 | array = self.get_array()
35 | if isinstance(key, indexing.BasicIndexer):
36 | return await array[key.tuple]
37 | elif isinstance(key, indexing.VectorizedIndexer):
38 | return await array.vindex[
39 | indexing._arrayize_vectorized_indexer(key, self.shape).tuple
40 | ]
41 | else:
42 | assert isinstance(key, indexing.OuterIndexer)
43 | return await array.oindex[key.tuple]
44 |
45 |
46 | class AsyncStore(ZarrStore):
47 | @classmethod
48 | async def open_group(
49 | cls,
50 | store,
51 | mode="r",
52 | synchronizer=None,
53 | group=None,
54 | consolidated=False,
55 | consolidate_on_close=False,
56 | chunk_store=None,
57 | storage_options=None,
58 | append_dim=None,
59 | write_region=None,
60 | safe_chunks=True,
61 | stacklevel=2,
62 | ):
63 | if isinstance(store, os.PathLike):
64 | raise NotImplementedError("cannot do local storage zarr")
65 |
66 | open_kwargs = dict(
67 | mode=mode,
68 | synchronizer=synchronizer,
69 | path=group,
70 | )
71 | open_kwargs["storage_options"] = storage_options
72 |
73 | if chunk_store:
74 | open_kwargs["chunk_store"] = chunk_store
75 | if consolidated is None:
76 | consolidated = False
77 |
78 | if consolidated is None:
79 | try:
80 | zarr_group = await open_consolidated(store, **open_kwargs)
81 | except KeyError:
82 | raise NotImplementedError("not ready for non-consolidated stores")
83 | elif consolidated:
84 | # TODO: an option to pass the metadata_key keyword
85 | zarr_group = await open_consolidated(store, **open_kwargs)
86 | else:
87 | raise NotImplementedError("not ready for non-consolidated stores")
88 | return cls(
89 | zarr_group,
90 | mode,
91 | consolidate_on_close,
92 | append_dim,
93 | write_region,
94 | safe_chunks,
95 | )
96 |
97 | async def load(self):
98 | variables = FrozenDict(
99 | (_decode_variable_name(k), v)
100 | for k, v in (await self.get_variables()).items()
101 | )
102 | attributes = FrozenDict(self.get_attrs())
103 | return variables, attributes
104 |
105 | async def get_variables(self):
106 | return FrozenDict(
107 | await asyncio.gather(
108 | *[
109 | self.open_store_variable_with_key(k, v, k)
110 | for k, v in self.zarr_group.arrays()
111 | ]
112 | )
113 | )
114 |
115 | async def open_store_variable_with_key(self, name, zarr_array, label):
116 | return label, await self.open_store_variable(name, zarr_array)
117 |
118 | async def open_store_variable(self, name, zarr_array):
119 | data = AsyncArrayWrapper(name, self)
120 | try_nczarr = self._mode == "r"
121 | dimensions, attributes = _get_zarr_dims_and_attrs(
122 | zarr_array, DIMENSION_KEY, try_nczarr
123 | )
124 | attributes = dict(attributes)
125 | encoding = {
126 | "chunks": zarr_array.chunks,
127 | "preferred_chunks": dict(zip(dimensions, zarr_array.chunks)),
128 | "compressor": zarr_array.compressor,
129 | "filters": zarr_array.filters,
130 | }
131 | # _FillValue needs to be in attributes, not encoding, so it will get
132 | # picked up by decode_cf
133 | if getattr(zarr_array, "fill_value") is not None:
134 | attributes["_FillValue"] = zarr_array.fill_value
135 |
136 | variable = Variable(dimensions, data, attributes, encoding)
137 | await variable.maybe_preload()
138 | return variable
139 |
140 |
141 | class AsyncStoreBackendEntrypoint(StoreBackendEntrypoint):
142 | async def open_dataset(
143 | self,
144 | store,
145 | *,
146 | mask_and_scale=True,
147 | decode_times=True,
148 | concat_characters=True,
149 | decode_coords=True,
150 | drop_variables=None,
151 | use_cftime=None,
152 | decode_timedelta=None,
153 | ):
154 | vars, attrs = await store.load()
155 | encoding = store.get_encoding()
156 |
157 | vars, attrs, coord_names = conventions.decode_cf_variables(
158 | vars,
159 | attrs,
160 | mask_and_scale=mask_and_scale,
161 | decode_times=decode_times,
162 | concat_characters=concat_characters,
163 | decode_coords=decode_coords,
164 | drop_variables=drop_variables,
165 | use_cftime=use_cftime,
166 | decode_timedelta=decode_timedelta,
167 | )
168 |
169 | ds = Dataset(vars, attrs=attrs)
170 | ds = ds.set_coords(coord_names.intersection(vars))
171 | ds.set_close(store.close)
172 | ds.encoding = encoding
173 |
174 | return ds
175 |
176 |
177 | class AsyncZarrBackendEntrypint(ZarrBackendEntrypoint):
178 | async def open_dataset(
179 | self,
180 | filename_or_obj,
181 | mask_and_scale=True,
182 | decode_times=True,
183 | concat_characters=True,
184 | decode_coords=True,
185 | drop_variables=None,
186 | use_cftime=None,
187 | decode_timedelta=None,
188 | group=None,
189 | mode="r",
190 | synchronizer=None,
191 | consolidated=None,
192 | chunk_store=None,
193 | storage_options=None,
194 | stacklevel=3,
195 | ):
196 | filename_or_obj = _normalize_path(filename_or_obj)
197 | store = await AsyncStore.open_group(
198 | filename_or_obj,
199 | group=group,
200 | mode=mode,
201 | synchronizer=synchronizer,
202 | consolidated=consolidated,
203 | consolidate_on_close=False,
204 | chunk_store=chunk_store,
205 | storage_options=storage_options,
206 | stacklevel=stacklevel + 1,
207 | )
208 |
209 | store_entrypoint = AsyncStoreBackendEntrypoint()
210 | with close_on_error(store):
211 | ds = await store_entrypoint.open_dataset(
212 | store,
213 | mask_and_scale=mask_and_scale,
214 | decode_times=decode_times,
215 | concat_characters=concat_characters,
216 | decode_coords=decode_coords,
217 | drop_variables=drop_variables,
218 | use_cftime=use_cftime,
219 | decode_timedelta=decode_timedelta,
220 | )
221 | return ds
222 |
223 |
224 | @contextlib.contextmanager
225 | def close_on_error(f):
226 | """Context manager to ensure that a file opened by xarray is closed if an
227 | exception is raised before the user sees the file object.
228 | """
229 | try:
230 | yield
231 | except Exception:
232 | f.close()
233 | raise
234 |
--------------------------------------------------------------------------------
/src/xarray/conventions.py:
--------------------------------------------------------------------------------
1 | from xarray import conventions
2 |
3 | from .core.variable import Variable
4 |
5 |
6 | def decode_cf_variable(
7 | name,
8 | var,
9 | concat_characters=True,
10 | mask_and_scale=True,
11 | decode_times=True,
12 | decode_endianness=True,
13 | stack_char_dim=True,
14 | use_cftime=None,
15 | decode_timedelta=None,
16 | ):
17 | # Ensure datetime-like Variables are passed through unmodified (GH 6453)
18 | if conventions._contains_datetime_like_objects(var):
19 | return var
20 |
21 | original_dtype = var.dtype
22 |
23 | if decode_timedelta is None:
24 | decode_timedelta = decode_times
25 |
26 | if concat_characters:
27 | if stack_char_dim:
28 | var = conventions.strings.CharacterArrayCoder().decode(var, name=name)
29 | var = conventions.strings.EncodedStringCoder().decode(var)
30 |
31 | if mask_and_scale:
32 | for coder in [
33 | conventions.variables.UnsignedIntegerCoder(),
34 | conventions.variables.CFMaskCoder(),
35 | conventions.variables.CFScaleOffsetCoder(),
36 | ]:
37 | var = coder.decode(var, name=name)
38 |
39 | if decode_timedelta:
40 | var = conventions.times.CFTimedeltaCoder().decode(var, name=name)
41 | if decode_times:
42 | var = conventions.times.CFDatetimeCoder(use_cftime=use_cftime).decode(
43 | var, name=name
44 | )
45 |
46 | dimensions, data, attributes, encoding = conventions.variables.unpack_for_decoding(
47 | var
48 | )
49 | # TODO(shoyer): convert everything below to use coders
50 |
51 | if decode_endianness and not data.dtype.isnative:
52 | # do this last, so it's only done if we didn't already unmask/scale
53 | data = conventions.NativeEndiannessArray(data)
54 | original_dtype = data.dtype
55 |
56 | encoding.setdefault("dtype", original_dtype)
57 |
58 | if "dtype" in attributes and attributes["dtype"] == "bool":
59 | del attributes["dtype"]
60 | data = conventions.BoolTypeArray(data)
61 |
62 | return Variable(dimensions, data, attributes, encoding=encoding)
63 |
--------------------------------------------------------------------------------
/src/xarray/core/variable.py:
--------------------------------------------------------------------------------
1 | from asyncio import iscoroutinefunction
2 | from typing import Any
3 |
4 | import numpy as np
5 | from xarray.core import variable
6 |
7 |
8 | class Variable(variable.Variable):
9 | async def maybe_preload(self):
10 | if self.size < 1e5 and iscoroutinefunction(getattr(self._data, "__array__")):
11 | self._data = await self._data.__array__()
12 |
13 | @property
14 | def data(self) -> Any:
15 | return self._data
16 |
17 | @data.setter
18 | def data(self, data):
19 | data = variable.as_compatible_data(data)
20 | if data.shape != self.shape:
21 | raise ValueError(
22 | f"replacement data must match the Variable's shape. "
23 | f"replacement data has shape {data.shape}; Variable has shape {self.shape}"
24 | )
25 | self._data = data
26 |
27 | async def _isel(
28 | self,
29 | indexers: variable.Mapping[Any, Any] = None,
30 | missing_dims="raise",
31 | **indexers_kwargs: Any,
32 | ):
33 | indexers = variable.either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
34 | indexers = variable.drop_dims_from_indexers(indexers, self.dims, missing_dims)
35 |
36 | key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
37 | return await self.__agetitem__(key)
38 |
39 | async def __agetitem__(self, key):
40 | """Return a new Variable object whose contents are consistent with
41 | getting the provided key from the underlying data.
42 |
43 | NB. __getitem__ and __setitem__ implement xarray-style indexing,
44 | where if keys are unlabeled arrays, we index the array orthogonally
45 | with them. If keys are labeled array (such as Variables), they are
46 | broadcasted with our usual scheme and then the array is indexed with
47 | the broadcasted key, like numpy's fancy indexing.
48 |
49 | If you really want to do indexing like `x[x > 0]`, manipulate the numpy
50 | array `x.values` directly.
51 | """
52 | dims, indexer, new_order = self._broadcast_indexes(key)
53 | data = await self._data[indexer]
54 | if new_order:
55 | data = np.moveaxis(data, range(len(new_order)), new_order)
56 | return self._finalize_indexing_result(dims, data)
57 |
--------------------------------------------------------------------------------
/src/xarray/dataset.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Any, Hashable, Iterable, Mapping
3 |
4 | from xarray import Dataset as XDs
5 | from xarray.core.indexes import isel_indexes
6 | from xarray.core.indexing import is_fancy_indexer, map_index_queries
7 | from xarray.core.utils import drop_dims_from_indexers, either_dict_or_kwargs
8 |
9 |
10 | class Dataset(XDs):
11 | async def _isel(
12 | self,
13 | indexers: Mapping[Any, Any] | None = None,
14 | drop: bool = False,
15 | missing_dims="raise",
16 | **indexers_kwargs: Any,
17 | ):
18 | indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
19 | if any(is_fancy_indexer(idx) for idx in indexers.values()):
20 | return self._isel_fancy(indexers, drop=drop, missing_dims=missing_dims)
21 |
22 | # Much faster algorithm for when all indexers are ints, slices, one-dimensional
23 | # lists, or zero or one-dimensional np.ndarray's
24 | indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims)
25 |
26 | variables = {}
27 | dims: dict[Hashable, int] = {}
28 | coord_names = self._coord_names.copy()
29 |
30 | indexes, index_variables = isel_indexes(self.xindexes, indexers)
31 |
32 | async def place_var(name, var):
33 | if name in index_variables:
34 | var = index_variables[name]
35 | else:
36 | var_indexers = {k: v for k, v in indexers.items() if k in var.dims}
37 | if var_indexers:
38 | if hasattr(var, "_isel"):
39 | var = await var._isel(var_indexers)
40 | else:
41 | var = var.isel(var_indexers)
42 | if drop and var.ndim == 0 and name in coord_names:
43 | coord_names.remove(name)
44 | return
45 | variables[name] = var
46 | dims.update(zip(var.dims, var.shape))
47 |
48 | await asyncio.gather(
49 | *[place_var(name, var) for name, var in self._variables.items()]
50 | )
51 | # preserve variable order
52 | # if name in index_variables:
53 | # var = index_variables[name]
54 | # else:
55 | # var_indexers = {k: v for k, v in indexers.items() if k in var.dims}
56 | # if var_indexers:
57 | # var = var.isel(var_indexers)
58 | # if drop and var.ndim == 0 and name in coord_names:
59 | # coord_names.remove(name)
60 | # continue
61 | # variables[name] = var
62 | # dims.update(zip(var.dims, var.shape))
63 |
64 | return self._construct_direct(
65 | variables=variables,
66 | coord_names=coord_names,
67 | dims=dims,
68 | attrs=self._attrs,
69 | indexes=indexes,
70 | encoding=self._encoding,
71 | close=self._close,
72 | )
73 |
74 | async def _sel(
75 | self,
76 | indexers: Mapping[Any, Any] = None,
77 | method: str = None,
78 | tolerance: int | float | Iterable[int | float] | None = None,
79 | drop: bool = False,
80 | **indexers_kwargs: Any,
81 | ):
82 | indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
83 | query_results = map_index_queries(
84 | self, indexers=indexers, method=method, tolerance=tolerance
85 | )
86 |
87 | if drop:
88 | no_scalar_variables = {}
89 | for k, v in query_results.variables.items():
90 | if v.dims:
91 | no_scalar_variables[k] = v
92 | else:
93 | if k in self._coord_names:
94 | query_results.drop_coords.append(k)
95 | query_results.variables = no_scalar_variables
96 |
97 | result = await self._isel(indexers=query_results.dim_indexers, drop=drop)
98 | return result._overwrite_indexes(*query_results.as_tuple()[1:])
99 |
--------------------------------------------------------------------------------
/src/zarr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeliashi/xarray-async/86ebe27adc5151f761414c4cf58165201a13186d/src/zarr/__init__.py
--------------------------------------------------------------------------------
/src/zarr/convenience.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from zarr.convenience import StoreLike, open
4 | from zarr.storage import normalize_store_arg
5 |
6 | from .core import Array
7 | from .storage import ConsolidatedMetadataStore
8 |
9 | sys.modules["zarr.core"].Array = Array
10 | sys.modules["zarr.hierarchy"].Array = Array
11 |
12 |
13 | async def open_consolidated(
14 | store: StoreLike, metadata_key=".zmetadata", mode="r+", **kwargs
15 | ):
16 | zarr_version = kwargs.get("zarr_version")
17 | store = normalize_store_arg(
18 | store,
19 | storage_options=kwargs.get("storage_options"),
20 | mode=mode,
21 | zarr_version=zarr_version,
22 | )
23 | if mode not in {"r", "r+"}:
24 | raise ValueError(
25 | "invalid mode, expected either 'r' or 'r+'; found {!r}".format(mode)
26 | )
27 |
28 | path = kwargs.pop("path", None)
29 | if store._store_version == 2:
30 | ConsolidatedStoreClass = ConsolidatedMetadataStore
31 | else:
32 | raise NotImplementedError()
33 | # assert_zarr_v3_api_available()
34 | # ConsolidatedStoreClass = ConsolidatedMetadataStoreV3
35 | # # default is to store within 'consolidated' group on v3
36 | # if not metadata_key.startswith("meta/root/"):
37 | # metadata_key = "meta/root/consolidated/" + metadata_key
38 |
39 | # setup metadata store
40 | meta_store = ConsolidatedStoreClass(store, metadata_key=metadata_key)
41 | await meta_store.init_coro
42 |
43 | # pass through
44 | chunk_store = kwargs.pop("chunk_store", None) or store
45 | return open(
46 | store=meta_store, chunk_store=chunk_store, mode=mode, path=path, **kwargs
47 | )
48 |
--------------------------------------------------------------------------------
/src/zarr/core.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import sys
3 |
4 | import numpy as np
5 | from numcodecs.compat import ensure_ndarray_like
6 | from zarr.core import Array as ZA
7 | from zarr.errors import err_too_many_indices
8 | from zarr.indexing import (
9 | BasicIndexer,
10 | CoordinateIndexer,
11 | MaskIndexer,
12 | OrthogonalIndexer,
13 | check_fields,
14 | ensure_tuple,
15 | is_pure_fancy_indexing,
16 | pop_fields,
17 | )
18 | from zarr.util import check_array_shape
19 |
20 | from .indexing import OIndex, VIndex
21 |
22 | sys.modules["zarr.core"].OIndex = OIndex
23 | sys.modules["zarr.core"].VIndex = VIndex
24 |
25 |
26 | class Array(ZA):
27 | async def __array__(self, *args):
28 | a = await self[...]
29 | if args:
30 | a = a.astype(args[0])
31 | return a
32 |
33 | async def islice(self, start=None, end=None):
34 | if len(self.shape) == 0:
35 | # Same error as numpy
36 | raise TypeError("iteration over a 0-d array")
37 | if start is None:
38 | start = 0
39 | if end is None or end > self.shape[0]:
40 | end = self.shape[0]
41 |
42 | if not isinstance(start, int) or start < 0:
43 | raise ValueError("start must be a nonnegative integer")
44 |
45 | if not isinstance(end, int) or end < 0:
46 | raise ValueError("end must be a nonnegative integer")
47 |
48 | # Avoid repeatedly decompressing chunks by iterating over the chunks
49 | # in the first dimension.
50 | chunk_size = self.chunks[0]
51 | chunk = None
52 | for j in range(start, end):
53 | if j % chunk_size == 0:
54 | chunk = await self[j : j + chunk_size]
55 | # init chunk if we start offset of chunk borders
56 | elif chunk is None:
57 | chunk_start = j - j % chunk_size
58 | chunk_end = chunk_start + chunk_size
59 | chunk = await self[chunk_start:chunk_end]
60 | yield await chunk[j % chunk_size]
61 |
62 | async def __iter__(self):
63 | return await self.islice()
64 |
65 | async def __aiter__(self):
66 | return await self.islice()
67 |
68 | async def __getitem__(self, selection):
69 | fields, pure_selection = pop_fields(selection)
70 | if is_pure_fancy_indexing(pure_selection, self.ndim):
71 | result = await self.vindex[selection]
72 | else:
73 | result = await self.get_basic_selection(pure_selection, fields=fields)
74 | return result
75 |
76 | async def get_basic_selection(self, selection=Ellipsis, out=None, fields=None):
77 | if not self._cache_metadata:
78 | self._load_metadata()
79 | check_fields(fields, self._dtype)
80 | if self._shape == ():
81 | return await self._get_basic_selection_zd(
82 | selection=selection, out=out, fields=fields
83 | )
84 | else:
85 | return await self._get_basic_selection_nd(
86 | selection=selection, out=out, fields=fields
87 | )
88 |
89 | async def _get_basic_selection_zd(self, selection, out=None, fields=None):
90 | selection = ensure_tuple(selection)
91 | if selection not in ((), (Ellipsis,)):
92 | err_too_many_indices(selection, ())
93 |
94 | try:
95 | # obtain encoded data for chunk
96 | ckey = self._chunk_key((0,))
97 | cdata = await self.chunk_store[ckey]
98 |
99 | except KeyError:
100 | # chunk not initialized
101 | chunk = np.zeros_like(self._meta_array, shape=(), dtype=self._dtype)
102 | if self._fill_value is not None:
103 | chunk.fill(self._fill_value)
104 |
105 | else:
106 | chunk = self._decode_chunk(cdata)
107 |
108 | # handle fields
109 | if fields:
110 | chunk = chunk[fields]
111 |
112 | # handle selection of the scalar value via empty tuple
113 | if out is None:
114 | out = chunk[selection]
115 | else:
116 | out[selection] = chunk[selection]
117 |
118 | return out
119 |
120 | async def _get_basic_selection_nd(self, selection, out=None, fields=None):
121 | indexer = BasicIndexer(selection, self)
122 |
123 | return await self._get_selection(indexer=indexer, out=out, fields=fields)
124 |
125 | async def _get_selection(self, indexer, out=None, fields=None):
126 | out_dtype = check_fields(fields, self._dtype)
127 | out_shape = indexer.shape
128 | if out is None:
129 | out = np.empty_like(
130 | self._meta_array, shape=out_shape, dtype=out_dtype, order=self._order
131 | )
132 | else:
133 | check_array_shape("out", out, out_shape)
134 | # sequentially get one key at a time from storage
135 | await asyncio.gather(
136 | *[
137 | self._chunk_getitem(
138 | chunk_coords,
139 | chunk_selection,
140 | out,
141 | out_selection,
142 | drop_axes=indexer.drop_axes,
143 | fields=fields,
144 | )
145 | for chunk_coords, chunk_selection, out_selection in indexer
146 | ]
147 | )
148 |
149 | if out.shape:
150 | return out
151 | else:
152 | return out[()]
153 |
154 | async def _chunk_getitem(
155 | self,
156 | chunk_coords,
157 | chunk_selection,
158 | out,
159 | out_selection,
160 | drop_axes=None,
161 | fields=None,
162 | ):
163 | out_is_ndarray = True
164 | try:
165 | out = ensure_ndarray_like(out)
166 | except TypeError:
167 | out_is_ndarray = False
168 |
169 | assert len(chunk_coords) == len(self._cdata_shape)
170 | ckey = self._chunk_key(chunk_coords)
171 | try:
172 | # obtain compressed data for chunk
173 | cdata = await self.chunk_store[ckey]
174 | except KeyError:
175 | # chunk not initialized
176 | if self._fill_value is not None:
177 | if fields:
178 | fill_value = self._fill_value[fields]
179 | else:
180 | fill_value = self._fill_value
181 | out[out_selection] = fill_value
182 |
183 | else:
184 | self._process_chunk(
185 | out,
186 | cdata,
187 | chunk_selection,
188 | drop_axes,
189 | out_is_ndarray,
190 | fields,
191 | out_selection,
192 | )
193 |
194 | async def get_orthogonal_selection(self, selection, out=None, fields=None):
195 | if not self._cache_metadata:
196 | self._load_metadata()
197 | check_fields(fields, self._dtype)
198 | indexer = OrthogonalIndexer(selection, self)
199 | return await self._get_selection(indexer=indexer, out=out, fields=fields)
200 |
201 | async def set_orthogonal_selection(self, selection, value, fields=None):
202 | raise NotImplementedError()
203 |
204 | async def get_coordinate_selection(self, selection, out=None, fields=None):
205 | if not self._cache_metadata:
206 | self._load_metadata()
207 |
208 | # check args
209 | check_fields(fields, self._dtype)
210 |
211 | # setup indexer
212 | indexer = CoordinateIndexer(selection, self)
213 |
214 | # handle output - need to flatten
215 | if out is not None:
216 | out = out.reshape(-1)
217 |
218 | out = await self._get_selection(indexer=indexer, out=out, fields=fields)
219 |
220 | # restore shape
221 | out = out.reshape(indexer.sel_shape)
222 |
223 | return out
224 |
225 | async def get_mask_selection(self, selection, out=None, fields=None):
226 | if not self._cache_metadata:
227 | self._load_metadata()
228 |
229 | # check args
230 | check_fields(fields, self._dtype)
231 |
232 | # setup indexer
233 | indexer = MaskIndexer(selection, self)
234 |
235 | return await self._get_selection(indexer=indexer, out=out, fields=fields)
236 |
--------------------------------------------------------------------------------
/src/zarr/indexing.py:
--------------------------------------------------------------------------------
1 | from zarr.errors import VindexInvalidSelectionError
2 | from zarr.indexing import OIndex as ZO
3 | from zarr.indexing import VIndex as ZV
4 | from zarr.indexing import (
5 | ensure_tuple,
6 | is_coordinate_selection,
7 | is_mask_selection,
8 | pop_fields,
9 | replace_lists,
10 | )
11 |
12 |
13 | class OIndex(ZO):
14 | async def __getitem__(self, selection):
15 | fields, selection = pop_fields(selection)
16 | selection = ensure_tuple(selection)
17 | selection = replace_lists(selection)
18 | return await self.array.get_orthogonal_selection(selection, fields=fields)
19 |
20 | async def __setitem__(self, selection, value):
21 | fields, selection = pop_fields(selection)
22 | selection = ensure_tuple(selection)
23 | selection = replace_lists(selection)
24 | return await self.array.set_orthogonal_selection(
25 | selection, value, fields=fields
26 | )
27 |
28 |
29 | class VIndex(ZV):
30 | async def __getitem__(self, selection):
31 | fields, selection = pop_fields(selection)
32 | selection = ensure_tuple(selection)
33 | selection = replace_lists(selection)
34 | if is_coordinate_selection(selection, self.array):
35 | return await self.array.get_coordinate_selection(selection, fields=fields)
36 | elif is_mask_selection(selection, self.array):
37 | return await self.array.get_mask_selection(selection, fields=fields)
38 | else:
39 | raise VindexInvalidSelectionError(selection)
40 |
--------------------------------------------------------------------------------
/src/zarr/storage.py:
--------------------------------------------------------------------------------
1 | from zarr.errors import MetadataError
2 | from zarr.storage import ConsolidatedMetadataStore as zCMS
3 | from zarr.storage import KVStore, Store, StoreLike
4 | from zarr.util import json_loads
5 |
6 |
7 | class ConsolidatedMetadataStore(zCMS):
8 | def __init__(self, store: StoreLike, metadata_key=".zmetadata"):
9 | self.init_coro = self.ainit(store, metadata_key)
10 |
11 | async def ainit(self, store: StoreLike, metadata_key=".zmetadata"):
12 | self.store = Store._ensure_store(store)
13 |
14 | # retrieve consolidated metadata
15 | meta = json_loads(await self.store[metadata_key])
16 |
17 | # check format of consolidated metadata
18 | consolidated_format = meta.get("zarr_consolidated_format", None)
19 | if consolidated_format != 1:
20 | raise MetadataError(
21 | "unsupported zarr consolidated metadata format: %s"
22 | % consolidated_format
23 | )
24 |
25 | # decode metadata
26 | self.meta_store: Store = KVStore(meta["metadata"])
27 |
--------------------------------------------------------------------------------