├── LICENSE ├── README.md ├── examples ├── configs │ └── example.yaml └── image_simple.py ├── requirements_pt1.txt ├── requirements_pt2.txt ├── sdata ├── __init__.py ├── custom_datapipes.py ├── datapipeline.py ├── dataset.py ├── dummy.py ├── filters │ ├── __init__.py │ ├── base.py │ └── metadata_filters.py └── mappers │ ├── __init__.py │ ├── base.py │ ├── batched_mappers.py │ └── sample_mappers.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 StabilityAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # datapipelines 2 | 3 | 4 | Iterable datapipelines for pytorch training. 5 | 6 | The functions `sdata.create_dataset()` and `sdata.create_loader()` provide interfaces for your pytorch training code, where the former returns 7 | a dataset and the latter a wrapper around a pytorch dataloader. 8 | 9 | A dataset as returned by `sdata.create_dataset()` consists of 5 main modules should be defined in a yaml-config: 10 | 1. A base [datapipeline](./sdata/datapipeline.py#L306), which reads data as tar files from local fs and assembles them to samples. Each sample comes as a python-dict. 11 | 2. A list of [preprocessors](sdata/dataset.py#L129) which can be either used to transform the entries of a sample or to filter out unsuitable samples. The former kinds are called `mappers`, the latter `filters`. This repository provides a basic set of [mappers](sdata/mappers) and [filters](sdata/filters) which provide basic (not too application specific) data transforms and filters. 12 | 3. A list of [decoders](hsdata/dataset.py#L127) whose elements can be either defined as a string matching one of the predefined webdataset [image decoders](https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/autodecode.py#L238) decoders or some custom decoder (in the config-style) for handling more specific needs. Note that decoding will be skipped alltogether when setting `decoders=None` (or in config-style yaml `decoders: null`). 13 | 4. A list of [postprocessors](sdata/dataset.py#L130) which are used to filter or transform the data after it has been decoded and should again be either `mappers` or `filters`. 14 | 5. `error_handler`: A [webdataset-style function](https://github.com/webdataset/webdataset/blob/main/webdataset/handlers.py) for handling any errors which occur in the `datapipeline`, `preprocessors`, `decoders` or `postprocessors`. 15 | 16 | A wrapper around a pytorch dataloader, which can be plugged in to your training, is returned by [`sdata.create_loader()`](sdata/dataset.py#L51). You can pass the dataset either as an `IterableDataset` as returned by `sdata.create_dataset()` or via the config which would instantiate this dataset. Apart from the known `batch_size`, `num_workers`, `partial` and `collation_fn` parameteters for pytorch dataloaders, the function can be configured via the following arguments. 17 | 18 | 1. `batched_transforms` of batched `mappers` and `filters` which transform an entire training batch before being passed to the dataloader defined in the same style than the `preprocessors` and `postprocessors` from above. 19 | 2. `loader_kwargs` defining additional keyword arguments for the dataloader (such as `prefetch_factor`, ...) 20 | 3. `error_handler`: A [webdataset-style function](https://github.com/webdataset/webdataset/blob/main/webdataset/handlers.py) for handling any errors which occur in the `batched_transforms`. 21 | 22 | 23 | ## Examples 24 | 25 | Here, it is most effective to look at the configs in `examples/configs/` for the following examples. These will show you how this works. 26 | 27 | For a simple example, see [`examples/image_simple.py`](examples/image_simple.py), find config [here](examples/configs/example.yaml). 28 | 29 | **NOTE:** You have to add your dataset in tar-form which should follow the [webdataset-format](https://github.com/webdataset/webdataset). To find the parts which have to be adapted, search for comments conaining `USER:` in the respective config. 30 | 31 | ## Installation 32 | 33 | ### Pytorch 2 and later 34 | 35 | ```bash 36 | python3 -m venv .pt2 37 | source .pt2/bin/activate 38 | pip3 install wheel 39 | pip3 install -r requirements_pt2.txt 40 | 41 | ``` 42 | 43 | ### Pytorch 1.13 44 | 45 | ```bash 46 | python3 -m venv .pt1 47 | source .pt1/bin/activate 48 | pip3 install wheel 49 | pip3 install -r requirements_pt1.txt 50 | 51 | ``` 52 | 53 | -------------------------------------------------------------------------------- /examples/configs/example.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | urls: 3 | # USER: adapt this path the root of your custom dataset 4 | - "/path/to/data" 5 | pipeline_config: 6 | shardshuffle: 10000 7 | sample_shuffle: 1000 # USER: you might wanna adapt depending on your available RAM 8 | 9 | decoders: 10 | - "pil" 11 | 12 | postprocessors: 13 | - target: sdata.mappers.TorchVisionImageTransforms 14 | params: 15 | key: 'jpg' # USER: you might wanna adapt this for your custom dataset 16 | transforms: 17 | - target: torchvision.transforms.Resize 18 | params: 19 | size: 256 20 | interpolation: 3 21 | - target: torchvision.transforms.ToTensor 22 | - target: sdata.mappers.Rescaler 23 | 24 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 25 | params: 26 | h_key: height # USER: you might wanna adapt this for your custom dataset 27 | w_key: width # USER: you might wanna adapt this for your custom dataset 28 | 29 | loader: 30 | batch_size: 64 31 | num_workers: 6 -------------------------------------------------------------------------------- /examples/image_simple.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from omegaconf import OmegaConf 6 | import numpy as np 7 | 8 | from sdata import create_dataset, create_loader 9 | 10 | num_it = 100 11 | max_vis = 3 12 | 13 | if __name__ == "__main__": 14 | filedir = os.path.realpath(os.path.dirname(__file__)) 15 | config_path = os.path.join(filedir, "configs", "example.yaml") 16 | config = OmegaConf.load(config_path) 17 | 18 | # build config 19 | datapipeline = create_dataset(**config.dataset) 20 | 21 | # build loader 22 | loader = create_loader(datapipeline, **config.loader) 23 | 24 | print(f"Yielding {num_it} batches") 25 | 26 | for i, batch in enumerate(loader): 27 | if i >= num_it: 28 | break 29 | 30 | for key in batch: 31 | if isinstance(batch[key], (torch.Tensor, np.ndarray)): 32 | print(key, batch[key].shape) 33 | elif isinstance(batch[key], (List)): 34 | print(key) 35 | print(batch[key][:max_vis]) 36 | 37 | print("ciao") 38 | -------------------------------------------------------------------------------- /requirements_pt1.txt: -------------------------------------------------------------------------------- 1 | omegaconf 2 | einops 3 | fire 4 | tqdm 5 | webdataset>=0.2.33 6 | --extra-index-url https://download.pytorch.org/whl/cu117 7 | torch==1.13.1 8 | torchdata==0.5.1 9 | torchaudio==0.13.1 10 | torchvision==0.14.1+cu117 11 | torchmetrics 12 | opencv-python==4.6.0.66 13 | -e . -------------------------------------------------------------------------------- /requirements_pt2.txt: -------------------------------------------------------------------------------- 1 | omegaconf 2 | einops 3 | fire 4 | tqdm 5 | webdataset>=0.2.33 6 | torch>=2.0.1 7 | torchaudio>=2.0.2 8 | torchdata==0.6.1 9 | torchmetrics 10 | torchvision>=0.15.2 11 | opencv-python==4.6.0.66 12 | -e . -------------------------------------------------------------------------------- /sdata/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import create_dataset, create_loader 2 | from .datapipeline import warn_and_continue 3 | from . import mappers, filters 4 | from .dummy import create_dummy_dataset 5 | 6 | 7 | __all__ = [ 8 | "filters", 9 | "mappers", 10 | "create_loader", 11 | "create_dataset", 12 | "warn_and_continue", 13 | "create_dummy_dataset", 14 | ] 15 | -------------------------------------------------------------------------------- /sdata/custom_datapipes.py: -------------------------------------------------------------------------------- 1 | # Custom datapipes, partly copied from release 0.6.0 to also support 0.5.1, shout out to pytorch and torchdata 2 | import os 3 | import tarfile 4 | import time 5 | import warnings 6 | from io import BufferedIOBase, RawIOBase 7 | import re 8 | from typing import ( 9 | Iterator, 10 | List, 11 | Tuple, 12 | Optional, 13 | cast, 14 | IO, 15 | Callable, 16 | Dict, 17 | Union, 18 | ) 19 | import random 20 | import gc 21 | from collections import deque 22 | 23 | 24 | import numpy as np 25 | import torch 26 | from torchdata.datapipes.iter import IterDataPipe 27 | from torchdata.datapipes.iter import TarArchiveLoader 28 | from torchdata.datapipes.utils.common import validate_pathname_binary_tuple 29 | from torch.utils.data.datapipes.utils.common import StreamWrapper 30 | import webdataset as wds 31 | 32 | 33 | 34 | def _is_stream_handle(data): 35 | obj_to_check = data.file_obj if isinstance(data, StreamWrapper) else data 36 | return isinstance(obj_to_check, (BufferedIOBase, RawIOBase)) 37 | 38 | 39 | def _shard_expand(s: str) -> List[str]: 40 | expansion = r"\{[0-9]+\.\.[0-9]+\}" 41 | m = re.search(expansion, s) 42 | if not m: 43 | return [s] 44 | prefix = s[: m.start()] 45 | rest = _shard_expand(s[m.end() :]) 46 | rng = s[m.start() + 1 : m.end() - 1] 47 | lohi = rng.split("..") 48 | if len(lohi[0]) == len(lohi[1]) and lohi[0].startswith("0"): 49 | fmt = "{prefix}{i:0>{l}d}{r}" 50 | elif len(lohi[0]) <= len(lohi[1]): 51 | if lohi[0].startswith("0") and lohi[0] != "0": 52 | raise ValueError( 53 | "shard_expand: low bound must not start with 0 if low bound is shorter" 54 | ) 55 | fmt = "{prefix}{i}{r}" 56 | else: 57 | raise ValueError("shard_expand: low bound must be shorter than high bound") 58 | lo, hi = (int(x) for x in lohi) 59 | if lo >= hi: 60 | raise ValueError(f"shard_expand: bad range in in shard spec {s}.") 61 | result = [] 62 | for i in range(lo, hi + 1): 63 | for r in rest: 64 | expanded: str = fmt.format(prefix=prefix, i=i, r=r, l=len(lohi[1])) 65 | result.append(expanded) 66 | return result 67 | 68 | 69 | class CustomShardExpanderIterDataPipe(IterDataPipe[str]): 70 | r""" 71 | Expands incoming shard strings into shards. 72 | 73 | Sharded data files are named using shell-like brace notation. For example, 74 | an ImageNet dataset sharded into 1200 shards and stored on a web server 75 | might be named `imagenet-{000000..001199}.tar`. 76 | 77 | Note that shard names can be expanded without any server transactions; 78 | this makes `shard_expand` reproducible and storage system independent 79 | (unlike :class `.FileLister` etc.). 80 | 81 | Args: 82 | source_datapipe: a DataPipe yielding a stream of pairs 83 | 84 | Returns: 85 | a DataPipe yielding a stream of expanded pathnames. 86 | 87 | Example: 88 | from torchdata.datapipes.iter import IterableWrapper 89 | >>> source_dp = IterableWrapper(["ds-{00..05}.tar"]) 90 | >>> expand_dp = source_dp.shard_expand() 91 | >>> list(expand_dp) 92 | ['ds-00.tar', 'ds-01.tar', 'ds-02.tar', 'ds-03.tar', 'ds-04.tar', 'ds-05.tar'] 93 | >>> source_dp = IterableWrapper(["imgs_{00..05}.tar", "labels_{00..05}.tar"]) 94 | >>> expand_dp = source_dp.shard_expand() 95 | >>> list(expand_dp) 96 | ['imgs_00.tar', 'imgs_01.tar', 'imgs_02.tar', 'labels_00.tar', 'labels_01.tar', 'labels_02.tar'] 97 | """ 98 | 99 | def __init__(self, source_datapipe: IterDataPipe[str]) -> None: 100 | super().__init__() 101 | self.source_datapipe: IterDataPipe[str] = source_datapipe 102 | 103 | def __iter__(self) -> Iterator[str]: 104 | for path in self.source_datapipe: 105 | yield from _shard_expand(path) 106 | 107 | 108 | class SeedSetter(IterDataPipe): 109 | """ 110 | Resets the seed on call of __iter__ (invoked in the reset() method 111 | """ 112 | 113 | def __init__(self, datapipe, debug=False): 114 | super().__init__() 115 | self.datapipe = datapipe 116 | self.is_init = False 117 | self.debug = False 118 | 119 | # # def reset(self): 120 | def reset(self): 121 | # this will be called whenever __iter__ is invoked again (this should be kept in mind for shuffling 122 | if not self.is_init: 123 | # we only wanna do this once 124 | self.is_init = True 125 | 126 | worker_info = torch.utils.data.get_worker_info() 127 | 128 | if worker_info: 129 | worker_id = worker_info.id 130 | newseed = np.random.get_state()[1][0] + worker_id 131 | if self.debug: 132 | print(f"Worker #{worker_id} reseeding with {newseed}") 133 | np.random.seed(newseed) 134 | torch.random.manual_seed(newseed) 135 | random.seed(newseed) 136 | 137 | def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: 138 | # self.set_seed() 139 | # print(f'seed in worker init: {seed}') 140 | for data in self.datapipe: 141 | yield data 142 | 143 | 144 | class SplitByWorker(IterDataPipe): 145 | """ 146 | distributed data across workers to mimic behavior of shard splitting in webdataset 147 | """ 148 | 149 | def __init__(self, datapipe, debug: bool = False): 150 | super().__init__() 151 | self.datapipe = datapipe 152 | # self.drop_last = drop_last 153 | self.worker_id = 0 154 | self.num_workers = 1 155 | self.debug = debug 156 | self.do_print = True 157 | 158 | def reset(self): 159 | # this will be called whenever __iter__ is invoked again (this should be kept in mind for shuffling 160 | worker_info = torch.utils.data.get_worker_info() 161 | if self.debug: 162 | print(f"worker {worker_info} configured") 163 | if worker_info: 164 | self.worker_id = worker_info.id 165 | self.num_workers = worker_info.num_workers 166 | 167 | def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: 168 | for i, data in enumerate(self.datapipe): 169 | # # avoid hanging due to uneven number of shards per worker 170 | if i % self.num_workers == self.worker_id: 171 | if self.debug and self.do_print: 172 | print(f"data worker {self.worker_id} got first shard {data}") 173 | self.do_print = False 174 | yield data 175 | 176 | 177 | class PrefixResampler(IterDataPipe): 178 | def __init__( 179 | self, 180 | datapipe: IterDataPipe[Tuple], 181 | prefixes: List[str], 182 | ps: List[float] = None, 183 | is_braceexpand: bool = False, 184 | buffersize_per_prefix: Union[int, Dict] = 100000, 185 | custom_seeding: bool=True, 186 | debug: bool = True, 187 | handler: Callable = wds.reraise_exception, 188 | ): 189 | super().__init__() 190 | self.source = datapipe 191 | if is_braceexpand: 192 | # only dirs 193 | prefixes = [os.path.split(p.split("{")[0])[0] for p in prefixes] 194 | 195 | assert len(set(prefixes)) == len(prefixes), "Prefixes should be unique" 196 | self.ps = {k: p for k, p in zip(prefixes, ps)} 197 | 198 | print(f"{self.__class__.__name__} got the following prefixes: {prefixes}") 199 | 200 | if isinstance(buffersize_per_prefix, int): 201 | buffersize_per_prefix = {pref: buffersize_per_prefix for pref in prefixes} 202 | 203 | assert len(buffersize_per_prefix) == len( 204 | prefixes 205 | ), f"Buffersize per prefix (len={len(buffersize_per_prefix)}) has to have the same length than prefixes (len={len(prefixes)})" 206 | self.url_buffer = { 207 | prf: deque(maxlen=buffersize_per_prefix[prf]) for prf in prefixes 208 | } 209 | self.warn_once = {prf: True for prf in self.url_buffer} 210 | 211 | sum_ = sum(list(self.ps.values())) 212 | self.ps = {k: self.ps[k] / sum_ for k in self.ps} 213 | 214 | print( 215 | f"Got the following (prob, prefix) pairs for {len(self.ps)} prefixes {[(k, p) for k, p in self.ps.items()]}" 216 | ) 217 | 218 | self.handler = handler 219 | self.is_init = not custom_seeding 220 | self.debug = debug 221 | 222 | assert np.isclose( 223 | sum(self.ps.values()), 1.0 224 | ), "Probabilities must have the same length than prefix and must sum up to 1" 225 | 226 | def reset(self): 227 | if self.debug: 228 | 229 | worker_info = torch.utils.data.get_worker_info() 230 | 231 | if worker_info: 232 | worker_id = worker_info.id 233 | print(f"Worker #{worker_id} has seed {np.random.get_state()[1][0]}") 234 | 235 | def __iter__(self): 236 | keep_target = False 237 | target_prefix=None 238 | for url in self.source: 239 | try: 240 | assert isinstance( 241 | url, (tuple, Tuple) 242 | ), f"source datapipe of {self.__class__.__name__} should yield tuples" 243 | key, content = url 244 | if not keep_target: 245 | keep_target = True 246 | target_prefix = np.random.choice( 247 | list(self.ps), 1, p=list(self.ps.values()) 248 | ).item() 249 | current_prefix = list(filter(lambda x: key.startswith(x), self.ps)) 250 | if not len(current_prefix) == 1: 251 | raise ValueError( 252 | f"the received prefix is non-unique and matches " 253 | f"all of {current_prefix}, aborting" 254 | ) 255 | current_prefix = current_prefix[0] 256 | 257 | if ( 258 | len(self.url_buffer[current_prefix]) 259 | >= self.url_buffer[current_prefix].maxlen 260 | ): 261 | maxsize = self.url_buffer[current_prefix].maxlen 262 | if self.warn_once[current_prefix]: 263 | self.warn_once[current_prefix] = False 264 | warnings.warn( 265 | f"buffer size for prefix {current_prefix} in {self.__class__.__name__} exceeds its max buffer size {maxsize}," 266 | f"thus discarding this element. " 267 | f"Is this intended?" 268 | ) 269 | else: 270 | self.url_buffer[current_prefix].append(url) 271 | 272 | if current_prefix == target_prefix: 273 | keep_target = False 274 | # FIFO 275 | out = self.url_buffer[target_prefix].popleft()[1] 276 | yield out 277 | except Exception as e: 278 | if self.handler(e): 279 | pass 280 | else: 281 | raise e 282 | 283 | 284 | class Dataset2SamplesConverter(IterDataPipe): 285 | def __init__( 286 | self, datapipe: IterDataPipe, handler: Callable = wds.reraise_exception 287 | ): 288 | super().__init__() 289 | self.datapipe = datapipe 290 | self.handler = handler 291 | 292 | def __iter__(self) -> Iterator[Dict]: 293 | try: 294 | for sample in self.datapipe: 295 | try: 296 | # dict-style sample from tuple 297 | key = os.path.split(sample[0][0])[-1].split(".")[0] 298 | url = os.path.split(sample[0][0])[0] 299 | out = {} 300 | for s in sample: 301 | key_ = ( 302 | s[0].split(key)[-1][1:] 303 | if s[0].split(key)[-1].startswith(".") 304 | else s[0].split(key)[-1] 305 | ) 306 | data = s[1] 307 | 308 | if _is_stream_handle(data): 309 | ds = data 310 | # The behavior of .read can differ between streams (e.g. HTTPResponse), hence this is used instead 311 | data = b"".join(data) 312 | ds.close() 313 | del ds 314 | 315 | out[key_] = data 316 | del data 317 | sample = out 318 | del out 319 | sample["__key__"] = key 320 | sample["__url__"] = url 321 | 322 | yield sample 323 | 324 | except Exception as exn: 325 | if self.handler(exn): 326 | gc.collect() 327 | continue 328 | else: 329 | break 330 | 331 | except Exception as e: 332 | if self.handler(e): 333 | print(f"Catched exception in {self.__class__.__name__}: ", e) 334 | else: 335 | print(f"Catched exception in {self.__class__.__name__}: ", e) 336 | raise e 337 | 338 | 339 | class TarArchiveLoaderAndCloser(TarArchiveLoader): 340 | def __init__(self, handler: Callable = wds.reraise_exception, *args, **kwargs): 341 | super().__init__(*args, **kwargs) 342 | self.handler = handler 343 | self.times = None 344 | self.profile = False 345 | 346 | def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: 347 | for data in self.datapipe: 348 | start = time.perf_counter() 349 | validate_pathname_binary_tuple(data) 350 | pathname, data_stream = data 351 | try: 352 | if isinstance(data_stream, StreamWrapper) and isinstance( 353 | data_stream.file_obj, tarfile.TarFile 354 | ): 355 | tar = data_stream.file_obj 356 | else: 357 | reading_mode = ( 358 | self.mode 359 | if hasattr(data_stream, "seekable") and data_stream.seekable() 360 | else self.mode.replace(":", "|") 361 | ) 362 | # typing.cast is used here to silence mypy's type checker 363 | tar = tarfile.open( 364 | fileobj=cast(Optional[IO[bytes]], data_stream), 365 | mode=reading_mode, 366 | ) 367 | if self.profile: 368 | self.open_times.append(time.perf_counter() - start) 369 | try: 370 | for tarinfo in tar: 371 | start = time.perf_counter() 372 | if not tarinfo.isfile(): 373 | continue 374 | extracted_fobj = tar.extractfile(tarinfo) 375 | if extracted_fobj is None: 376 | warnings.warn( 377 | f"failed to extract file {tarinfo.name} from source tarfile {pathname}" 378 | ) 379 | raise tarfile.ExtractError 380 | inner_pathname = os.path.normpath( 381 | os.path.join(pathname, tarinfo.name) 382 | ) 383 | sw = StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] 384 | 385 | if self.profile: 386 | self.extract_times.append(time.perf_counter() - start) 387 | yield inner_pathname, sw 388 | # sw.autoclose() 389 | del sw 390 | # close tarfile after it's been exceeded 391 | finally: 392 | tar.close() 393 | del tar 394 | del tarinfo 395 | 396 | if _is_stream_handle(data_stream): 397 | data_stream.autoclose() 398 | del data_stream 399 | gc.collect() 400 | except Exception as e: 401 | warnings.warn( 402 | f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!" 403 | ) 404 | if self.handler(e): 405 | if hasattr(e, "args") and len(e.args) > 0: 406 | e.args = (e.args[0] + " @ " + str(pathname),) + e.args[1:] 407 | else: 408 | raise e 409 | 410 | 411 | 412 | -------------------------------------------------------------------------------- /sdata/datapipeline.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from typing import Callable, Optional, Union, List, Any, Dict 4 | import re 5 | from packaging import version 6 | from operator import itemgetter 7 | import functools 8 | import time 9 | import warnings 10 | import threading 11 | 12 | 13 | from omegaconf import ListConfig, DictConfig 14 | import torchdata 15 | import torch.distributed as dist 16 | 17 | from torch.utils.data.datapipes.iter import IterableWrapper, FileOpener 18 | from torchdata.datapipes.iter import IterKeyZipper 19 | import webdataset as wds 20 | 21 | from .custom_datapipes import ( 22 | CustomShardExpanderIterDataPipe, 23 | SplitByWorker, 24 | PrefixResampler, 25 | TarArchiveLoaderAndCloser, 26 | SeedSetter, 27 | Dataset2SamplesConverter, 28 | _is_stream_handle, 29 | ) 30 | 31 | class TimeoutError(Exception): 32 | pass 33 | 34 | 35 | def timeout_wrapper(func): 36 | def wrapper(*args, **kwargs): 37 | if ( 38 | "SDATA_MAX_EXC_TIME" not in os.environ 39 | or not os.environ["SDATA_MAX_EXC_TIME"] 40 | ): 41 | res = func(*args, **kwargs) 42 | del args 43 | del kwargs 44 | return res 45 | 46 | timeout = float(os.environ["SDATA_MAX_EXC_TIME"]) 47 | 48 | result = [None] 49 | exception = [None] 50 | event = threading.Event() 51 | 52 | def wrapped_func(): 53 | try: 54 | result[0] = func(*args, **kwargs) 55 | except Exception as e: 56 | exception[0] = e 57 | finally: 58 | event.set() 59 | 60 | thread = threading.Thread(target=wrapped_func) 61 | thread.start() 62 | event.wait(timeout) 63 | 64 | if not event.is_set(): 65 | raise TimeoutError( 66 | f"Function call timed out (longer than {timeout } secs)." 67 | ) 68 | 69 | thread.join() 70 | 71 | if exception[0] is not None: 72 | raise exception[0] 73 | 74 | del thread 75 | del exception 76 | del wrapped_func 77 | del event 78 | del args 79 | del kwargs 80 | 81 | return result[0] 82 | 83 | return wrapper 84 | 85 | 86 | def warn_and_continue(exn): 87 | """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" 88 | print(exn) 89 | warnings.warn(repr(exn)) 90 | time.sleep(0.05) 91 | return True 92 | 93 | 94 | def time_measure(name: str = "function"): 95 | def wrapper(fn): 96 | def measure_time(*args, **kwargs): 97 | start = time.perf_counter() 98 | r = fn(*args, **kwargs) 99 | end = time.perf_counter() 100 | if "SDATA_PROFILE" in os.environ and os.environ["SDATA_PROFILE"]: 101 | if r is None: 102 | return r 103 | try: 104 | if isinstance(r, Dict): 105 | r[f"{name}-time"] = end - start 106 | else: 107 | args[1][f"{name}-time"] = end - start 108 | 109 | except Exception as e: 110 | print(f"Exception raised when measuring time for {name}") 111 | raise e 112 | 113 | del args 114 | del kwargs 115 | 116 | return r 117 | 118 | return measure_time 119 | 120 | return wrapper 121 | 122 | 123 | def instantiate(config: Union[Dict, DictConfig]) -> Any: 124 | if not "target" in config: 125 | if config == "__is_first_stage__": 126 | return None 127 | elif config == "__is_unconditional__": 128 | return None 129 | raise KeyError("Expected key `target` to instantiate.") 130 | return create_obj(config["target"])(**config.get("params", dict())) 131 | 132 | 133 | def make_callable(config): 134 | return functools.partial( 135 | create_obj(config["target"]), **config.get("params", dict()) 136 | ) 137 | 138 | 139 | def create_obj(string: str, reload: bool = False, invalidate_cache: bool = True) -> Any: 140 | module, cls = string.rsplit(".", 1) 141 | if invalidate_cache: 142 | importlib.invalidate_caches() 143 | if reload: 144 | module_imp = importlib.import_module(module) 145 | importlib.reload(module_imp) 146 | return getattr(importlib.import_module(module, package=None), cls) 147 | 148 | 149 | class KeyPassThroughDecoder(wds.Decoder): 150 | def __init__(self, *args, passthrough_keys=None, **kwargs): 151 | super().__init__(*args, **kwargs) 152 | self.passthrough_keys = passthrough_keys 153 | if self.passthrough_keys is None: 154 | self.passthrough_keys = [] 155 | 156 | def decode1(self, key, data): 157 | # if data is a stream handle, we need to read all the content before decoding 158 | if _is_stream_handle(data): 159 | ds = data 160 | # The behavior of .read can differ between streams (e.g. HTTPResponse), hence this is used instead 161 | data = b"".join(data) 162 | ds.close() 163 | 164 | key = "." + key 165 | for f in self.handlers: 166 | result = f(key, data) 167 | if isinstance(result, wds.autodecode.Continue): 168 | key, data = result.key, result.data 169 | continue 170 | if result is not None: 171 | del data 172 | return result 173 | return data 174 | 175 | @timeout_wrapper 176 | @time_measure(name="KeyPassThroughDecoder") 177 | def decode(self, sample): 178 | """Decode an entire sample. 179 | 180 | :param sample: the sample, a dictionary of key value pairs 181 | """ 182 | result = {} 183 | assert isinstance(sample, dict), sample 184 | for k, v in list(sample.items()): 185 | if k[0] == "_": 186 | if isinstance(v, bytes): 187 | v = v.decode("utf-8") 188 | result[k] = v 189 | continue 190 | if self.only is not None and k not in self.only: 191 | result[k] = v 192 | continue 193 | assert v is not None 194 | if self.partial: 195 | if isinstance(v, bytes) or k in self.passthrough_keys: 196 | result[k] = self.decode1(k, v) 197 | else: 198 | result[k] = v 199 | else: 200 | assert ( 201 | isinstance(v, bytes) or k in self.passthrough_keys 202 | ), f"key: {k}; passthrough_keys: {self.passthrough_keys}" 203 | result[k] = self.decode1(k, v) 204 | return result 205 | 206 | 207 | def tarfilter(x): 208 | ret = x.endswith(".tar") 209 | del x 210 | return ret 211 | 212 | 213 | def grouper(x): 214 | key = x[0].split("/")[-1].split(".")[0] 215 | del x 216 | return key 217 | 218 | 219 | def tuple_grouper(x): 220 | key = x[0][0].split("/")[-1].split(".")[0] 221 | del x 222 | return key 223 | 224 | 225 | def merge_samples(s1, s2, meta_urls): 226 | s1_files = [os.path.splitext(s[0])[1] for s in s1] 227 | meta_key_list = [mk for mk in meta_urls if mk in s2[0][0]] 228 | if len(meta_key_list) == 0: 229 | raise ValueError( 230 | f"no matching meta key found for the following file(s): {os.path.splitext(s2[0][0])[0]}" 231 | ) 232 | elif len(meta_key_list) > 1: 233 | raise ValueError( 234 | f"More than one matching meta key found for the following file(s): {os.path.splitext(s2[0][0])[0]}" 235 | ) 236 | 237 | meta_key = meta_key_list[0] 238 | outs2 = [ 239 | s 240 | if os.path.splitext(s[0])[1] not in s1_files 241 | else (os.path.splitext(s[0])[0] + meta_key + os.path.splitext(s[0])[1], s[1]) 242 | for s in s2 243 | ] 244 | del s2 245 | return list(s1) + outs2 246 | 247 | 248 | def merge_them(u1, u2): 249 | # concat lists: these lists should contain all tarfiles from the same prefix but 250 | # with different filenames 251 | return u1[1] + [ 252 | u2, 253 | ] 254 | 255 | 256 | def identity(x): 257 | return True 258 | 259 | 260 | def map_to_tuple(x): 261 | return ( 262 | os.path.join(os.path.split(x)[0], os.path.splitext(os.path.split(x)[1])[0]), 263 | [ 264 | x, 265 | ], 266 | ) 267 | 268 | 269 | def filter_with_meta_set(x, meta_set): 270 | return itemgetter(0)(x) in meta_set 271 | 272 | 273 | def get_ref_key(x, suffix): 274 | return os.path.splitext(x.replace("_" + suffix, ""))[0] 275 | 276 | 277 | def list_files_in_datapipe( 278 | urls: Union[List, ListConfig], 279 | is_braceexpand: bool, 280 | tar_sampler: Callable = identity, 281 | ) -> torchdata.datapipes.iter.IterDataPipe: 282 | """ 283 | 284 | :param datapipe: 285 | :param is_braceexpand: 286 | :return: 287 | """ 288 | datapipe = IterableWrapper(urls) 289 | 290 | if version.parse(torchdata.__version__) >= version.parse("0.6.0"): 291 | if is_braceexpand: 292 | datapipe = CustomShardExpanderIterDataPipe(datapipe) 293 | else: 294 | datapipe = datapipe.list_files(recursive=True).filter(tarfilter) 295 | else: 296 | if is_braceexpand: 297 | datapipe = CustomShardExpanderIterDataPipe(datapipe) 298 | else: 299 | datapipe = datapipe.list_files(recursive=True).filter(tarfilter) 300 | 301 | datapipe = datapipe.filter(tar_sampler) 302 | 303 | return datapipe 304 | 305 | 306 | class StableDataPipeline(wds.DataPipeline, wds.compat.FluidInterface): 307 | """ 308 | Central class for reading data from tars on local fs and building samples based on consecutive files with the same keys 309 | """ 310 | 311 | def __init__( 312 | self, 313 | urls: Union[List[str], str, ListConfig], 314 | meta_urls: Optional[Union[List[str], str]] = None, 315 | metadata_buffer_size: Union[int, None] = 10000, 316 | repeat: int = None, 317 | shardshuffle: int = 10000, 318 | sample_shuffle: int = 1, 319 | resample_prefixes: bool = False, 320 | prefix_probs: Optional[List[float]] = None, 321 | split_data_by_worker: bool = True, 322 | tar_sampler: Optional[Union[DictConfig, Dict, Callable]] = identity, 323 | handler: Union[Callable, DictConfig] = wds.reraise_exception, 324 | debug: bool = False, 325 | n_shards: int = 100000, 326 | ): 327 | """ 328 | 329 | :param urls: folders to load the shards from, can be a list of different prefoxes for dataset mixing 330 | :param meta_urls: can be used for aligned metadata files stored as tars 331 | :param metadata_buffer_size: 332 | :param repeat: number of repetitions in the training data. Default is None which means looping perpetually. 333 | :param shardshuffle: Shuffle buffer size for shard shuffling. size 1 means no shufflin. Default is 10k. 334 | :param sample_shuffle: Shuffle buffer for sample-level-shuffling. Default is 1 which means no shuffling 335 | :param resample_prefixes: Whether to resample when different prefixes are in the entire dataset. 336 | This can be useful in combination with prefix probs when training on merged datasets of non-equal size. 337 | :param prefix_probs: list containing resampling probabilities for every prefix in `urls` 338 | :param split_data_by_worker: Whether to split shards across worker threads for num_workers > 0 339 | :param handler: handler for handling exceptions as in webdataset 340 | """ 341 | super().__init__() 342 | 343 | if isinstance(urls, (List, ListConfig, list)): 344 | pass 345 | elif isinstance(urls, str): 346 | urls = [urls] 347 | else: 348 | raise TypeError( 349 | "urls need to be path to a S3 prefix or list of paths to more than one prefixes" 350 | ) 351 | 352 | if isinstance(handler, (DictConfig, Dict)): 353 | handler = make_callable(handler) 354 | 355 | 356 | # get some information abt fs where shards live in and the way shards are specified 357 | is_braceexpand = any(["{" in u for u in urls]) 358 | 359 | if is_braceexpand: 360 | brace_expansion = re.compile(r"\{[0-9]+\.\.[0-9]+\}") 361 | assert all(len(re.findall(brace_expansion, u)) == 1 for u in urls), ( 362 | "Specifiying tars in listed prefixes should be consistent. " 363 | "It should be either braceexpand notation or just using some " 364 | "base prefix. If this still fails, you might have some urls with " 365 | "multiple or malformed braceexpands." 366 | ) 367 | 368 | if isinstance(tar_sampler, (Dict, dict, DictConfig)): 369 | tar_sampler = instantiate(tar_sampler) 370 | 371 | main_datapipe = list_files_in_datapipe( 372 | urls, 373 | is_braceexpand=is_braceexpand, 374 | tar_sampler=tar_sampler, 375 | ).map(fn=map_to_tuple) 376 | 377 | if meta_urls: 378 | print( 379 | f"Zipping together {len(meta_urls)} meta datapipes with the following suffixes {meta_urls} " 380 | f"and adding this to the main datapipes " 381 | ) 382 | 383 | if isinstance(meta_urls, str): 384 | meta_urls = [meta_urls] 385 | 386 | meta_urls_base = [os.path.split(m) for m in urls] 387 | # meta_urls = [[os.path.join(m[0], os.path.splitext(m[1])[0]+f"_{suffix}"+os.path.splitext(m[1])[1]) for m in meta_urls_base] for suffix in meta_urls] 388 | meta_files = [ 389 | [os.path.join(m[0] + f"_{suffix}", m[1]) for m in meta_urls_base] 390 | for suffix in meta_urls 391 | ] 392 | 393 | for suffix, meta_url_collection in zip(meta_urls, meta_files): 394 | # this is the meta data which will be added to the man data 395 | meta_datapipe = list_files_in_datapipe( 396 | meta_url_collection, 397 | is_braceexpand=is_braceexpand, 398 | tar_sampler=tar_sampler, 399 | ) 400 | # filter out non-exisiting shards 401 | meta_set = set([get_ref_key(pth, suffix) for pth in meta_datapipe]) 402 | main_datapipe = main_datapipe.filter( 403 | functools.partial(filter_with_meta_set, meta_set=meta_set) 404 | ) 405 | 406 | # cycle in side branch to avoid exhausting after iterating over the entire dataset 407 | meta_datapipe = meta_datapipe.cycle() 408 | # merging always based on filenames where the metadata shards are expected to have .tar, 409 | # e.g. for a main shard "0000.tar" and an optical flow metadatashard we'd have "0000.tar" for the metadata shard 410 | # and the resulting key would be /path/to/prefix/0000 411 | main_datapipe = IterKeyZipper( 412 | main_datapipe, 413 | ref_datapipe=meta_datapipe, 414 | key_fn=itemgetter(0), 415 | ref_key_fn=functools.partial(get_ref_key, suffix=suffix), 416 | keep_key=True, 417 | merge_fn=merge_them, 418 | buffer_size=metadata_buffer_size, 419 | ) 420 | # main_datapipe = main_datapipe 421 | 422 | # start shuffling accross shards for the first time to mix different datasets 423 | # (can be the same for all workers, just as an additional shuffled initialization) 424 | if shardshuffle > 1 and not resample_prefixes and len(urls) > 1: 425 | # back to datapipes. We further apply a map to remove the key, so that the result is the sames than 426 | # for the prefix subsampler 427 | main_datapipe = main_datapipe.shuffle(buffer_size=n_shards).map( 428 | fn=itemgetter(1) 429 | ) 430 | elif resample_prefixes: 431 | main_datapipe = PrefixResampler( 432 | main_datapipe.shuffle(buffer_size=n_shards), 433 | ps=prefix_probs, 434 | prefixes=urls, 435 | is_braceexpand=is_braceexpand, 436 | custom_seeding=split_data_by_worker, 437 | debug=debug 438 | ) 439 | else: 440 | main_datapipe = main_datapipe.map(itemgetter(1)) 441 | 442 | if not resample_prefixes: 443 | shardshuffle = max(shardshuffle, 1) 444 | main_datapipe = main_datapipe.shuffle(buffer_size=shardshuffle) 445 | 446 | main_datapipe = main_datapipe.sharding_filter() 447 | 448 | # after this operation datapipes in the distinct processes contain different tars 449 | if dist.is_available() and dist.is_initialized(): 450 | # after this operation datapipes in the distinct processes contain different tars 451 | 452 | global_rank = dist.get_rank() 453 | world_size = dist.get_world_size() 454 | main_datapipe.apply_sharding(world_size, global_rank) 455 | print("#" * 100) 456 | print(f"distributing shards for worker with global rank {global_rank}") 457 | print("#" * 100) 458 | 459 | else: 460 | print( 461 | f"torch distributed not used, not applying sharding in {self.__class__.__name__}" 462 | ) 463 | 464 | if split_data_by_worker: 465 | print("Distributing shards across the worker threads in every process") 466 | main_datapipe = SplitByWorker( 467 | datapipe=main_datapipe, debug=debug 468 | ) 469 | else: 470 | main_datapipe = SeedSetter(main_datapipe, debug=debug) 471 | 472 | main_datapipe = main_datapipe.cycle(count=repeat) 473 | 474 | # unzip before loading, since here we can be sure that all shards are distributed and shuffled 475 | # aligned with their corresponding metadata shards 476 | meta_len = len(meta_urls) if meta_urls else 0 477 | main_datapipe, *meta_datapipes = main_datapipe.unzip( 478 | sequence_length=meta_len + 1 479 | ) 480 | 481 | 482 | # regular fileopener 483 | main_datapipe = FileOpener(main_datapipe, mode="b") 484 | meta_datapipes = [FileOpener(m, mode="b") for m in meta_datapipes] 485 | 486 | # adapted TarLoader which closes open tarfile handles after exceeding them 487 | # main_datapipe = TarArchiveLoaderAndCloser(datapipe=main_datapipe).groupby(grouper) 488 | # 489 | main_datapipe = TarArchiveLoaderAndCloser( 490 | datapipe=main_datapipe, handler=handler 491 | ).groupby(grouper) 492 | meta_datapipes = [ 493 | TarArchiveLoaderAndCloser(datapipe=m, handler=handler).groupby(grouper) 494 | for m in meta_datapipes 495 | ] 496 | 497 | # zip again, this time we're searching based on the same keys 498 | for meta_dp in meta_datapipes: 499 | # here we da 500 | main_datapipe = IterKeyZipper( 501 | main_datapipe, 502 | ref_datapipe=meta_dp, 503 | key_fn=tuple_grouper, 504 | merge_fn=functools.partial(merge_samples, meta_urls=meta_urls), 505 | buffer_size=metadata_buffer_size, 506 | ) 507 | 508 | if sample_shuffle > 0: 509 | main_datapipe = main_datapipe.shuffle(buffer_size=sample_shuffle) 510 | 511 | main_datapipe = Dataset2SamplesConverter(main_datapipe, handler=handler) 512 | self.append(main_datapipe) 513 | # self.append(dataset2samples(handler=handler)) 514 | 515 | def decode( 516 | self, 517 | *args, 518 | pre=None, 519 | post=None, 520 | only=None, 521 | partial=False, 522 | passthrough_keys=None, 523 | handler=wds.reraise_exception, 524 | ): 525 | handlers = [ 526 | wds.autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args 527 | ] 528 | decoder = KeyPassThroughDecoder( 529 | handlers, 530 | passthrough_keys=passthrough_keys, 531 | pre=pre, 532 | post=post, 533 | only=only, 534 | partial=partial, 535 | ) 536 | return self.map(decoder, handler=handler) 537 | -------------------------------------------------------------------------------- /sdata/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import numpy as np 4 | from omegaconf import DictConfig, ListConfig 5 | from typing import Dict, List, Optional, Union, Callable 6 | import webdataset as wds 7 | 8 | 9 | from .filters import AbstractFilter 10 | from .mappers import AbstractMapper 11 | from .datapipeline import StableDataPipeline, instantiate, make_callable, time_measure 12 | 13 | 14 | @time_measure("Collator") 15 | def dict_collation_fn( 16 | samples: List, combine_tensors: bool = True, combine_scalars: bool = True 17 | ) -> Dict: 18 | """Take a list of samples (as dictionary) and create a batch, preserving the keys. 19 | If `tensors` is True, `ndarray` objects are combined into 20 | tensor batches. 21 | :param dict samples: list of samples 22 | :param bool tensors: whether to turn lists of ndarrays into a single ndarray 23 | :returns: single sample consisting of a batch 24 | :rtype: dict 25 | """ 26 | keys = set.intersection(*[set(sample.keys()) for sample in samples]) 27 | batched = {key: [] for key in keys} 28 | 29 | for s in samples: 30 | [batched[key].append(s[key]) for key in batched] 31 | 32 | result = {} 33 | for key in batched: 34 | if isinstance(batched[key][0], (int, float)): 35 | if combine_scalars: 36 | result[key] = np.array(list(batched[key])) 37 | elif isinstance(batched[key][0], torch.Tensor): 38 | if combine_tensors: 39 | result[key] = torch.stack(list(batched[key])) 40 | elif isinstance(batched[key][0], np.ndarray): 41 | if combine_tensors: 42 | result[key] = np.array(list(batched[key])) 43 | else: 44 | result[key] = list(batched[key]) 45 | 46 | del samples 47 | del batched 48 | return result 49 | 50 | 51 | def create_loader( 52 | datapipeline: Union[wds.DataPipeline, Union[DictConfig, Dict]], 53 | batch_size: int, 54 | num_workers: int, 55 | partial: bool = False, 56 | collation_fn: Optional[Union[Callable, Dict, DictConfig]] = None, 57 | batched_transforms: Optional[ListConfig] = None, 58 | loader_kwargs: Optional[Union[Dict, DictConfig]] = None, 59 | error_handler: Optional[Union[Callable, Dict, DictConfig]] = None, 60 | ) -> torch.utils.data.DataLoader: 61 | if not loader_kwargs: 62 | loader_kwargs = {} 63 | 64 | if not batched_transforms: 65 | batched_transforms = [] 66 | 67 | if not collation_fn: 68 | collation_fn = dict_collation_fn 69 | 70 | if isinstance(collation_fn, (Dict, DictConfig)): 71 | collation_fn = make_callable(collation_fn) 72 | 73 | if not error_handler: 74 | error_handler = {"target": "sdata.warn_and_continue"} 75 | 76 | if isinstance(error_handler, (Dict, DictConfig)): 77 | error_handler = make_callable(error_handler) 78 | 79 | print("#" * 100) 80 | print("Building dataloader with the following parameters") 81 | print(f"batch_size: {batch_size}") 82 | print(f"num_workers: {num_workers}") 83 | for key in loader_kwargs: 84 | print(key, ": ", loader_kwargs[key]) 85 | print("#" * 100) 86 | # create datapipeline from dict if not already instantiated 87 | if isinstance(datapipeline, (DictConfig, Dict)): 88 | datapipeline = instantiate(datapipeline) 89 | 90 | # batching 91 | datapipeline = datapipeline.batched( 92 | batch_size, partial=partial, collation_fn=collation_fn 93 | ) 94 | 95 | # apply transforms which act on batched samples 96 | for i, trf in enumerate(batched_transforms): 97 | trf = instantiate(trf) 98 | if isinstance(trf, AbstractFilter): 99 | print( 100 | f"Adding filter {trf.__class__.__name__} as batched transform #{i} " 101 | f"to the datapipeline" 102 | ) 103 | datapipeline = datapipeline.select(trf) 104 | elif isinstance(trf, AbstractMapper): 105 | print( 106 | f"Adding mapper {trf.__class__.__name__} as batched transform #{i} " 107 | f"to the datapipeline" 108 | ) 109 | datapipeline = datapipeline.map(trf, handler=error_handler) 110 | else: 111 | raise TypeError( 112 | "chosen batched transform should be either a subclass of " 113 | "sdata.AbstractMapper or one of sdata.AbstractFilter" 114 | "but is none of both" 115 | ) 116 | 117 | # create loader 118 | loader = torch.utils.data.DataLoader( 119 | datapipeline, batch_size=None, num_workers=num_workers, **loader_kwargs 120 | ) 121 | return loader 122 | 123 | 124 | def create_dataset( 125 | urls: Union[List, ListConfig, str], 126 | pipeline_config: Optional[Union[DictConfig, Dict]] = None, 127 | decoders: Optional[Union[ListConfig, str]] = "pil", 128 | additional_decoder_kwargs: Optional[DictConfig] = None, 129 | preprocessors: Optional[ListConfig] = None, 130 | postprocessors: Optional[ListConfig] = None, 131 | error_handler: Optional[Union[Callable, Dict]] = None, 132 | ) -> wds.DataPipeline: 133 | """ 134 | Create a dataset from several (partly optional) configs and urls defining paths to shards in webdataset/torchdata format 135 | The shards should be located in the local filesystem and can be specified as directories or in braceexpand notation 136 | :param urls: the urls as paths to the shards 137 | :param pipeline_config: additional parameters for configuring the main datapipeline, for a list of 138 | available parameters, see sdata. 139 | :param decoders: 140 | :param additional_decoder_kwargs: Additional keyword args for the decoder. This can be e.g. used to define passthrough keys, which shall be 141 | decoded although not having a known decoder key 142 | :param preprocessors: 143 | :param postprocessors: 144 | :param error_handler: The error handler defining a strategy for handling errors in the stages in the pipeline 145 | :return: 146 | """ 147 | if isinstance(urls, str): 148 | urls = [urls] 149 | 150 | if not pipeline_config: 151 | pipeline_config = {} 152 | 153 | if not error_handler: 154 | error_handler = {"target": "sdata.warn_and_continue"} 155 | 156 | pipeline_config.pop("handler", None) 157 | pipeline_config["handler"] = error_handler 158 | 159 | # default for all processors 160 | if not preprocessors: 161 | preprocessors = [] 162 | 163 | if not postprocessors: 164 | postprocessors = [] 165 | 166 | if not additional_decoder_kwargs: 167 | additional_decoder_kwargs = {} 168 | 169 | if decoders and not isinstance(decoders, (List, ListConfig)): 170 | # default case is assuming image decoding 171 | decoders = [decoders] 172 | 173 | elif not decoders: 174 | decoders = [] 175 | 176 | datapipeline = StableDataPipeline(urls=urls, **pipeline_config) 177 | 178 | if isinstance(error_handler, Dict): 179 | error_handler = make_callable(error_handler) 180 | 181 | # instantiate all preprocessors 182 | for i, prepro_config in enumerate(preprocessors): 183 | prepro = instantiate(prepro_config) 184 | if isinstance(prepro, AbstractFilter): 185 | print( 186 | f"Adding filter {prepro.__class__.__name__} as preprocessor #{i} " 187 | f"to the datapipeline" 188 | ) 189 | datapipeline = datapipeline.select(prepro) 190 | elif isinstance(prepro, AbstractMapper): 191 | print( 192 | f"Adding mapper {prepro.__class__.__name__} as preprocessor #{i} " 193 | f"to the datapipeline" 194 | ) 195 | datapipeline = datapipeline.map(prepro, handler=error_handler) 196 | else: 197 | raise TypeError( 198 | f"chosen preprocessor {prepro.__class__.__name__} should be either a subclass of " 199 | "sdata.mappers.AbstractMapper or one of sdata.filters.AbstractFilter" 200 | "but is none of both" 201 | ) 202 | 203 | # do decoding 204 | prepared_decoders = [] 205 | for decoder_spec in decoders: 206 | if isinstance(decoder_spec, (Dict, DictConfig)): 207 | decoder = instantiate(decoder_spec) 208 | print(f"Adding decoder {decoder.__class__.__name__} to decoders.") 209 | prepared_decoders.append(decoder) 210 | elif isinstance(decoder_spec, str): 211 | assert ( 212 | decoder_spec in wds.autodecode.imagespecs 213 | or decoder_spec in wds.autodecode.decoders 214 | ), ( 215 | "when decoder is specified via a string, then it has to be a a " 216 | "decoder known to webdataset" 217 | ) 218 | print(f"Adding decoder {decoder_spec} to decoders.") 219 | prepared_decoders.append(decoder_spec) 220 | else: 221 | raise TypeError(f"{decoder_spec} not a thing for decoders.") 222 | 223 | if decoders: 224 | # default behavior is setting partial to 'True' in decode 225 | partial = additional_decoder_kwargs.pop("partial", True) 226 | # add instantiated decoders to the datapipeline 227 | datapipeline = datapipeline.decode( 228 | *prepared_decoders, 229 | partial=partial, 230 | handler=error_handler, 231 | **additional_decoder_kwargs, 232 | ) 233 | 234 | # instantiate all postprocessors 235 | for i, postro_config in enumerate(postprocessors): 236 | postpro = instantiate(postro_config) 237 | if isinstance(postpro, AbstractFilter): 238 | print( 239 | f"Adding filter {postpro.__class__.__name__} as postprocessor #{i} " 240 | f"to the datapipeline" 241 | ) 242 | datapipeline = datapipeline.select(postpro) 243 | elif isinstance(postpro, AbstractMapper): 244 | print( 245 | f"Adding mapper {postpro.__class__.__name__} as postprocessor #{i} " 246 | f"to the datapipeline" 247 | ) 248 | datapipeline = datapipeline.map(postpro, handler=error_handler) 249 | else: 250 | raise TypeError( 251 | "chosen postprocessor should be either a subclass of " 252 | "sdata.AbstractMapper or one of sdata.AbstractFilter" 253 | "but is none of both" 254 | ) 255 | 256 | return datapipeline 257 | -------------------------------------------------------------------------------- /sdata/dummy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import webdataset as wds 4 | from torchdata.datapipes.iter import IterDataPipe 5 | 6 | from .dataset import create_dataset 7 | 8 | 9 | class DummyIterator(IterDataPipe): 10 | def __init__(self, sample: Dict): 11 | super().__init__() 12 | self.sample = sample 13 | 14 | def __iter__(self): 15 | while True: 16 | yield self.sample 17 | 18 | 19 | class DummyDataPipeline(wds.DataPipeline, wds.compat.FluidInterface): 20 | def __init__(self, datapipe: IterDataPipe): 21 | super().__init__() 22 | self.append(datapipe) 23 | 24 | 25 | def create_dummy_dataset(*args, **kwargs): 26 | datapipe = create_dataset(*args, **kwargs) 27 | 28 | sample = next(iter(datapipe)) 29 | del datapipe 30 | iterator = DummyIterator(sample) 31 | 32 | datapipeline = DummyDataPipeline(iterator) 33 | 34 | return datapipeline 35 | -------------------------------------------------------------------------------- /sdata/filters/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import AbstractFilter, LambdaFilter 2 | from .metadata_filters import ( 3 | SimpleSizeFilter, 4 | SimpleKeyFilter, 5 | ) 6 | 7 | __all__ = [ 8 | "AbstractFilter", 9 | "SimpleSizeFilter", 10 | "SimpleKeyFilter", 11 | "LambdaFilter" 12 | ] 13 | -------------------------------------------------------------------------------- /sdata/filters/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, Union, List, Optional, Callable 3 | 4 | from omegaconf import ListConfig, DictConfig 5 | 6 | from ..datapipeline import time_measure, make_callable 7 | 8 | 9 | class AbstractFilter: 10 | def __init__( 11 | self, 12 | exclude_keys: Optional[Union[List[str], ListConfig, str]] = None, 13 | verbose: bool = False, 14 | ): 15 | if not exclude_keys: 16 | exclude_keys = [] 17 | 18 | if isinstance(exclude_keys, str): 19 | exclude_keys = [exclude_keys] 20 | self.exclude_keys = exclude_keys 21 | 22 | self.verbose = verbose 23 | 24 | def skip_this_sample(self, sample: Dict) -> bool: 25 | res = any(map(lambda x: x in sample["__url__"], self.exclude_keys)) 26 | del sample 27 | return res 28 | 29 | @abstractmethod 30 | def __call__(self, sample: Dict) -> bool: 31 | raise NotImplementedError("AbstractFilter should not be called but overwritten") 32 | 33 | 34 | class LambdaFilter(AbstractFilter): 35 | def __init__( 36 | self, 37 | keys: Union[str, List[str], ListConfig], 38 | fn: Union[Dict, DictConfig, Callable], 39 | *args, 40 | **kwargs 41 | ): 42 | super().__init__(*args, **kwargs) 43 | if isinstance(keys, str): 44 | keys = [keys] 45 | 46 | self.keys = keys 47 | 48 | if isinstance(fn, Union[Dict, DictConfig]): 49 | fn = make_callable(fn) 50 | 51 | self.fn = fn 52 | 53 | @time_measure("LambdaMapper") 54 | def __call__(self, sample: Dict) -> bool: 55 | if self.skip_this_sample(sample): 56 | del sample 57 | return True 58 | 59 | let_pass = True 60 | for key in self.keys: 61 | let_pass &= self.fn(sample[key]) 62 | 63 | del sample 64 | return let_pass 65 | -------------------------------------------------------------------------------- /sdata/filters/metadata_filters.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List, Union, Tuple 2 | 3 | from omegaconf import DictConfig 4 | 5 | from .base import AbstractFilter 6 | from sdata.datapipeline import time_measure, timeout_wrapper 7 | 8 | 9 | class SimpleKeyFilter(AbstractFilter): 10 | def __init__(self, keys: Union[str, List[str]], *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | if isinstance(keys, str): 13 | keys = [keys] 14 | 15 | self.keys = set(keys) 16 | 17 | @timeout_wrapper 18 | @time_measure("SimpleKeyFilter") 19 | def __call__(self, sample: Dict) -> bool: 20 | try: 21 | if self.skip_this_sample(sample): 22 | return True 23 | result = all(map(lambda x: x in sample, self.keys)) 24 | del sample 25 | return result 26 | except Exception as e: 27 | print(f"{e.__class__.__name__} in {self.__class__.__name__}: {e}") 28 | return False 29 | 30 | 31 | class SimpleSizeFilter(AbstractFilter): 32 | def __init__( 33 | self, 34 | size: Union[int, Tuple[int, int], List[int]], 35 | mode: str = "min", 36 | strict: Optional[Union[bool, Dict]] = None, 37 | width_key: str = "original_width", 38 | height_key: str = "original_height", 39 | *args, 40 | **kwargs, 41 | ): 42 | """ 43 | Simple size filter based on metadata which is already decoded 44 | :param size: The desired min or max size 45 | :param mode: either to filter out all above a min size (min) or all below a max size (max) 46 | :param key: indicates the field in the sample, the field should be a dict 47 | :param subkeys: list of strings defining subkeys for nested dict, i.e. ['foo','bar'] would result 48 | in sample[self.key]['foo']['bar'] being the entry in a nested dict, where the size information sits 49 | :param strict: whether to return True or False when the key is not present 50 | :param width_key: the width key at the final level in the metadata dict i.e. for key='json', 51 | subkeys=['foo','bar'] and width_key = 'original_width', the entry sample['json']['foo']['bar']['original_width'] 52 | would be used 53 | :param height_key: same as above but with width 54 | """ 55 | super().__init__(*args, **kwargs) 56 | if isinstance(size, int): 57 | size = [size, size] 58 | 59 | self.size = size 60 | 61 | if mode == "min": 62 | self.relation = self.filter_min 63 | else: 64 | self.relation = self.filter_max 65 | 66 | self.strict = strict 67 | if not isinstance(self.strict, (bool, dict, DictConfig)): 68 | raise TypeError( 69 | f"strict in {self.__class__.__name__} should be bool or Dict" 70 | ) 71 | 72 | self.height_key = height_key 73 | self.width_key = width_key 74 | 75 | def filter_min(self, height: int, width: int) -> bool: 76 | return height >= self.size[0] and width >= self.size[1] 77 | 78 | def filter_max(self, height: int, width: int) -> bool: 79 | return height <= self.size[0] and width <= self.size[1] 80 | 81 | @timeout_wrapper 82 | @time_measure("SimpleSizeFilter") 83 | def __call__(self, sample: Dict) -> bool: 84 | try: 85 | if self.skip_this_sample(sample): 86 | return True 87 | # get height and width 88 | original_width = sample[self.width_key] 89 | original_height = sample[self.height_key] 90 | 91 | result = self.relation(original_height, original_width) 92 | return result 93 | except Exception as e: 94 | if isinstance(self.strict, bool): 95 | return not self.strict 96 | elif isinstance(self.strict, (dict, DictConfig)): 97 | url = sample["__url__"] 98 | key = [k for k in self.strict if k in url][0] 99 | result = not self.strict[key] 100 | return result 101 | else: 102 | raise TypeError( 103 | f"strict in {self.__class__.__name__} should be bool or Dict" 104 | ) 105 | -------------------------------------------------------------------------------- /sdata/mappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import AbstractMapper 2 | from .batched_mappers import BatchedEinopsTransform 3 | from .sample_mappers import ( 4 | TorchVisionImageTransforms, 5 | Rescaler, 6 | AddOriginalImageSizeAsTupleAndCropToSquare, 7 | ) 8 | 9 | __all__ = [ 10 | "AbstractMapper", 11 | "AddOriginalImageSizeAsTupleAndCropToSquare", 12 | "BatchedEinopsTransform", 13 | "TorchVisionImageTransforms", 14 | "Rescaler", 15 | ] 16 | -------------------------------------------------------------------------------- /sdata/mappers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, Union, Optional, List, Callable 3 | import functools 4 | 5 | from omegaconf import ListConfig, DictConfig 6 | 7 | from ..datapipeline import make_callable, time_measure 8 | 9 | 10 | class AbstractMapper(object): 11 | timeout = None 12 | 13 | def __init__( 14 | self, 15 | exclude_keys: Optional[Union[List[str], ListConfig, str]] = None, 16 | timeout: Optional[float] = None, 17 | verbose:bool = False, 18 | ): 19 | self.timeout = timeout 20 | if not exclude_keys: 21 | exclude_keys = [] 22 | 23 | self.verbose = verbose 24 | 25 | if isinstance(exclude_keys, str): 26 | exclude_keys = [exclude_keys] 27 | self.exclude_keys = exclude_keys 28 | 29 | def skip_this_sample(self, sample: Dict) -> bool: 30 | res = any(map(lambda x: x in sample["__url__"], self.exclude_keys)) 31 | del sample 32 | return res 33 | 34 | @abstractmethod 35 | def __call__(self, sample: Dict) -> Union[Dict, None]: 36 | raise NotImplementedError("AbstractMapper should not be called but overwritten") 37 | 38 | 39 | 40 | class LambdaMapper(AbstractMapper): 41 | def __init__( 42 | self, 43 | keys: Union[str, List[str], ListConfig], 44 | fn: Union[Dict, DictConfig, Callable], 45 | *args, 46 | **kwargs 47 | ): 48 | super().__init__(*args, **kwargs) 49 | if isinstance(keys, str): 50 | keys = [keys] 51 | 52 | self.keys = keys 53 | 54 | if isinstance(fn, Union[Dict, DictConfig]): 55 | fn = make_callable(fn) 56 | 57 | self.fn = fn 58 | 59 | @time_measure("LambdaMapper") 60 | def __call__(self, sample: Dict) -> Dict: 61 | if self.skip_this_sample(sample): 62 | return sample 63 | 64 | for key in self.keys: 65 | sample[key] = self.fn(sample[key]) 66 | 67 | return sample 68 | -------------------------------------------------------------------------------- /sdata/mappers/batched_mappers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from einops import rearrange, repeat, reduce 4 | 5 | from .base import AbstractMapper 6 | from ..datapipeline import time_measure 7 | 8 | 9 | class BatchedEinopsTransform(AbstractMapper): 10 | transforms = {"rearrange": rearrange, "repeat": repeat, "reduce": reduce} 11 | 12 | def __init__( 13 | self, pattern: str, key: str, mode: str = "rearrange", *args, **kwargs 14 | ) -> None: 15 | super().__init__(*args, **kwargs) 16 | self.pattern = pattern 17 | self.key = key 18 | 19 | assert mode in self.transforms, ( 20 | f"mode parameter for {self.__class__.__name__} has to be " 21 | f"in {list(self.transforms)}" 22 | ) 23 | self.mode = mode 24 | 25 | @time_measure("BatchedEinopsTransform") 26 | def __call__(self, sample: Dict) -> Dict: 27 | if self.skip_this_sample(sample): 28 | return sample 29 | target = sample[self.key] 30 | 31 | sample[self.key] = self.transforms[self.mode](target, self.pattern) 32 | 33 | del target 34 | return sample 35 | -------------------------------------------------------------------------------- /sdata/mappers/sample_mappers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union, List, Dict 3 | from omegaconf import DictConfig, ListConfig 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as TT 7 | from torchvision.transforms.functional import InterpolationMode 8 | 9 | from .base import AbstractMapper 10 | from ..datapipeline import instantiate, time_measure,timeout_wrapper 11 | 12 | 13 | 14 | class Rescaler(AbstractMapper): 15 | def __init__( 16 | self, 17 | key: Union[List[str], ListConfig, str] = "jpg", 18 | isfloat: bool = True, 19 | strict: bool = True, 20 | *args, 21 | **kwargs, 22 | ): 23 | """ 24 | 25 | :param key: the key indicating the sample 26 | :param isfloat: bool indicating whether input is float in [0,1] 27 | or uint in [0.255] 28 | """ 29 | # keeping name of first argument to be 'key' for the sake of backwards compatibility 30 | super().__init__(*args, **kwargs) 31 | if isinstance(key, str): 32 | key = [key] 33 | self.keys = set(key) 34 | self.isfloat = isfloat 35 | self.strict = strict 36 | self.has_warned = [False, False] 37 | 38 | @timeout_wrapper 39 | @time_measure("Rescaler") 40 | def __call__(self, sample: Dict) -> Dict: 41 | """ 42 | 43 | :param sample: Dict containing the speficied key, which should be a torch.Tensor or numpy array 44 | :return: 45 | """ 46 | if self.skip_this_sample(sample): 47 | return sample 48 | if not any(map(lambda x: x in sample, self.keys)): 49 | if self.strict: 50 | raise KeyError( 51 | f"None of {self.keys} in current sample with keys {list(sample.keys())}" 52 | ) 53 | else: 54 | if not self.has_warned[0]: 55 | self.has_warned[0] = True 56 | warnings.warn( 57 | f"None of {self.keys} contained in sample" 58 | f"(for sample with keys {list(sample.keys())}). " 59 | f"Sample is returned unprocessed since strict mode not enabled" 60 | ) 61 | return sample 62 | 63 | matching_keys = set(self.keys.intersection(sample)) 64 | if len(matching_keys) > 1: 65 | if self.strict: 66 | raise ValueError( 67 | f"more than one matching key of {self.keys} in sample {list(sample.keys())}. This should not be the case" 68 | ) 69 | else: 70 | if not self.has_warned[1]: 71 | warnings.warn( 72 | f"more than one matching key of {self.keys} in sample {list(sample.keys())}." 73 | f" But strict mode disabled, so returning sample unchanged" 74 | ) 75 | self.has_warned[1] = True 76 | return sample 77 | 78 | key = matching_keys.pop() 79 | 80 | if self.isfloat: 81 | sample[key] = sample[key] * 2 - 1.0 82 | else: 83 | sample[key] = sample[key] / 127.5 - 1.0 84 | 85 | return sample 86 | 87 | 88 | class TorchVisionImageTransforms(AbstractMapper): 89 | def __init__( 90 | self, 91 | transforms: Union[Union[Dict, DictConfig], ListConfig], 92 | key: str = "jpg", 93 | strict: bool = True, 94 | *args, 95 | **kwargs, 96 | ): 97 | super().__init__(*args, **kwargs) 98 | self.strict = strict 99 | self.key = key 100 | chained_transforms = [] 101 | 102 | if isinstance(transforms, (DictConfig, Dict)): 103 | transforms = [transforms] 104 | 105 | for trf in transforms: 106 | trf = instantiate(trf) 107 | chained_transforms.append(trf) 108 | 109 | self.transform = TT.Compose(chained_transforms) 110 | 111 | @timeout_wrapper 112 | @time_measure("TorchVisionImageTransforms") 113 | def __call__(self, sample: Dict) -> Union[Dict, None]: 114 | if self.skip_this_sample(sample): 115 | return sample 116 | if self.key not in sample: 117 | if self.strict: 118 | del sample 119 | return None 120 | else: 121 | return sample 122 | sample[self.key] = self.transform(sample[self.key]) 123 | return sample 124 | 125 | 126 | 127 | class AddOriginalImageSizeAsTupleAndCropToSquare(AbstractMapper): 128 | """ 129 | Adds the original image size as params and crops to a square. 130 | Also adds cropping parameters. Requires that no RandomCrop/CenterCrop has been called before 131 | """ 132 | 133 | def __init__( 134 | self, 135 | h_key: str = "original_height", 136 | w_key: str = "original_width", 137 | image_key: str = "jpg", 138 | use_data_key: bool = True, 139 | data_key: str = "json", 140 | *args, 141 | **kwargs, 142 | ): 143 | super().__init__(*args, **kwargs) 144 | self.h_key, self.w_key = h_key, w_key 145 | self.image_key = image_key 146 | self.data_key = data_key 147 | self.use_data_key = use_data_key 148 | 149 | @timeout_wrapper 150 | @time_measure("AddOriginalImageSizeAsTupleAndCropToSquare") 151 | def __call__(self, x: Dict) -> Dict: 152 | if self.skip_this_sample(x): 153 | return x 154 | if self.use_data_key: 155 | h, w = map(lambda y: x["json"][y], (self.h_key, self.w_key)) 156 | else: 157 | h, w = map(lambda y: x[y], (self.h_key, self.w_key)) 158 | x["original_size_as_tuple"] = torch.tensor([h, w]) 159 | jpg = x[self.image_key] 160 | if not isinstance(jpg, torch.Tensor) and jpg.shape[0] not in [1, 3]: 161 | raise ValueError( 162 | f"{self.__class__.__name__} requires input image to be a torch.Tensor with channels-first" 163 | ) 164 | # x['jpg'] should be chw tensor in [-1, 1] at this point 165 | size = min(jpg.shape[1], jpg.shape[2]) 166 | delta_h = jpg.shape[1] - size 167 | delta_w = jpg.shape[2] - size 168 | assert not all( 169 | [delta_h, delta_w] 170 | ) # we assume that the image is already resized such that the smallest size is at the desired size. Thus, eiter delta_h or delta_w must be zero 171 | top = np.random.randint(0, delta_h + 1) 172 | left = np.random.randint(0, delta_w + 1) 173 | x[self.image_key] = TT.functional.crop( 174 | jpg, top=top, left=left, height=size, width=size 175 | ) 176 | x["crop_coords_top_left"] = torch.tensor([top, left]) 177 | return x 178 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | import os 4 | 5 | if __name__ == "__main__": 6 | 7 | def _read_reqs(relpath): 8 | fullpath = os.path.join(os.path.dirname(__file__), relpath) 9 | with open(fullpath) as f: 10 | return [ 11 | s.strip() 12 | for s in f.readlines() 13 | if (s.strip() and not s.startswith("#")) 14 | ] 15 | 16 | 17 | setup( 18 | name="sdata", 19 | version="0.0.1", 20 | description="", 21 | packages=find_packages(), 22 | # install_requires=REQUIREMENTS, 23 | py_modules=["sdata"], 24 | ) 25 | --------------------------------------------------------------------------------