├── .gitignore ├── README.md ├── create_huge_random_zarr_and_tif.ipynb ├── create_random_test_zarr.ipynb ├── create_random_test_zarr.py ├── dask_future_loader_zarr.ipynb ├── data └── placeholder ├── example_ZARR.ipynb ├── example_ZARR_daskclient.ipynb └── ome_xarray.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore folder data 2 | data/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-zarr-loader 2 | 3 | During the Biovision Hackathon 2023 in Zurich we wanted to find a way to load data from a zarr file into pytorch. The goal was to load the data in a way that it can be used for training a neural network. 4 | We tried several solution in particular using or without using dask to load the data in parallel. 5 | 6 | ## Dataset 7 | 8 | original data used for testing:[Link](https://imagesc.zulipchat.com/user_uploads/16804/85qPFC9O85gLhNmF5KLdqtUx/bsd_val.zarr.zip) 9 | But you can also find an example to generate a random data : 10 | - [Notebook for random data](create_random_test_zarr.ipynb) 11 | - [Notebook for huge random data zarr and tiff](create_huge_random_zarr_and_tif) 12 | 13 | Proposed loaders: 14 | - [Fastest loader](example_ZARR.ipynb) 15 | - [Notebook for dask loader](dask_future_loader_zarr.ipynb) 16 | 17 | ## Installation 18 | 19 | create env 20 | ```bash 21 | mamba create -n pytorch-zarr-loader -c pytorch -c conda-forge python=3.11 ome-zarr pytorch cpuonly notebook napari matplotlib 22 | mamba activate pytorch-zarr-loader 23 | ``` 24 | 25 | create test data 26 | ```bash 27 | python create_random_test_zarr.py 28 | ``` 29 | ## Similar solutions 30 | 31 | Please consider other repo that also load zarr for pytorch 32 | - https://github.com/TheJacksonLaboratory/zarrdataset 33 | -------------------------------------------------------------------------------- /create_huge_random_zarr_and_tif.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import zarr\n", 11 | "import os\n", 12 | "\n", 13 | "from skimage.data import binary_blobs\n", 14 | "from ome_zarr.io import parse_url\n", 15 | "from ome_zarr.writer import write_image\n", 16 | "import dask.array as da" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "random_img = da.random.random((1000, 1000, 10000), chunks=(100, 100, 100))\n", 26 | "random_img" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from pathlib import Path\n", 36 | "test_path = Path('.') / 'data' / 'huge.zarr'\n", 37 | "\n", 38 | "os.makedirs(test_path, exist_ok=True)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# write the image data\n", 48 | "store = parse_url(test_path, mode=\"w\").store\n", 49 | "root = zarr.group(store=store)\n", 50 | "write_image(image=random_img, group=root, axes=\"zyx\", storage_options=dict(chunks=(100, 100, 100)))\n", 51 | "# optional rendering settings\n", 52 | "root.attrs[\"omero\"] = {\n", 53 | " \"channels\": [{\n", 54 | " \"color\": \"00FFFF\",\n", 55 | " \"window\": {\"start\": 0, \"end\": 20, \"min\": 0, \"max\": 255},\n", 56 | " \"label\": \"random\",\n", 57 | " \"active\": True,\n", 58 | " }]\n", 59 | "}" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Do the same thing but saving as tif\n", 69 | "import tifffile\n", 70 | "from pathlib import Path\n", 71 | "import dask.array as da\n", 72 | "import os\n", 73 | "\n", 74 | "random_img = da.random.random((1000, 1000, 10000), chunks=(100, 100, 100))\n", 75 | "\n", 76 | "test_path = Path('.') / 'data' / 'many_tif'\n", 77 | "\n", 78 | "os.makedirs(test_path, exist_ok=True)\n", 79 | "\n", 80 | "for i in range(random_img.shape[2]):\n", 81 | " tifffile.imwrite(test_path / f'{i}.tif', random_img[:, :, i].compute())" 82 | ] 83 | } 84 | ], 85 | "metadata": { 86 | "kernelspec": { 87 | "display_name": "pytorch-2d-unet", 88 | "language": "python", 89 | "name": "python3" 90 | }, 91 | "language_info": { 92 | "codemirror_mode": { 93 | "name": "ipython", 94 | "version": 3 95 | }, 96 | "file_extension": ".py", 97 | "mimetype": "text/x-python", 98 | "name": "python", 99 | "nbconvert_exporter": "python", 100 | "pygments_lexer": "ipython3", 101 | "version": "3.11.6" 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 2 106 | } 107 | -------------------------------------------------------------------------------- /create_random_test_zarr.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 18, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import zarr\n", 11 | "import os\n", 12 | "\n", 13 | "from skimage.data import binary_blobs\n", 14 | "from ome_zarr.io import parse_url\n", 15 | "from ome_zarr.writer import write_image" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 22, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "\n", 25 | "path = \"./data/aymanns/test_zarr/test_ngff_image.zarr\"\n", 26 | "os.makedirs(path, exist_ok=True)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 23, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "mean_val=10\n", 36 | "size_xy = 512\n", 37 | "size_z = 100\n", 38 | "rng = np.random.default_rng(0)\n", 39 | "data = rng.poisson(mean_val, size=(size_z, size_xy, size_xy)).astype(np.uint8)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 24, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# write the image data\n", 49 | "store = parse_url(path, mode=\"w\").store\n", 50 | "root = zarr.group(store=store)\n", 51 | "write_image(image=data, group=root, axes=\"zyx\", storage_options=dict(chunks=(1, size_xy, size_xy)))\n", 52 | "# optional rendering settings\n", 53 | "root.attrs[\"omero\"] = {\n", 54 | " \"channels\": [{\n", 55 | " \"color\": \"00FFFF\",\n", 56 | " \"window\": {\"start\": 0, \"end\": 20, \"min\": 0, \"max\": 255},\n", 57 | " \"label\": \"random\",\n", 58 | " \"active\": True,\n", 59 | " }]\n", 60 | "}" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 25, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# add labels...\n", 70 | "blobs = binary_blobs(length=size_xy, volume_fraction=0.1, n_dim=3).astype('int8')\n", 71 | "blobs2 = binary_blobs(length=size_xy, volume_fraction=0.1, n_dim=3).astype('int8')\n", 72 | "# blobs will contain values of 1, 2 and 0 (background)\n", 73 | "blobs += 2 * blobs2\n", 74 | "\n", 75 | "# label.shape is (size_xy, size_xy, size_xy), Slice to match the data\n", 76 | "label = blobs[:size_z, :, :]" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 26, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# write the labels to /labels\n", 86 | "labels_grp = root.create_group(\"labels\")\n", 87 | "# the 'labels' .zattrs lists the named labels data\n", 88 | "label_name = \"blobs\"\n", 89 | "labels_grp.attrs[\"labels\"] = [label_name]\n", 90 | "label_grp = labels_grp.create_group(label_name)\n", 91 | "# need 'image-label' attr to be recognized as label\n", 92 | "label_grp.attrs[\"image-label\"] = {\n", 93 | " \"colors\": [\n", 94 | " {\"label-value\": 1, \"rgba\": [255, 0, 0, 255]},\n", 95 | " {\"label-value\": 2, \"rgba\": [0, 255, 0, 255]},\n", 96 | " {\"label-value\": 3, \"rgba\": [255, 255, 0, 255]}\n", 97 | " ]\n", 98 | "}" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 27, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "[]" 110 | ] 111 | }, 112 | "execution_count": 27, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "write_image(label, label_grp, axes=\"zyx\")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "pytorch-2d-unet", 132 | "language": "python", 133 | "name": "python3" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 3 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython3", 145 | "version": "3.9.17" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 2 150 | } 151 | -------------------------------------------------------------------------------- /create_random_test_zarr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import zarr 3 | import os 4 | 5 | from skimage.data import binary_blobs 6 | from ome_zarr.io import parse_url 7 | from ome_zarr.writer import write_image 8 | 9 | path = "data/test_ngff_image.zarr" 10 | os.mkdir(path) 11 | 12 | mean_val=10 13 | size_xy = 512 14 | size_z = 100 15 | rng = np.random.default_rng(0) 16 | data = rng.poisson(mean_val, size=(size_z, size_xy, size_xy)).astype(np.uint8) 17 | 18 | # write the image data 19 | store = parse_url(path, mode="w").store 20 | root = zarr.group(store=store) 21 | write_image(image=data, group=root, axes="zyx", storage_options=dict(chunks=(1, size_xy, size_xy))) 22 | # optional rendering settings 23 | root.attrs["omero"] = { 24 | "channels": [{ 25 | "color": "00FFFF", 26 | "window": {"start": 0, "end": 20, "min": 0, "max": 255}, 27 | "label": "random", 28 | "active": True, 29 | }] 30 | } 31 | 32 | 33 | # add labels... 34 | blobs = binary_blobs(length=size_xy, volume_fraction=0.1, n_dim=3).astype('int8') 35 | blobs2 = binary_blobs(length=size_xy, volume_fraction=0.1, n_dim=3).astype('int8') 36 | # blobs will contain values of 1, 2 and 0 (background) 37 | blobs += 2 * blobs2 38 | 39 | # label.shape is (size_xy, size_xy, size_xy), Slice to match the data 40 | label = blobs[:size_z, :, :] 41 | 42 | # write the labels to /labels 43 | labels_grp = root.create_group("labels") 44 | # the 'labels' .zattrs lists the named labels data 45 | label_name = "blobs" 46 | labels_grp.attrs["labels"] = [label_name] 47 | label_grp = labels_grp.create_group(label_name) 48 | # need 'image-label' attr to be recognized as label 49 | label_grp.attrs["image-label"] = { 50 | "colors": [ 51 | {"label-value": 1, "rgba": [255, 0, 0, 255]}, 52 | {"label-value": 2, "rgba": [0, 255, 0, 255]}, 53 | {"label-value": 3, "rgba": [255, 255, 0, 255]} 54 | ] 55 | } 56 | 57 | write_image(label, label_grp, axes="zyx") 58 | -------------------------------------------------------------------------------- /dask_future_loader_zarr.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ZARR reading with Dask Client and future" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import dask.array as da\n", 17 | "\n", 18 | "online_path = \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0062A/6001240.zarr/0\"\n", 19 | "local_path = \"data/6001240.zarr\"" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "text/html": [ 30 | "\n", 31 | " \n", 32 | " \n", 65 | " \n", 140 | " \n", 141 | "
\n", 33 | " \n", 34 | " \n", 35 | " \n", 36 | " \n", 37 | " \n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | "
Array Chunk
Bytes 67.09 MiB 0.95 MiB
Shape (2, 236, 275, 271) (1, 50, 100, 100)
Dask graph 90 chunks in 5 graph layers
Data type int16 numpy.ndarray
\n", 64 | "
\n", 66 | " \n", 67 | "\n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | "\n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | "\n", 77 | " \n", 78 | " \n", 79 | "\n", 80 | " \n", 81 | " 2\n", 82 | " 1\n", 83 | "\n", 84 | "\n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | "\n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | "\n", 99 | " \n", 100 | " \n", 101 | "\n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | "\n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | "\n", 116 | " \n", 117 | " \n", 118 | "\n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | "\n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | "\n", 131 | " \n", 132 | " \n", 133 | "\n", 134 | " \n", 135 | " 271\n", 136 | " 275\n", 137 | " 236\n", 138 | "\n", 139 | "
" 142 | ], 143 | "text/plain": [ 144 | "dask.array" 145 | ] 146 | }, 147 | "execution_count": 2, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "data = da.from_zarr(online_path).rechunk((1, 50, 100, 100)).astype('int16')\n", 154 | "data" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 3, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "data.to_zarr(local_path)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 4, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "patch_size = (2, 10, 64, 64)\n", 173 | "small_slice = tuple([slice(0, i) for i in patch_size])" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 5, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "\n", 183 | "from itertools import islice\n", 184 | "from pathlib import Path\n", 185 | "from typing import List, Tuple, Union, Optional, Callable, Dict, Generator\n", 186 | "import time\n", 187 | "import numpy as np\n", 188 | "import zarr\n", 189 | "from torch.utils.data import DataLoader, IterableDataset, get_worker_info\n", 190 | "from dask.distributed import Client, get_client\n", 191 | "\n", 192 | "from timeit import timeit, time\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 6, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "import numpy as np\n", 202 | "from typing import List, Tuple, Union\n", 203 | "import dask\n", 204 | "\n", 205 | "\n", 206 | "@dask.delayed\n", 207 | "def load_zarr(arr: np.ndarray,\n", 208 | " patch_positions,\n", 209 | " patch_size: Union[List[int], Tuple[int, ...]]\n", 210 | " ) -> np.ndarray:\n", 211 | "\n", 212 | "\n", 213 | " # create slices for each dimension\n", 214 | " slices = []\n", 215 | " for i, (center, dimension) in enumerate(zip(patch_positions, patch_size)):\n", 216 | " if center is None:\n", 217 | " slices.append(slice(None))\n", 218 | " else:\n", 219 | " slices.append(slice(center, center + dimension))\n", 220 | "\n", 221 | " # load patch\n", 222 | " patch = arr[tuple(slices)]\n", 223 | " return patch\n", 224 | "\n", 225 | "def extract_patches_random(arr: np.ndarray,\n", 226 | " patch_size: Union[List[int], Tuple[int, ...]],\n", 227 | " num_patches: int) -> List[np.ndarray]:\n", 228 | " \"\"\"\n", 229 | " Extract a specified number of patches from an array in a random manner.\n", 230 | "\n", 231 | " Parameters\n", 232 | " ----------\n", 233 | " arr : np.ndarray\n", 234 | " Input array from which to extract patches.\n", 235 | " patch_size : Tuple[int, ...]\n", 236 | " Patch sizes in each dimension.\n", 237 | " num_patches : int\n", 238 | " Number of patches to return.\n", 239 | "\n", 240 | " Returns\n", 241 | " -------\n", 242 | " List[np.ndarray]\n", 243 | " List of randomly selected patches.\n", 244 | " \"\"\"\n", 245 | "\n", 246 | " patch_centers = []\n", 247 | " for i, dimension in enumerate(patch_size):\n", 248 | " if dimension == arr.shape[i]:\n", 249 | " patch_centers.append([None]*num_patches)\n", 250 | " else:\n", 251 | " patch_centers.append(np.random.randint(low=0,\n", 252 | " high=arr.shape[i] - dimension,\n", 253 | " size=num_patches))\n", 254 | " patch_centers = np.array(patch_centers).T\n", 255 | "\n", 256 | " patches = []\n", 257 | " for patch in patch_centers:\n", 258 | " patch = load_zarr(arr, patch, patch_size)\n", 259 | " patches.append(patch)\n", 260 | "\n", 261 | " patches = dask.compute(*patches)\n", 262 | " return np.stack(patches)\n" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 7, 268 | "metadata": {}, 269 | "outputs": [ 270 | { 271 | "data": { 272 | "text/plain": [ 273 | "'http://127.0.0.1:8787/status'" 274 | ] 275 | }, 276 | "execution_count": 7, 277 | "metadata": {}, 278 | "output_type": "execute_result" 279 | } 280 | ], 281 | "source": [ 282 | "try:\n", 283 | " client = get_client()\n", 284 | "except ValueError:\n", 285 | " client = Client()\n", 286 | "client.dashboard_link" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 8, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "class ZarrDataset(IterableDataset):\n", 296 | " \"\"\"Dataset to extract patches from a zarr storage.\"\"\"\n", 297 | "\n", 298 | " def __init__(\n", 299 | " self,\n", 300 | " data_path: Union[str, Path],\n", 301 | " patch_size: Optional[Union[List[int], Tuple[int]]] = None,\n", 302 | " num_patches: Optional[int] = None,\n", 303 | " num_load_at_once: int = 20,\n", 304 | " n_shuffle_coordinates: int = 20,\n", 305 | " ) -> None:\n", 306 | " self.patch_size = patch_size\n", 307 | " self.num_patches = num_patches\n", 308 | " self.num_load_at_once = num_load_at_once\n", 309 | " self.n_shuffle_coordinates = n_shuffle_coordinates\n", 310 | "\n", 311 | " self.sample = zarr.open(data_path, mode=\"r\")\n", 312 | "\n", 313 | " def __len__(self):\n", 314 | " return self.n_shuffle_coordinates * self.num_load_at_once\n", 315 | "\n", 316 | " def __iter__(self):\n", 317 | " \"\"\"\n", 318 | " Iterate over data source and yield single patch.\n", 319 | "\n", 320 | " Yields\n", 321 | " ------\n", 322 | " np.ndarray\n", 323 | " \"\"\"\n", 324 | " future = client.submit(extract_patches_random,\n", 325 | " self.sample,\n", 326 | " self.patch_size,\n", 327 | " self.num_load_at_once)\n", 328 | "\n", 329 | " data_in_memory = extract_patches_random(self.sample,\n", 330 | " self.patch_size,\n", 331 | " self.num_load_at_once)\n", 332 | "\n", 333 | " for _ in range(self.n_shuffle_coordinates):\n", 334 | " data_in_memory = future.result()\n", 335 | " future = client.submit(extract_patches_random,\n", 336 | " self.sample,\n", 337 | " self.patch_size,\n", 338 | " self.num_load_at_once)\n", 339 | "\n", 340 | " for j in range(len(data_in_memory)):\n", 341 | " # pop and yield single patch\n", 342 | " patch = data_in_memory[j]\n", 343 | " yield patch\n" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 13, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stderr", 353 | "output_type": "stream", 354 | "text": [ 355 | "100%|██████████| 2000/2000 [00:05<00:00, 340.76it/s]\n" 356 | ] 357 | } 358 | ], 359 | "source": [ 360 | "# around 5 seconds\n", 361 | "\n", 362 | "from tqdm import tqdm\n", 363 | "\n", 364 | "dataset = ZarrDataset(\n", 365 | " data_path=local_path,\n", 366 | " patch_size=patch_size,\n", 367 | " num_load_at_once=20,\n", 368 | " n_shuffle_coordinates=100\n", 369 | ")\n", 370 | "\n", 371 | "dl = DataLoader(dataset, batch_size=1, num_workers=0, prefetch_factor=None)\n", 372 | "\n", 373 | "\n", 374 | "for X in tqdm(dl):\n", 375 | " X = np.array(X)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 14, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "name": "stderr", 385 | "output_type": "stream", 386 | "text": [ 387 | " 2%|▏ | 40/2000 [00:01<00:51, 38.28it/s]2023-11-08 16:19:27,045 - distributed.scheduler - ERROR - Couldn't gather keys: {('getitem-1e3ff9323dd8b129417285e108080b70', 1, 0, 0, 0): 'waiting', ('getitem-1e3ff9323dd8b129417285e108080b70', 0, 0, 0, 0): 'waiting'}\n", 388 | "2023-11-08 16:19:27,045 - distributed.client - WARNING - Couldn't gather 2 keys, rescheduling (('getitem-1e3ff9323dd8b129417285e108080b70', 1, 0, 0, 0), ('getitem-1e3ff9323dd8b129417285e108080b70', 0, 0, 0, 0))\n", 389 | " 48%|████▊ | 962/2000 [00:25<00:28, 35.96it/s]2023-11-08 16:19:51,733 - distributed.scheduler - ERROR - Couldn't gather keys: {('getitem-1e3ff9323dd8b129417285e108080b70', 1, 0, 0, 0): 'waiting', ('getitem-1e3ff9323dd8b129417285e108080b70', 0, 0, 0, 0): 'waiting'}\n", 390 | "2023-11-08 16:19:51,734 - distributed.client - WARNING - Couldn't gather 2 keys, rescheduling (('getitem-1e3ff9323dd8b129417285e108080b70', 1, 0, 0, 0), ('getitem-1e3ff9323dd8b129417285e108080b70', 0, 0, 0, 0))\n", 391 | "100%|██████████| 2000/2000 [00:52<00:00, 37.84it/s]" 392 | ] 393 | }, 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "CPU times: user 15.8 s, sys: 1.19 s, total: 17 s\n", 399 | "Wall time: 52.9 s\n" 400 | ] 401 | }, 402 | { 403 | "name": "stderr", 404 | "output_type": "stream", 405 | "text": [ 406 | "\n" 407 | ] 408 | } 409 | ], 410 | "source": [ 411 | "%%time\n", 412 | "# around 1 minutes\n", 413 | "\n", 414 | "import dask.array as da\n", 415 | "complete_download = da.from_zarr(local_path)\n", 416 | "\n", 417 | "for i in tqdm(range(len(dl))):\n", 418 | " complete_download[small_slice].compute()" 419 | ] 420 | } 421 | ], 422 | "metadata": { 423 | "kernelspec": { 424 | "display_name": "Python 3.9.13 ('HDNn')", 425 | "language": "python", 426 | "name": "python3" 427 | }, 428 | "language_info": { 429 | "codemirror_mode": { 430 | "name": "ipython", 431 | "version": 3 432 | }, 433 | "file_extension": ".py", 434 | "mimetype": "text/x-python", 435 | "name": "python", 436 | "nbconvert_exporter": "python", 437 | "pygments_lexer": "ipython3", 438 | "version": "3.11.6" 439 | }, 440 | "vscode": { 441 | "interpreter": { 442 | "hash": "faf8b084d52efbff00ddf863c4fb0ca7a3b023f9f18590a5b65c31dc02d793e2" 443 | } 444 | } 445 | }, 446 | "nbformat": 4, 447 | "nbformat_minor": 2 448 | } 449 | -------------------------------------------------------------------------------- /data/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementCaporal/pytorch-zarr-loader/f142b5e98b6d6f959ce231f9e6bc93c6f2d63eb7/data/placeholder -------------------------------------------------------------------------------- /example_ZARR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ZARR reading\n", 8 | "\n", 9 | "[Link to example dataset](https://imagesc.zulipchat.com/user_uploads/16804/85qPFC9O85gLhNmF5KLdqtUx/bsd_val.zarr.zip) - copy it under `./data/` and unzip it.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from itertools import islice\n", 19 | "from pathlib import Path\n", 20 | "from typing import List, Tuple, Union, Optional, Callable, Dict, Generator\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import zarr\n", 24 | "import time\n", 25 | "import sys\n", 26 | "\n", 27 | "from torch.utils.data import DataLoader, IterableDataset, get_worker_info" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 120, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def extract_patches_random(arr: np.ndarray,\n", 37 | " patch_size: Union[List[int], Tuple[int]]\n", 38 | ") -> Generator[np.ndarray, None, None]:\n", 39 | " \"\"\"\n", 40 | " Generate patches from an array in a random manner.\n", 41 | "\n", 42 | " The method calculates how many patches the image can be divided into and then\n", 43 | " extracts an equal number of random patches.\n", 44 | "\n", 45 | " Parameters\n", 46 | " ----------\n", 47 | " arr : np.ndarray\n", 48 | " Input image array.\n", 49 | " patch_size : Tuple[int]\n", 50 | " Patch sizes in each dimension.\n", 51 | "\n", 52 | " Yields\n", 53 | " ------\n", 54 | " Generator[np.ndarray, None, None]\n", 55 | " Generator of patches.\n", 56 | " \"\"\"\n", 57 | "\n", 58 | " rng = np.random.default_rng()\n", 59 | "\n", 60 | " n_patches_per_slice = np.ceil(np.prod(arr.shape[1:]) / np.prod(patch_size)).astype(\n", 61 | " int\n", 62 | " )\n", 63 | " crop_coords = rng.integers(\n", 64 | " 0,\n", 65 | " np.array(arr.shape[-len(patch_size):]) - np.array(patch_size),\n", 66 | " size=(arr.shape[0], n_patches_per_slice, len(patch_size)),\n", 67 | " )\n", 68 | " for slice_idx in range(crop_coords.shape[0]):\n", 69 | " sample = arr[slice_idx]\n", 70 | " for patch_idx in range(crop_coords.shape[1]):\n", 71 | " patch = sample[\n", 72 | " crop_coords[slice_idx, patch_idx, 0]: crop_coords[\n", 73 | " slice_idx, patch_idx, 0\n", 74 | " ]\n", 75 | " + patch_size[0],\n", 76 | " crop_coords[slice_idx, patch_idx, 1]: crop_coords[\n", 77 | " slice_idx, patch_idx, 1\n", 78 | " ]\n", 79 | " + patch_size[1],\n", 80 | " ]\n", 81 | " yield patch" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 121, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "class ZarrDataset(IterableDataset):\n", 91 | " \"\"\"Dataset to extract patches from a zarr storage.\"\"\"\n", 92 | "\n", 93 | " def __init__(\n", 94 | " self,\n", 95 | " data_path: Union[str, Path],\n", 96 | " patch_extraction_method: str,\n", 97 | " patch_size: Optional[Union[List[int], Tuple[int]]] = None,\n", 98 | " num_patches: Optional[int] = None,\n", 99 | " mean: Optional[float] = None,\n", 100 | " std: Optional[float] = None,\n", 101 | " patch_transform: Optional[Callable] = None,\n", 102 | " patch_transform_params: Optional[Dict] = None,\n", 103 | " ) -> None:\n", 104 | " self.data_path = Path(data_path)\n", 105 | " self.patch_extraction_method = patch_extraction_method\n", 106 | " self.patch_size = patch_size\n", 107 | " self.num_patches = num_patches\n", 108 | " self.mean = mean\n", 109 | " self.std = std\n", 110 | " self.patch_transform = patch_transform\n", 111 | " self.patch_transform_params = patch_transform_params\n", 112 | "\n", 113 | " self.sample = zarr.open(data_path, mode=\"r\")\n", 114 | "\n", 115 | " def _generate_patches(self):\n", 116 | " patches = extract_patches_random(\n", 117 | " self.sample,\n", 118 | " self.patch_size,\n", 119 | " )\n", 120 | "\n", 121 | " for idx, patch in enumerate(patches):\n", 122 | "\n", 123 | " if isinstance(patch, tuple):\n", 124 | " patch = (patch, *patch[1:])\n", 125 | " else:\n", 126 | " patch = patch\n", 127 | "\n", 128 | " if self.patch_transform is not None:\n", 129 | " assert self.patch_transform_params is not None\n", 130 | " patch = self.patch_transform(patch, **self.patch_transform_params)\n", 131 | " if self.num_patches is not None and idx >= self.num_patches:\n", 132 | " return\n", 133 | " else:\n", 134 | " yield patch\n", 135 | "\n", 136 | " def __iter__(self):\n", 137 | " \"\"\"\n", 138 | " Iterate over data source and yield single patch.\n", 139 | "\n", 140 | " Yields\n", 141 | " ------\n", 142 | " np.ndarray\n", 143 | " \"\"\"\n", 144 | " worker_info = get_worker_info()\n", 145 | " worker_id = worker_info.id if worker_info is not None else 0\n", 146 | " num_workers = worker_info.num_workers if worker_info is not None else 1\n", 147 | " yield from islice(self._generate_patches(), 0, None, num_workers)\n", 148 | "\n", 149 | "def train_loop(dataloader: DataLoader):\n", 150 | " for i, batch in enumerate(dataloader):\n", 151 | " pass" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 123, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "test_path = Path('.') / 'data' / 'bsd_val.zarr' \n", 161 | "train_path_fast = '/localscratch/bsd_train.zarr/'\n", 162 | "\n", 163 | "patch_size = (64, 64)\n", 164 | "\n", 165 | "dataset = ZarrDataset(\n", 166 | " data_path=test_path,\n", 167 | " patch_extraction_method='random',\n", 168 | " patch_size=patch_size,\n", 169 | ")\n", 170 | "\n", 171 | "dl = DataLoader(dataset, batch_size=128, num_workers=0)\n", 172 | "\n" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 124, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "Average time: 0.000us/step\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "times = []\n", 190 | "\n", 191 | "for i, batch in enumerate(dl):\n", 192 | " start = time.time()\n", 193 | " b = batch.shape\n", 194 | " cur_time = time.time() - start\n", 195 | " times.append(cur_time)\n", 196 | " info = f\" {cur_time * 1e6:.3f}us/step\"\n", 197 | "\n", 198 | " print(info, end='\\r')\n", 199 | "\n", 200 | "print(f\"Average time: {np.mean(times) * 1e6:.3f}us/step\")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 127, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "from timeit import timeit, time" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 128, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "0.08461117744445801" 221 | ] 222 | }, 223 | "execution_count": 128, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "# turn previous for loop into function for timeit to work\n", 230 | "def iterate_dl(dl):\n", 231 | " timer = time.time()\n", 232 | " for i, batch in enumerate(dl):\n", 233 | " start = time.time()\n", 234 | " b = batch.shape\n", 235 | " return (time.time() - timer)/(i + 1)\n", 236 | "\n", 237 | "# timeit and add counter of iterations\n", 238 | "iterate_dl(dl)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [] 247 | } 248 | ], 249 | "metadata": { 250 | "kernelspec": { 251 | "display_name": "Python 3.9.13 ('HDNn')", 252 | "language": "python", 253 | "name": "python3" 254 | }, 255 | "language_info": { 256 | "codemirror_mode": { 257 | "name": "ipython", 258 | "version": 3 259 | }, 260 | "file_extension": ".py", 261 | "mimetype": "text/x-python", 262 | "name": "python", 263 | "nbconvert_exporter": "python", 264 | "pygments_lexer": "ipython3", 265 | "version": "3.9.17" 266 | }, 267 | "vscode": { 268 | "interpreter": { 269 | "hash": "faf8b084d52efbff00ddf863c4fb0ca7a3b023f9f18590a5b65c31dc02d793e2" 270 | } 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 2 275 | } 276 | -------------------------------------------------------------------------------- /example_ZARR_daskclient.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ZARR reading with Dask Client\n", 8 | "\n", 9 | "[Link to example dataset](https://imagesc.zulipchat.com/user_uploads/16804/85qPFC9O85gLhNmF5KLdqtUx/bsd_val.zarr.zip) - copy it under `./data/` and unzip it.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "\n", 19 | "from itertools import islice\n", 20 | "from pathlib import Path\n", 21 | "from typing import List, Tuple, Union, Optional, Callable, Dict, Generator\n", 22 | "import time\n", 23 | "import numpy as np\n", 24 | "import zarr\n", 25 | "from torch.utils.data import DataLoader, IterableDataset, get_worker_info\n", 26 | "from dask.distributed import Client, get_client\n", 27 | "\n", 28 | "from timeit import timeit, time\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "test_path = Path('.') / 'data' / 'test_ngff_image.zarr/0' \n", 38 | "patch_size = (64, 64)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "def read_zarr(file_path: Path) -> Union[zarr.core.Array, zarr.storage.DirectoryStore, zarr.hierarchy.Group]:\n", 48 | " \"\"\"Reads a file and returns a pointer.\n", 49 | "\n", 50 | " Parameters\n", 51 | " ----------\n", 52 | " file_path : Path\n", 53 | " pathlib.Path object containing a path to a file\n", 54 | "\n", 55 | " Returns\n", 56 | " -------\n", 57 | " np.ndarray\n", 58 | " Pointer to zarr storage\n", 59 | "\n", 60 | " Raises\n", 61 | " ------\n", 62 | " ValueError, OSError\n", 63 | " if a file is not a valid tiff or damaged\n", 64 | " ValueError\n", 65 | " if data dimensions are not 2, 3 or 4\n", 66 | " ValueError\n", 67 | " if axes parameter from config is not consistent with data dimensions\n", 68 | " \"\"\"\n", 69 | " zarr_source = zarr.open(Path(file_path), mode=\"r\")\n", 70 | " \n", 71 | " \"\"\"\n", 72 | " if isinstance(zarr_source, zarr.hierarchy.Group):\n", 73 | " raise NotImplementedError(\"Group not supported yet\")\n", 74 | "\n", 75 | " elif isinstance(zarr_source, zarr.storage.DirectoryStore):\n", 76 | " raise NotImplementedError(\"DirectoryStore not supported yet\")\n", 77 | "\n", 78 | " elif isinstance(zarr_source, zarr.core.Array):\n", 79 | " # array should be of shape (S, (C), (Z), Y, X), iterating over S ?\n", 80 | " # TODO what if array is not of that shape and/or chunks aren't defined and\n", 81 | " if zarr_source.dtype == \"O\":\n", 82 | " raise NotImplementedError(\"Object type not supported yet\")\n", 83 | " else:\n", 84 | " array = zarr_source\n", 85 | " else:\n", 86 | " raise ValueError(f\"Unsupported zarr object type {type(zarr_source)}\")\n", 87 | "\n", 88 | " # TODO how to fix dimensions? Or just raise error?\n", 89 | " # sanity check on dimensions\n", 90 | " if len(array.shape) < 2 or len(array.shape) > 4:\n", 91 | " raise ValueError(\n", 92 | " f\"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape}).\"\n", 93 | " )\n", 94 | " \"\"\"\n", 95 | " return zarr_source" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "import numpy as np\n", 105 | "from typing import List, Tuple, Union\n", 106 | "\n", 107 | "def extract_patches_random(arr: np.ndarray,\n", 108 | " patch_size: Union[List[int], Tuple[int, ...]],\n", 109 | " num_patches: int) -> List[np.ndarray]:\n", 110 | " \"\"\"\n", 111 | " Extract a specified number of patches from an array in a random manner.\n", 112 | "\n", 113 | " Parameters\n", 114 | " ----------\n", 115 | " arr : np.ndarray\n", 116 | " Input array from which to extract patches.\n", 117 | " patch_size : Tuple[int, ...]\n", 118 | " Patch sizes in each dimension.\n", 119 | " num_patches : int\n", 120 | " Number of patches to return.\n", 121 | "\n", 122 | " Returns\n", 123 | " -------\n", 124 | " List[np.ndarray]\n", 125 | " List of randomly selected patches.\n", 126 | " \"\"\"\n", 127 | "\n", 128 | " rng = np.random.default_rng()\n", 129 | " patches = []\n", 130 | " patch_centers_x = np.random.randint(low=patch_size[0] // 2,\n", 131 | " high=arr.shape[-1] - patch_size[0] // 2,\n", 132 | " size=num_patches)\n", 133 | " patch_centers_y = np.random.randint(low=patch_size[1] // 2,\n", 134 | " high=arr.shape[-2] - patch_size[1] // 2,\n", 135 | " size=num_patches)\n", 136 | " slice_indeces = np.random.randint(low=0, high=arr.shape[0], size=num_patches)\n", 137 | " \n", 138 | " for i, x, y in zip(slice_indeces, patch_centers_x, patch_centers_y):\n", 139 | " patch = arr[i, \n", 140 | " y - patch_size[1] // 2 : y + patch_size[1] // 2,\n", 141 | " x - patch_size[0] // 2 : x + patch_size[0] // 2]\n", 142 | "\n", 143 | " patches.append(patch)\n", 144 | "\n", 145 | " return np.stack(patches)\n" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 5, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stderr", 155 | "output_type": "stream", 156 | "text": [ 157 | "c:\\Users\\johamuel\\AppData\\Local\\mambaforge\\envs\\pytorch-2d-unet\\lib\\site-packages\\distributed\\node.py:182: UserWarning: Port 8787 is already in use.\n", 158 | "Perhaps you already have a cluster running?\n", 159 | "Hosting the HTTP server on port 65013 instead\n", 160 | " warnings.warn(\n" 161 | ] 162 | }, 163 | { 164 | "data": { 165 | "text/plain": [ 166 | "'http://127.0.0.1:65013/status'" 167 | ] 168 | }, 169 | "execution_count": 5, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "try:\n", 176 | " client = get_client()\n", 177 | "except ValueError:\n", 178 | " client = Client()\n", 179 | "client.dashboard_link" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 9, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "class ZarrDataset(IterableDataset):\n", 189 | " \"\"\"Dataset to extract patches from a zarr storage.\"\"\"\n", 190 | "\n", 191 | " def __init__(\n", 192 | " self,\n", 193 | " data_path: Union[str, Path],\n", 194 | " patch_extraction_method: str,\n", 195 | " patch_size: Optional[Union[List[int], Tuple[int]]] = None,\n", 196 | " num_patches: Optional[int] = None,\n", 197 | " mean: Optional[float] = None,\n", 198 | " std: Optional[float] = None,\n", 199 | " patch_transform: Optional[Callable] = None,\n", 200 | " patch_transform_params: Optional[Dict] = None,\n", 201 | " num_load_at_once: int = 20,\n", 202 | " n_shuffle_coordinates: int = 20,\n", 203 | " ) -> None:\n", 204 | " self.data_path = Path(data_path)\n", 205 | " self.patch_extraction_method = patch_extraction_method\n", 206 | " self.patch_size = patch_size\n", 207 | " self.num_patches = num_patches\n", 208 | " self.mean = mean\n", 209 | " self.std = std\n", 210 | " self.patch_transform = patch_transform\n", 211 | " self.patch_transform_params = patch_transform_params\n", 212 | " self.num_load_at_once = num_load_at_once\n", 213 | " self.n_shuffle_coordinates = n_shuffle_coordinates\n", 214 | "\n", 215 | " self.sample = read_zarr(self.data_path)\n", 216 | "\n", 217 | " def __iter__(self):\n", 218 | " \"\"\"\n", 219 | " Iterate over data source and yield single patch.\n", 220 | "\n", 221 | " Yields\n", 222 | " ------\n", 223 | " np.ndarray\n", 224 | " \"\"\"\n", 225 | " worker_info = get_worker_info()\n", 226 | " worker_id = worker_info.id if worker_info is not None else 0\n", 227 | " num_workers = worker_info.num_workers if worker_info is not None else 1\n", 228 | " \n", 229 | " # future = client.submit(extract_patches_random,\n", 230 | " # self.sample,\n", 231 | " # self.patch_size,\n", 232 | " # self.num_load_at_once)\n", 233 | "\n", 234 | " data_in_memory = extract_patches_random(self.sample,\n", 235 | " self.patch_size,\n", 236 | " self.num_load_at_once)\n", 237 | "\n", 238 | " for _ in range(self.n_shuffle_coordinates):\n", 239 | " #data_in_memory = future.result()\n", 240 | " # future = client.submit(extract_patches_random,\n", 241 | " # self.sample,\n", 242 | " # self.patch_size,\n", 243 | " # self.num_load_at_once)\n", 244 | "\n", 245 | " for j in range(len(data_in_memory)):\n", 246 | " yield data_in_memory[j]\n" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 10, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "dataset = ZarrDataset(\n", 256 | " data_path=test_path,\n", 257 | " patch_extraction_method='random',\n", 258 | " patch_size=patch_size,\n", 259 | " num_load_at_once=20,\n", 260 | " n_shuffle_coordinates=100\n", 261 | ")\n", 262 | "\n", 263 | "dl = DataLoader(dataset, batch_size=128, num_workers=0)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 11, 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/plain": [ 274 | "0.0031248480081558228" 275 | ] 276 | }, 277 | "execution_count": 11, 278 | "metadata": {}, 279 | "output_type": "execute_result" 280 | } 281 | ], 282 | "source": [ 283 | "# turn previous for loop into function for timeit to work\n", 284 | "def iterate_dl(dl):\n", 285 | " timer = time.time()\n", 286 | " for i, batch in enumerate(dl):\n", 287 | " start = time.time()\n", 288 | " b = batch.shape\n", 289 | " return (time.time() - timer)/(i + 1)\n", 290 | "\n", 291 | "# timeit and add counter of iterations\n", 292 | "iterate_dl(dl)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [] 301 | } 302 | ], 303 | "metadata": { 304 | "kernelspec": { 305 | "display_name": "Python 3.9.13 ('HDNn')", 306 | "language": "python", 307 | "name": "python3" 308 | }, 309 | "language_info": { 310 | "codemirror_mode": { 311 | "name": "ipython", 312 | "version": 3 313 | }, 314 | "file_extension": ".py", 315 | "mimetype": "text/x-python", 316 | "name": "python", 317 | "nbconvert_exporter": "python", 318 | "pygments_lexer": "ipython3", 319 | "version": "3.9.17" 320 | }, 321 | "vscode": { 322 | "interpreter": { 323 | "hash": "faf8b084d52efbff00ddf863c4fb0ca7a3b023f9f18590a5b65c31dc02d793e2" 324 | } 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 2 329 | } 330 | -------------------------------------------------------------------------------- /ome_xarray.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# from https://discuss.pytorch.org/t/dataloader-parallelization-synchronization-with-zarr-xarray-dask/176149\n", 10 | "# and https://gist.github.com/d-v-b/f460c7f673819d431cc958a04acbab8a" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "\n", 24 | "dask.array\n", 25 | "Coordinates:\n", 26 | " * c (c) float64 0.0 1.0\n", 27 | " * z (z) float64 0.0 0.5002 1.0 1.501 2.001 ... 116.0 116.5 117.0 117.5\n", 28 | " * y (y) float64 0.0 0.3604 0.7208 1.081 ... 97.67 98.03 98.39 98.75\n", 29 | " * x (x) float64 0.0 0.3604 0.7208 1.081 ... 96.23 96.59 96.95 97.31\n" 30 | ] 31 | }, 32 | { 33 | "data": { 34 | "text/plain": [ 35 | "\"\\n\\ndask.array\\nCoordinates:\\n * c (c) float64 0.0 1.0\\n * z (z) float64 0.0 0.5002 1.0 1.501 2.001 ... 116.0 116.5 117.0 117.5\\n * y (y) float64 0.0 0.3604 0.7208 1.081 ... 97.67 98.03 98.39 98.75\\n * x (x) float64 0.0 0.3604 0.7208 1.081 ... 96.23 96.59 96.95 97.31\\n\"" 36 | ] 37 | }, 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "output_type": "execute_result" 41 | } 42 | ], 43 | "source": [ 44 | "from xarray_ome_ngff.registry import get_adapters\n", 45 | "import zarr\n", 46 | "from typing import Union\n", 47 | "import dask.array as da\n", 48 | "from xarray import DataArray\n", 49 | "import os\n", 50 | "\n", 51 | "def infer_coords(group: zarr.Group, array: zarr.Array):\n", 52 | " # these conditionals should be handled by a lower-level validation function\n", 53 | " if 'multiscales' in group.attrs:\n", 54 | " multiscales = group.attrs['multiscales']\n", 55 | " if len(multiscales) > 0:\n", 56 | " # note that technically the spec allows multiple references to the same zarr array\n", 57 | " # because multiscales is a list\n", 58 | " multiscale = multiscales[0]\n", 59 | " ngff_version = multiscale.get(\"version\", None)\n", 60 | " # get the appropriate Multiscale model depending on the version\n", 61 | " if ngff_version == \"0.4\":\n", 62 | " from pydantic_ome_ngff.v04 import Multiscale\n", 63 | " elif ngff_version == \"0.5-dev\":\n", 64 | " from pydantic_ome_ngff.latest import Multiscale\n", 65 | " else:\n", 66 | " raise ValueError(\n", 67 | " \"Could not resolve the version of the multiscales metadata \",\n", 68 | " f\"found in the group metadata {dict(group.attrs)}\",\n", 69 | " )\n", 70 | " else:\n", 71 | " raise ValueError(\"Multiscales attribute was empty.\")\n", 72 | " else:\n", 73 | " raise ValueError(\"Multiscales attribute not found.\")\n", 74 | " xarray_adapters = get_adapters(ngff_version)\n", 75 | " multiscales_meta = [Multiscale(**entry) for entry in multiscales]\n", 76 | " transforms = []\n", 77 | " axes = []\n", 78 | " matched_multiscale, matched_dataset = None, None\n", 79 | " # find the correct element in multiscales.datasets for this array\n", 80 | " for multi in multiscales_meta:\n", 81 | " for dataset in multi.datasets:\n", 82 | " if dataset.path == array.basename:\n", 83 | " matched_multiscale = multi\n", 84 | " matched_dataset = dataset\n", 85 | " if matched_dataset is None or matched_multiscale is None:\n", 86 | " raise ValueError(\n", 87 | " f\"\"\"\n", 88 | " Could not find an entry referencing array {array.basename}\n", 89 | " in the `multiscales` metadata of the parent group.\n", 90 | " \"\"\"\n", 91 | " )\n", 92 | " else:\n", 93 | " if matched_multiscale.coordinateTransformations is not None:\n", 94 | " transforms.extend(matched_multiscale.coordinateTransformations)\n", 95 | " transforms.extend(matched_dataset.coordinateTransformations)\n", 96 | " axes.extend(matched_multiscale.axes)\n", 97 | " coords = xarray_adapters.transforms_to_coords(axes, transforms, array.shape)\n", 98 | " return coords\n", 99 | "\n", 100 | "\n", 101 | "def read_dataarray(group: zarr.Group, array: zarr.Array, use_dask: bool = True, **kwargs) -> DataArray:\n", 102 | " coords = infer_coords(group, array)\n", 103 | " if use_dask:\n", 104 | " data = da.from_array(array, **kwargs)\n", 105 | " else:\n", 106 | " data = array\n", 107 | " return DataArray(data, coords)\n", 108 | "\n", 109 | "def test_read_dataarray():\n", 110 | " path = \"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0062A/6001240.zarr/\"\n", 111 | " z_group = zarr.open(path, mode='r')\n", 112 | " z_array = zarr.open(store=z_group.store, path = '0')\n", 113 | " d_array = read_dataarray(z_group, z_array)\n", 114 | " print(d_array)\n", 115 | "\n", 116 | "# if __name__ == '__main__':\n", 117 | "# test_read_dataarray()\n", 118 | "\n", 119 | "\"\"\"\n", 120 | "\n", 122 | "dask.array\n", 123 | "Coordinates:\n", 124 | " * c (c) float64 0.0 1.0\n", 125 | " * z (z) float64 0.0 0.5002 1.0 1.501 2.001 ... 116.0 116.5 117.0 117.5\n", 126 | " * y (y) float64 0.0 0.3604 0.7208 1.081 ... 97.67 98.03 98.39 98.75\n", 127 | " * x (x) float64 0.0 0.3604 0.7208 1.081 ... 96.23 96.59 96.95 97.31\n", 128 | "\"\"\"" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 66, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stderr", 138 | "output_type": "stream", 139 | "text": [ 140 | " 29%|██▉ | 92/313 [07:36<18:17, 4.97s/it]\n" 141 | ] 142 | }, 143 | { 144 | "ename": "KeyboardInterrupt", 145 | "evalue": "", 146 | "output_type": "error", 147 | "traceback": [ 148 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 149 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 150 | "\u001b[1;32m/home/clement/Documents/pytorch-zarr-loader/ome_xarray.ipynb Cell 3\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 31\u001b[0m \u001b[39m### Define and test the Dataloader. This will stall for num_workers > 0 and prefetch_factor > 0.\u001b[39;00m\n\u001b[1;32m 32\u001b[0m train_dataloader \u001b[39m=\u001b[39m DataLoader(train_data, batch_size\u001b[39m=\u001b[39m \u001b[39m32\u001b[39m, num_workers \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m, prefetch_factor\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m)\n\u001b[0;32m---> 33\u001b[0m \u001b[39mfor\u001b[39;49;00m X \u001b[39min\u001b[39;49;00m tqdm(train_dataloader):\n\u001b[1;32m 34\u001b[0m \u001b[39mpass\u001b[39;49;00m\n\u001b[1;32m 35\u001b[0m \u001b[39m# np.matmul(X,X) # do something\u001b[39;00m\n", 151 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/tqdm/std.py:1182\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1179\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[1;32m 1181\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1182\u001b[0m \u001b[39mfor\u001b[39;49;00m obj \u001b[39min\u001b[39;49;00m iterable:\n\u001b[1;32m 1183\u001b[0m \u001b[39myield\u001b[39;49;00m obj\n\u001b[1;32m 1184\u001b[0m \u001b[39m# Update and possibly print the progressbar.\u001b[39;49;00m\n\u001b[1;32m 1185\u001b[0m \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;49;00m\n", 152 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/dataloader.py:634\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 631\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 632\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 633\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 634\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 635\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 636\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 638\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", 153 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/dataloader.py:678\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 677\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 678\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 679\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 680\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n", 154 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset\u001b[39m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;49;00m idx \u001b[39min\u001b[39;49;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", 155 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset\u001b[39m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", 156 | "\u001b[1;32m/home/clement/Documents/pytorch-zarr-loader/ome_xarray.ipynb Cell 3\u001b[0m line \u001b[0;36m2\n\u001b[1;32m 20\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, idx):\n\u001b[0;32m---> 21\u001b[0m image_npy \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdata[\u001b[39m.\u001b[39;49m\u001b[39m.\u001b[39;49m\u001b[39m.\u001b[39;49m, idx]\u001b[39m.\u001b[39;49mto_numpy()\n\u001b[1;32m 22\u001b[0m image \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mas_tensor(image_npy, dtype \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mfloat)\n\u001b[1;32m 23\u001b[0m \u001b[39m# return image\u001b[39;00m\n", 157 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/xarray/core/dataarray.py:778\u001b[0m, in \u001b[0;36mDataArray.to_numpy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mto_numpy\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m np\u001b[39m.\u001b[39mndarray:\n\u001b[1;32m 768\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 769\u001b[0m \u001b[39m Coerces wrapped data to numpy and returns a numpy.ndarray.\u001b[39;00m\n\u001b[1;32m 770\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[39m DataArray.data\u001b[39;00m\n\u001b[1;32m 777\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 778\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mvariable\u001b[39m.\u001b[39;49mto_numpy()\n", 158 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/xarray/core/variable.py:1096\u001b[0m, in \u001b[0;36mVariable.to_numpy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1094\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Coerces wrapped data to numpy and returns a numpy.ndarray\"\"\"\u001b[39;00m\n\u001b[1;32m 1095\u001b[0m \u001b[39m# TODO an entrypoint so array libraries can choose coercion method?\u001b[39;00m\n\u001b[0;32m-> 1096\u001b[0m \u001b[39mreturn\u001b[39;00m to_numpy(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_data)\n", 159 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/xarray/core/pycompat.py:116\u001b[0m, in \u001b[0;36mto_numpy\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(data, \u001b[39m\"\u001b[39m\u001b[39mchunks\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 115\u001b[0m chunkmanager \u001b[39m=\u001b[39m get_chunked_array_type(data)\n\u001b[0;32m--> 116\u001b[0m data, \u001b[39m*\u001b[39m_ \u001b[39m=\u001b[39m chunkmanager\u001b[39m.\u001b[39;49mcompute(data)\n\u001b[1;32m 117\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, array_type(\u001b[39m\"\u001b[39m\u001b[39mcupy\u001b[39m\u001b[39m\"\u001b[39m)):\n\u001b[1;32m 118\u001b[0m data \u001b[39m=\u001b[39m data\u001b[39m.\u001b[39mget()\n", 160 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/xarray/core/daskmanager.py:70\u001b[0m, in \u001b[0;36mDaskManager.compute\u001b[0;34m(self, *data, **kwargs)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcompute\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39mdata: DaskArray, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mtuple\u001b[39m[np\u001b[39m.\u001b[39mndarray, \u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\u001b[39m.\u001b[39m]:\n\u001b[1;32m 68\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdask\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39marray\u001b[39;00m \u001b[39mimport\u001b[39;00m compute\n\u001b[0;32m---> 70\u001b[0m \u001b[39mreturn\u001b[39;00m compute(\u001b[39m*\u001b[39;49mdata, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", 161 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/dask/base.py:628\u001b[0m, in \u001b[0;36mcompute\u001b[0;34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[0m\n\u001b[1;32m 625\u001b[0m postcomputes\u001b[39m.\u001b[39mappend(x\u001b[39m.\u001b[39m__dask_postcompute__())\n\u001b[1;32m 627\u001b[0m \u001b[39mwith\u001b[39;00m shorten_traceback():\n\u001b[0;32m--> 628\u001b[0m results \u001b[39m=\u001b[39m schedule(dsk, keys, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 630\u001b[0m \u001b[39mreturn\u001b[39;00m repack([f(r, \u001b[39m*\u001b[39ma) \u001b[39mfor\u001b[39;00m r, (f, a) \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(results, postcomputes)])\n", 162 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/queue.py:171\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[39melif\u001b[39;00m timeout \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 170\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_qsize():\n\u001b[0;32m--> 171\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mnot_empty\u001b[39m.\u001b[39;49mwait()\n\u001b[1;32m 172\u001b[0m \u001b[39melif\u001b[39;00m timeout \u001b[39m<\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m 173\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39mtimeout\u001b[39m\u001b[39m'\u001b[39m\u001b[39m must be a non-negative number\u001b[39m\u001b[39m\"\u001b[39m)\n", 163 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/threading.py:327\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[39mtry\u001b[39;00m: \u001b[39m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[39mif\u001b[39;00m timeout \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 327\u001b[0m waiter\u001b[39m.\u001b[39;49macquire()\n\u001b[1;32m 328\u001b[0m gotit \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[39melse\u001b[39;00m:\n", 164 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "%%time\n", 170 | "# around 25 minutes?\n", 171 | "\n", 172 | "import numpy as np\n", 173 | "import xarray as xr\n", 174 | "from tqdm import tqdm\n", 175 | "from torch.utils.data import Dataset, DataLoader\n", 176 | "import torch\n", 177 | "\n", 178 | "### Define the Dataset\n", 179 | "class XRData(Dataset):\n", 180 | " def __init__(self, path):\n", 181 | " # self.data = xr.open_zarr(path).to_array()\n", 182 | "\n", 183 | " z_group = zarr.open(path, mode='r')\n", 184 | " z_array = zarr.open(store=z_group.store, path = '0')\n", 185 | " d_array = read_dataarray(z_group, z_array)\n", 186 | " self.data = d_array\n", 187 | "\n", 188 | " def __len__(self):\n", 189 | " return self.data.shape[-1]\n", 190 | "\n", 191 | " def __getitem__(self, idx):\n", 192 | " image_npy = self.data[..., idx].to_numpy()\n", 193 | " image = torch.as_tensor(image_npy, dtype = torch.float)\n", 194 | " # return image\n", 195 | " return image\n", 196 | "\n", 197 | "data_path = 'data/huge.zarr/'\n", 198 | "train_data = XRData(data_path)\n", 199 | "print(train_data.__getitem__(0).shape)\n", 200 | "\n", 201 | "### Define and test the Dataloader. This will stall for num_workers > 0 and prefetch_factor > 0.\n", 202 | "train_dataloader = DataLoader(train_data, batch_size= 32, num_workers = 0, prefetch_factor=None)\n", 203 | "for X in tqdm(train_dataloader):\n", 204 | " pass" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 65, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "name": "stderr", 214 | "output_type": "stream", 215 | "text": [ 216 | "100%|██████████| 312/312 [02:08<00:00, 2.43it/s]\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "%%time\n", 222 | "# around 2 minutes?\n", 223 | "\n", 224 | "### WORKAROUND\n", 225 | "class XRBatchData(Dataset):\n", 226 | " def __init__(self, path, batch_size):\n", 227 | "\n", 228 | " z_group = zarr.open(path, mode='r')\n", 229 | " z_array = zarr.open(store=z_group.store, path = '0')\n", 230 | " d_array = read_dataarray(z_group, z_array)\n", 231 | " self.data = d_array\n", 232 | "\n", 233 | " self.batch_size = batch_size\n", 234 | "\n", 235 | " def __len__(self):\n", 236 | " return int(int(self.data.shape[-1])/self.batch_size)\n", 237 | " # return int(len(self.data.global_id)/self.batch_size)\n", 238 | "\n", 239 | " def __getitem__(self, idx):\n", 240 | " image_npy = self.data[..., slice(idx*self.batch_size, (idx+1)*self.batch_size)].to_numpy()\n", 241 | " # image_npy = self.data.isel(global_id = slice(idx*self.batch_size, (idx+1)*self.batch_size)).to_numpy()\n", 242 | " image = torch.as_tensor(image_npy, dtype = torch.float)\n", 243 | " return image\n", 244 | "\n", 245 | "batch_size = 32\n", 246 | "train_data = XRBatchData(data_path, batch_size)\n", 247 | "print(train_data.__getitem__(0).shape)\n", 248 | "\n", 249 | "train_dataloader = DataLoader(train_data, batch_size= 1, num_workers = 0, prefetch_factor=None)\n", 250 | "for X in tqdm(train_dataloader):\n", 251 | " pass" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 68, 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "torch.Size([1000, 1000])\n" 264 | ] 265 | }, 266 | { 267 | "name": "stderr", 268 | "output_type": "stream", 269 | "text": [ 270 | " 20%|██ | 64/313 [07:05<27:34, 6.64s/it]\n" 271 | ] 272 | }, 273 | { 274 | "ename": "KeyboardInterrupt", 275 | "evalue": "", 276 | "output_type": "error", 277 | "traceback": [ 278 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 279 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 280 | "File \u001b[0;32m:29\u001b[0m\n", 281 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/tqdm/std.py:1182\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1179\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[1;32m 1181\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1182\u001b[0m \u001b[39mfor\u001b[39;49;00m obj \u001b[39min\u001b[39;49;00m iterable:\n\u001b[1;32m 1183\u001b[0m \u001b[39myield\u001b[39;49;00m obj\n\u001b[1;32m 1184\u001b[0m \u001b[39m# Update and possibly print the progressbar.\u001b[39;49;00m\n\u001b[1;32m 1185\u001b[0m \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;49;00m\n", 282 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/dataloader.py:634\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 631\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 632\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 633\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 634\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 635\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 636\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 638\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", 283 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/dataloader.py:678\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 677\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 678\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 679\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 680\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n", 284 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset\u001b[39m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;49;00m idx \u001b[39min\u001b[39;49;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", 285 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset\u001b[39m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", 286 | "File \u001b[0;32m:18\u001b[0m, in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n", 287 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/zarr/core.py:826\u001b[0m, in \u001b[0;36mArray.__getitem__\u001b[0;34m(self, selection)\u001b[0m\n\u001b[1;32m 824\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_orthogonal_selection(pure_selection, fields\u001b[39m=\u001b[39mfields)\n\u001b[1;32m 825\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 826\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mget_basic_selection(pure_selection, fields\u001b[39m=\u001b[39;49mfields)\n\u001b[1;32m 827\u001b[0m \u001b[39mreturn\u001b[39;00m result\n", 288 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/zarr/core.py:952\u001b[0m, in \u001b[0;36mArray.get_basic_selection\u001b[0;34m(self, selection, out, fields)\u001b[0m\n\u001b[1;32m 949\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_basic_selection_zd(selection\u001b[39m=\u001b[39mselection, out\u001b[39m=\u001b[39mout,\n\u001b[1;32m 950\u001b[0m fields\u001b[39m=\u001b[39mfields)\n\u001b[1;32m 951\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 952\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_get_basic_selection_nd(selection\u001b[39m=\u001b[39;49mselection, out\u001b[39m=\u001b[39;49mout,\n\u001b[1;32m 953\u001b[0m fields\u001b[39m=\u001b[39;49mfields)\n", 289 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/zarr/core.py:995\u001b[0m, in \u001b[0;36mArray._get_basic_selection_nd\u001b[0;34m(self, selection, out, fields)\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_get_basic_selection_nd\u001b[39m(\u001b[39mself\u001b[39m, selection, out\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, fields\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m):\n\u001b[1;32m 990\u001b[0m \u001b[39m# implementation of basic selection for array with at least one dimension\u001b[39;00m\n\u001b[1;32m 991\u001b[0m \n\u001b[1;32m 992\u001b[0m \u001b[39m# setup indexer\u001b[39;00m\n\u001b[1;32m 993\u001b[0m indexer \u001b[39m=\u001b[39m BasicIndexer(selection, \u001b[39mself\u001b[39m)\n\u001b[0;32m--> 995\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_get_selection(indexer\u001b[39m=\u001b[39;49mindexer, out\u001b[39m=\u001b[39;49mout, fields\u001b[39m=\u001b[39;49mfields)\n", 290 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/zarr/core.py:1284\u001b[0m, in \u001b[0;36mArray._get_selection\u001b[0;34m(self, indexer, out, fields)\u001b[0m\n\u001b[1;32m 1281\u001b[0m \u001b[39mif\u001b[39;00m math\u001b[39m.\u001b[39mprod(out_shape) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m 1282\u001b[0m \u001b[39m# allow storage to get multiple items at once\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m lchunk_coords, lchunk_selection, lout_selection \u001b[39m=\u001b[39m \u001b[39mzip\u001b[39m(\u001b[39m*\u001b[39mindexer)\n\u001b[0;32m-> 1284\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_chunk_getitems(\n\u001b[1;32m 1285\u001b[0m lchunk_coords, lchunk_selection, out, lout_selection,\n\u001b[1;32m 1286\u001b[0m drop_axes\u001b[39m=\u001b[39;49mindexer\u001b[39m.\u001b[39;49mdrop_axes, fields\u001b[39m=\u001b[39;49mfields\n\u001b[1;32m 1287\u001b[0m )\n\u001b[1;32m 1288\u001b[0m \u001b[39mif\u001b[39;00m out\u001b[39m.\u001b[39mshape:\n\u001b[1;32m 1289\u001b[0m \u001b[39mreturn\u001b[39;00m out\n", 291 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/zarr/core.py:2032\u001b[0m, in \u001b[0;36mArray._chunk_getitems\u001b[0;34m(self, lchunk_coords, lchunk_selection, out, lout_selection, drop_axes, fields)\u001b[0m\n\u001b[1;32m 2030\u001b[0m \u001b[39mfor\u001b[39;00m ckey, chunk_select, out_select \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(ckeys, lchunk_selection, lout_selection):\n\u001b[1;32m 2031\u001b[0m \u001b[39mif\u001b[39;00m ckey \u001b[39min\u001b[39;00m cdatas:\n\u001b[0;32m-> 2032\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_process_chunk(\n\u001b[1;32m 2033\u001b[0m out,\n\u001b[1;32m 2034\u001b[0m cdatas[ckey],\n\u001b[1;32m 2035\u001b[0m chunk_select,\n\u001b[1;32m 2036\u001b[0m drop_axes,\n\u001b[1;32m 2037\u001b[0m out_is_ndarray,\n\u001b[1;32m 2038\u001b[0m fields,\n\u001b[1;32m 2039\u001b[0m out_select,\n\u001b[1;32m 2040\u001b[0m partial_read_decode\u001b[39m=\u001b[39;49mpartial_read_decode,\n\u001b[1;32m 2041\u001b[0m )\n\u001b[1;32m 2042\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 2043\u001b[0m \u001b[39m# check exception type\u001b[39;00m\n\u001b[1;32m 2044\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fill_value \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n", 292 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/zarr/core.py:1946\u001b[0m, in \u001b[0;36mArray._process_chunk\u001b[0;34m(self, out, cdata, chunk_selection, drop_axes, out_is_ndarray, fields, out_selection, partial_read_decode)\u001b[0m\n\u001b[1;32m 1944\u001b[0m \u001b[39mexcept\u001b[39;00m ArrayIndexError:\n\u001b[1;32m 1945\u001b[0m cdata \u001b[39m=\u001b[39m cdata\u001b[39m.\u001b[39mread_full()\n\u001b[0;32m-> 1946\u001b[0m chunk \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_decode_chunk(cdata)\n\u001b[1;32m 1948\u001b[0m \u001b[39m# select data from chunk\u001b[39;00m\n\u001b[1;32m 1949\u001b[0m \u001b[39mif\u001b[39;00m fields:\n", 293 | "File \u001b[0;32m~/miniforge3/envs/pytorch-zarr-loader/lib/python3.11/site-packages/zarr/core.py:2202\u001b[0m, in \u001b[0;36mArray._decode_chunk\u001b[0;34m(self, cdata, start, nitems, expected_shape)\u001b[0m\n\u001b[1;32m 2200\u001b[0m chunk \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compressor\u001b[39m.\u001b[39mdecode_partial(cdata, start, nitems)\n\u001b[1;32m 2201\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 2202\u001b[0m chunk \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_compressor\u001b[39m.\u001b[39;49mdecode(cdata)\n\u001b[1;32m 2203\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 2204\u001b[0m chunk \u001b[39m=\u001b[39m cdata\n", 294 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "%%time\n", 300 | "# around 35 minutes?\n", 301 | "\n", 302 | "import numpy as np\n", 303 | "import xarray as xr\n", 304 | "from tqdm import tqdm\n", 305 | "from torch.utils.data import Dataset, DataLoader\n", 306 | "import torch\n", 307 | "\n", 308 | "### Define the Dataset\n", 309 | "class ZarrRData(Dataset):\n", 310 | " def __init__(self, path):\n", 311 | " self.data = zarr.open(path, mode='r')\n", 312 | "\n", 313 | " def __len__(self):\n", 314 | " return self.data.shape[-1]\n", 315 | "\n", 316 | " def __getitem__(self, idx):\n", 317 | " image_npy = self.data[..., idx]\n", 318 | " image = torch.as_tensor(image_npy, dtype = torch.float)\n", 319 | " # return image\n", 320 | " return image\n", 321 | "\n", 322 | "data_path = 'data/huge.zarr/0'\n", 323 | "train_data = ZarrRData(data_path)\n", 324 | "print(train_data.__getitem__(0).shape)\n", 325 | "\n", 326 | "### Define and test the Dataloader. This will stall for num_workers > 0 and prefetch_factor > 0.\n", 327 | "train_dataloader = DataLoader(train_data, batch_size= 32, num_workers = 0, prefetch_factor=None)\n", 328 | "for X in tqdm(train_dataloader):\n", 329 | " pass" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 69, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "data": { 339 | "text/plain": [ 340 | "35" 341 | ] 342 | }, 343 | "execution_count": 69, 344 | "metadata": {}, 345 | "output_type": "execute_result" 346 | } 347 | ], 348 | "source": [ 349 | "5*7" 350 | ] 351 | } 352 | ], 353 | "metadata": { 354 | "kernelspec": { 355 | "display_name": "pytorch-zarr-loader", 356 | "language": "python", 357 | "name": "python3" 358 | }, 359 | "language_info": { 360 | "codemirror_mode": { 361 | "name": "ipython", 362 | "version": 3 363 | }, 364 | "file_extension": ".py", 365 | "mimetype": "text/x-python", 366 | "name": "python", 367 | "nbconvert_exporter": "python", 368 | "pygments_lexer": "ipython3", 369 | "version": "3.11.6" 370 | } 371 | }, 372 | "nbformat": 4, 373 | "nbformat_minor": 2 374 | } 375 | --------------------------------------------------------------------------------