├── src ├── __init__.py └── chug │ ├── version.py │ ├── hfds │ ├── __init__.py │ ├── wrappers.py │ ├── collate.py │ └── loader.py │ ├── text │ ├── __init__.py │ └── tokenization.py │ ├── image │ ├── __init__.py │ ├── transforms_factory.py │ ├── transforms_alb.py │ ├── transforms_torch.py │ ├── build_transforms_image.py │ └── build_transforms_doc.py │ ├── common │ ├── __init__.py │ ├── task_config.py │ ├── random.py │ ├── collate.py │ ├── urls.py │ ├── types.py │ └── config.py │ ├── task_pipeline │ ├── __init__.py │ ├── pipeline_manual.py │ ├── pipeline_factory.py │ ├── pipeline_doc_read.py │ ├── pipeline_image_text.py │ ├── pipeline_gtparse.py │ └── pipeline_doc_vqa.py │ ├── wds │ ├── __init__.py │ ├── helpers.py │ ├── dataset_info.py │ ├── tariterators.py │ ├── shardlists.py │ ├── filters.py │ ├── pipeline.py │ ├── loader.py │ └── decode.py │ ├── doc │ ├── __init__.py │ ├── constants.py │ ├── doc_vqa_processor.py │ ├── doc_read_processor.py │ └── doc_processor.py │ ├── app │ └── test.py │ ├── __init__.py │ └── loader.py ├── requirements.txt ├── pyproject.toml ├── .gitignore ├── LICENSE └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/chug/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.0dev0' 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | timm 3 | webdataset 4 | datasets 5 | pypdfium2 6 | simple_parsing -------------------------------------------------------------------------------- /src/chug/hfds/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import create_loader_hf 2 | from .wrappers import SafeDataset -------------------------------------------------------------------------------- /src/chug/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenization import tokenize, text_input_to_target, prepare_text_input, create_text_preprocessor -------------------------------------------------------------------------------- /src/chug/image/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms_factory import build_image_transforms, create_image_preprocessor 2 | from .build_transforms_doc import build_transforms_doc_better, build_transforms_doc_basic, build_transforms_doc_nougat 3 | from .build_transforms_image import build_transforms_image_timm, build_transforms_image_basic 4 | -------------------------------------------------------------------------------- /src/chug/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .collate import collate 2 | from .config import ImageInputCfg, ImageAugCfg, PreprocessCfg, image_mode_to_chs 3 | from .config import DataArg, DataCfg, DistributedCfg, source_to_shard_spec 4 | from .random import get_pytorch_worker_seed, seed_worker 5 | from .task_config import DataTaskCfg 6 | from .types import SourceSpec, SplitInfo, ShardSpec, SharedCount, LoaderBundle, FeatureInfo, ImageFeatureInfo 7 | from .urls import expand_urls 8 | -------------------------------------------------------------------------------- /src/chug/task_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_doc_read import build_task_pipeline_doc_read, DataTaskDocReadCfg 2 | from .pipeline_doc_vqa import build_task_pipeline_doc_vqa, DataTaskDocVqaCfg 3 | from .pipeline_gtparse import build_task_pipeline_gtparse 4 | from .pipeline_image_text import build_task_pipeline_image_text, DataTaskImageTextCfg 5 | from .pipeline_manual import build_task_pipeline_manual, DataTaskManualCfg 6 | 7 | from.pipeline_factory import create_task_pipeline -------------------------------------------------------------------------------- /src/chug/wds/__init__.py: -------------------------------------------------------------------------------- 1 | from .decode import decode_pdf_pages, decode_image_pages, create_image_decoder, DecodeDoc 2 | from .filters import detshuffle_v2, map_v2, map_expand_maybe, map_expand_always, flatten_nested 3 | from .helpers import log_and_continue, expand_urls, get_error_handler 4 | from .loader import create_loader_wds 5 | from .pipeline import build_data_pipeline 6 | from .shardlists import ResampledShardsV2, ShuffledShardList 7 | from .tariterators import group_by_keys_nothrow, tarfile_to_samples_nothrow 8 | -------------------------------------------------------------------------------- /src/chug/doc/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import ( 2 | DEFAULT_DOC_KEY, 3 | DEFAULT_QUESTION_KEY, 4 | DEFAULT_QUESTION_ID_KEY, 5 | DEFAULT_ANSWER_KEY, 6 | DEFAULT_DOC_KEY_TUPLE, 7 | DEFAULT_QUESTION_KEY_TUPLE, 8 | DEFAULT_ANSWER_KEY_TUPLE, 9 | DEFAULT_DOC_FEAT, 10 | DEFAULT_QUESTION_FEAT, 11 | DEFAULT_QUESTION_ID_FEAT, 12 | DEFAULT_ANSWER_FEAT 13 | ) 14 | from .doc_processor import DocProcessor 15 | from .doc_read_processor import DocReadProcessor 16 | from .doc_vqa_processor import DocVqaProcessor 17 | -------------------------------------------------------------------------------- /src/chug/task_pipeline/pipeline_manual.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable 3 | 4 | import webdataset as wds 5 | 6 | from chug.common import DataTaskCfg, FeatureInfo, ImageFeatureInfo 7 | from chug.doc import DocReadProcessor 8 | from chug.wds.helpers import log_and_continue 9 | 10 | 11 | @dataclass 12 | class DataTaskManualCfg(DataTaskCfg): 13 | pass 14 | 15 | 16 | def build_task_pipeline_manual( 17 | cfg: DataTaskManualCfg, 18 | ): 19 | assert cfg.decode_and_process_fn is not None 20 | # a pipeline that relies fully on passed in decode_and_process_fn, other cfg fields ignored 21 | pipe = [ 22 | wds.map( 23 | cfg.decode_and_process_fn, 24 | handler=log_and_continue, 25 | ) 26 | ] 27 | return pipe 28 | -------------------------------------------------------------------------------- /src/chug/doc/constants.py: -------------------------------------------------------------------------------- 1 | from chug import FeatureInfo, ImageFeatureInfo 2 | 3 | DEFAULT_DOC_KEY = "pdf;tif;png;jpeg;jpg;webp;image" 4 | DEFAULT_QUESTION_KEY = "question;query" 5 | DEFAULT_QUESTION_ID_KEY = "question_id;query_id" 6 | DEFAULT_ANSWER_KEY = "answer;answers" 7 | 8 | DEFAULT_DOC_KEY_TUPLE = tuple(DEFAULT_DOC_KEY.split(';')) 9 | DEFAULT_QUESTION_KEY_TUPLE = tuple(DEFAULT_QUESTION_KEY.split(';')) 10 | DEFAULT_ANSWER_KEY_TUPLE = tuple(DEFAULT_ANSWER_KEY.split(';')) 11 | 12 | DEFAULT_DOC_FEAT = ImageFeatureInfo('image_input', input_key=DEFAULT_DOC_KEY, image_mode='L') 13 | DEFAULT_QUESTION_FEAT = FeatureInfo(None, input_key=DEFAULT_QUESTION_KEY) 14 | DEFAULT_QUESTION_ID_FEAT = FeatureInfo(None, input_key=DEFAULT_QUESTION_ID_KEY) 15 | DEFAULT_ANSWER_FEAT = FeatureInfo(None, input_key=DEFAULT_ANSWER_KEY) 16 | -------------------------------------------------------------------------------- /src/chug/task_pipeline/pipeline_factory.py: -------------------------------------------------------------------------------- 1 | from chug.common import DataTaskCfg 2 | 3 | from .pipeline_doc_read import build_task_pipeline_doc_read, DataTaskDocReadCfg 4 | from .pipeline_doc_vqa import build_task_pipeline_doc_vqa, DataTaskDocVqaCfg 5 | from .pipeline_image_text import build_task_pipeline_image_text, DataTaskImageTextCfg 6 | # from .donut_gtparse_pipe import 7 | from .pipeline_manual import build_task_pipeline_manual, DataTaskManualCfg 8 | 9 | _cfg_to_create = { 10 | DataTaskDocReadCfg: build_task_pipeline_doc_read, 11 | DataTaskDocVqaCfg: build_task_pipeline_doc_vqa, 12 | DataTaskImageTextCfg: build_task_pipeline_image_text, 13 | DataTaskManualCfg: build_task_pipeline_manual, 14 | } 15 | 16 | 17 | def create_task_pipeline(cfg: DataTaskCfg): 18 | create_fn = _cfg_to_create[type(cfg)] 19 | return create_fn(cfg) 20 | -------------------------------------------------------------------------------- /src/chug/app/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass, replace 4 | from datetime import datetime 5 | from pprint import pprint 6 | from typing import Dict, Optional, Union 7 | 8 | import simple_parsing 9 | 10 | from chug.common import ImageInputCfg, ImageAugCfg, DataArg 11 | from chug.wds import create_loader_wds 12 | 13 | @dataclass 14 | class TestArgs: 15 | data: DataArg 16 | # FIXME need TaskArg form to define subset of task cfg options from command line 17 | input: ImageInputCfg 18 | aug: ImageAugCfg 19 | 20 | 21 | def main(): 22 | args = simple_parsing.parse( 23 | TestArgs, 24 | add_option_string_dash_variants=simple_parsing.DashVariant.DASH, 25 | argument_generation_mode=simple_parsing.ArgumentGenerationMode.BOTH, 26 | add_config_path_arg=True, 27 | ) 28 | 29 | pprint(args) 30 | 31 | loader = create_loader_wds(...) 32 | 33 | # FIXME WIP app to demo iteration / analysis for supported datasets 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /src/chug/common/task_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | 4 | 5 | @dataclass 6 | class DataTaskCfg: 7 | """ 8 | Attributes: 9 | ... 10 | output_tuple: Output tuples instead of dicts from task pipeline. 11 | filter_valid: Filter out invalid / incomplete samples. 12 | flatten_json: Flatten json dicts into parent sample dict. Common to have in wds datasets. 13 | error_handler: 'specifies which error (exception) handler should be used, 'ignore_and_continue' 14 | is good for training, 'reraise_exception' for debugging purposes. 15 | """ 16 | decode_fn: Optional[Callable] = None 17 | image_process_fn: Optional[Callable] = None 18 | text_process_fn: Optional[Callable] = None 19 | decode_and_process_fn: Optional[Callable] = None 20 | output_tuple: bool = False # output features as tuple instead of dictionary 21 | filter_valid: bool = False # enable filter to keep samples with valid key-values 22 | flatten_json: bool = True # flatten nested 'json' dicts into parent sample 23 | error_handler: str = 'reraise_exception' -------------------------------------------------------------------------------- /src/chug/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import ( 2 | ImageInputCfg, 3 | ImageAugCfg, 4 | LoaderBundle, 5 | ImageFeatureInfo, 6 | FeatureInfo, 7 | ShardSpec, 8 | SourceSpec, 9 | DataArg, 10 | DataCfg, 11 | DistributedCfg, 12 | ) 13 | from .hfds import create_loader_hf 14 | from .image import ( 15 | build_image_transforms, 16 | build_transforms_image_basic, 17 | build_transforms_image_timm, 18 | build_transforms_doc_basic, 19 | build_transforms_doc_better, 20 | build_transforms_doc_nougat, 21 | create_image_preprocessor, 22 | ) 23 | from .loader import create_loader, create_loader_from_config_hf, create_loader_from_config_wds 24 | from .task_pipeline import ( 25 | create_task_pipeline, 26 | build_task_pipeline_doc_read, 27 | build_task_pipeline_doc_vqa, 28 | build_task_pipeline_gtparse, 29 | build_task_pipeline_image_text, 30 | build_task_pipeline_manual, 31 | DataTaskDocReadCfg, 32 | DataTaskDocVqaCfg, 33 | DataTaskImageTextCfg, 34 | DataTaskManualCfg, 35 | ) 36 | from .text import tokenize, text_input_to_target, prepare_text_input, create_text_preprocessor 37 | from .version import __version__ 38 | from .wds import ( 39 | create_loader_wds, 40 | build_data_pipeline, 41 | decode_image_pages, 42 | decode_pdf_pages, 43 | create_image_decoder, 44 | DecodeDoc, 45 | ) 46 | 47 | -------------------------------------------------------------------------------- /src/chug/common/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | from .types import SharedCount 6 | 7 | 8 | def seed_worker(worker_id): 9 | import torch 10 | worker_seed = torch.initial_seed() 11 | random.seed(worker_seed) 12 | np.random.seed(worker_seed % 2**32) 13 | 14 | 15 | def get_pytorch_worker_seed(increment=0, initial_seed=None): 16 | """get dataloader worker seed from pytorch 17 | """ 18 | from torch.utils.data import get_worker_info 19 | 20 | increment_value = increment.get_value() if isinstance(increment, SharedCount) else increment 21 | worker_info = get_worker_info() 22 | if worker_info is not None: 23 | # favour using the seed already created for pytorch dataloader workers if it exists 24 | seed = worker_info.seed 25 | num_workers = worker_info.num_workers 26 | if increment_value: 27 | # space out seed increments so they can't overlap across workers in different iterations 28 | seed += increment_value * max(1, num_workers) 29 | else: 30 | # a fallback when no dataloader workers are present (num_workers=0) 31 | import torch 32 | 33 | if initial_seed is None: 34 | initial_seed = torch.initial_seed() 35 | 36 | # generate seed from initial via torch.Generator so it matches DL worker seeds 37 | seed = torch.empty((), dtype=torch.int64).random_( 38 | generator=torch.Generator().manual_seed(initial_seed)).item() 39 | 40 | if increment_value: 41 | seed += increment_value 42 | 43 | return seed 44 | -------------------------------------------------------------------------------- /src/chug/hfds/wrappers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, IterableDataset 2 | 3 | from chug.common import SharedCount 4 | 5 | class SafeDataset(Dataset): 6 | """ 7 | This is a Dataset wrapped by a try/except in the __getitem__ in case 8 | the hfds datasets used have errors/corrupt data. 9 | """ 10 | 11 | def __init__(self, original_dataset, max_retry=10): 12 | self.ds = original_dataset 13 | self.max_retry = max_retry 14 | 15 | def __len__(self): 16 | return len(self.ds) 17 | 18 | def __getitem__(self, idx): 19 | err = None 20 | for try_idx in range(self.max_retry): 21 | try: 22 | item = self.ds[idx + try_idx] 23 | return item 24 | except Exception as e: 25 | err = e 26 | continue 27 | raise err 28 | 29 | 30 | 31 | class WrappedIterableDataset(IterableDataset): 32 | """ 33 | """ 34 | 35 | def __init__(self, original_dataset, interval_count=None, max_retry=10): 36 | self.ds = original_dataset 37 | self.max_retry = max_retry 38 | self.interval_count = interval_count 39 | 40 | def set_interval_count(self, interval_count): 41 | if isinstance(self.interval_count, SharedCount): 42 | self.interval_count.set_value(interval_count) 43 | else: 44 | self.interval_count = interval_count 45 | 46 | def __iter__(self): 47 | if isinstance(self.interval_count, SharedCount): 48 | interval_count = self.interval_count.get_value() 49 | else: 50 | interval_count = self.interval_count 51 | self.ds.set_epoch(interval_count) 52 | for sample in self.ds: 53 | yield sample 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "chug" 7 | authors = [ 8 | {name = "Ross Wightman", email = "ross@huggingface.co"}, 9 | ] 10 | description = "" 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | keywords = ["webdataset", "datasets", "sharded", "cluster", "scale", "documents"] 14 | license = {text = "Apache-2.0"} 15 | classifiers = [ 16 | 'Development Status :: 3 - Alpha', 17 | 'Intended Audience :: Education', 18 | 'Intended Audience :: Science/Research', 19 | 'License :: OSI Approved :: Apache Software License', 20 | 'Programming Language :: Python :: 3.8', 21 | 'Programming Language :: Python :: 3.9', 22 | 'Programming Language :: Python :: 3.10', 23 | 'Programming Language :: Python :: 3.11', 24 | 'Programming Language :: Python :: 3.12', 25 | 'Topic :: Scientific/Engineering', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'Topic :: Software Development', 28 | 'Topic :: Software Development :: Libraries', 29 | 'Topic :: Software Development :: Libraries :: Python Modules', 30 | ] 31 | dependencies = [ 32 | "webdataset", 33 | "datasets", 34 | "timm", 35 | "torch", 36 | "simple_parsing", 37 | "pypdfium2", 38 | ] 39 | dynamic = ["version"] 40 | 41 | [project.urls] 42 | homepage = "https://github.com/huggingface/chug" 43 | repository = "https://github.com/huggingface/chug" 44 | 45 | [project.optional-dependencies] 46 | # albumentations (nougat augs) 47 | alb = [ 48 | "albumentations", 49 | 'cv2', 50 | ] 51 | 52 | [tool.pdm.version] 53 | source = "file" 54 | path = "src/chug/version.py" 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # PyCharm 101 | .idea 102 | 103 | output/ 104 | 105 | # PyTorch weights 106 | *.tar 107 | *.pth 108 | *.pt 109 | *.torch 110 | *.gz 111 | Untitled.ipynb 112 | Testing notebook.ipynb 113 | 114 | # Root dir exclusions 115 | /*.csv 116 | /*.yaml 117 | /*.json 118 | /*.jpg 119 | /*.png 120 | /*.zip 121 | /*.tar.* -------------------------------------------------------------------------------- /src/chug/image/transforms_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | from chug.common import ImageInputCfg, ImageAugCfg 4 | from .build_transforms_doc import build_transforms_doc_better, build_transforms_doc_nougat, build_transforms_doc_basic 5 | from .build_transforms_image import build_transforms_image_timm, build_transforms_image_basic 6 | 7 | _transform_factories = { 8 | "image_basic": build_transforms_image_basic, 9 | "image_timm": build_transforms_image_timm, 10 | "doc_basic": build_transforms_doc_basic, 11 | "doc_nougat": build_transforms_doc_nougat, 12 | "doc_better": build_transforms_doc_better, 13 | } 14 | 15 | def build_image_transforms( 16 | input_cfg: ImageInputCfg, 17 | is_training=True, 18 | do_normalize=True, 19 | do_convert=False, 20 | composed=True, 21 | aug_cfg: Optional[Union[Dict[str, Any], ImageAugCfg]] = None, 22 | ): 23 | common_args = dict( 24 | input_cfg=input_cfg, 25 | is_training=is_training, 26 | do_normalize=do_normalize, 27 | aug_cfg=aug_cfg, 28 | composed=composed, 29 | ) 30 | 31 | tt = input_cfg.transform_type 32 | assert tt in _transform_factories, \ 33 | f"Unrecognized transform type: {tt}. Must be one of {list(_transform_factories.keys())}." 34 | transforms = _transform_factories[tt](**common_args) 35 | 36 | return transforms 37 | 38 | 39 | def create_image_preprocessor( 40 | input_cfg: ImageInputCfg, 41 | is_training=True, 42 | do_normalize=True, 43 | do_convert=False, 44 | aug_cfg: Optional[Union[Dict[str, Any], ImageAugCfg]] = None, 45 | ): 46 | transforms = build_image_transforms( 47 | input_cfg=input_cfg, 48 | is_training=is_training, 49 | do_normalize=do_normalize, 50 | do_convert=do_convert, 51 | aug_cfg=aug_cfg, 52 | composed=True, 53 | ) 54 | # NOTE for now, a stack of composed transforms are the image pre-processor 55 | return transforms 56 | 57 | -------------------------------------------------------------------------------- /src/chug/task_pipeline/pipeline_doc_read.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Callable, Optional 3 | 4 | import webdataset as wds 5 | 6 | from chug.common import DataTaskCfg, FeatureInfo, ImageFeatureInfo 7 | from chug.doc import DocReadProcessor, DEFAULT_DOC_FEAT 8 | from chug.wds import get_error_handler 9 | 10 | 11 | @dataclass 12 | class DataTaskDocReadCfg(DataTaskCfg): 13 | image_input_feat: ImageFeatureInfo = DEFAULT_DOC_FEAT 14 | text_input_feat: FeatureInfo = FeatureInfo('text_input', input_key='pages') 15 | text_target_feat: Optional[FeatureInfo] = FeatureInfo('text_target', input_key=None) 16 | page_sampling: str = 'random' 17 | render_dpi: int = 150 18 | 19 | 20 | def build_task_pipeline_doc_read( 21 | cfg: DataTaskDocReadCfg, 22 | ): 23 | handler = get_error_handler(cfg.error_handler) 24 | pipe = [] 25 | 26 | # document decoding & pre-processing done together, there is coupling in random page 27 | # selection and in the future, masking of image and/or text 28 | pipe += [ 29 | wds.map( 30 | DocReadProcessor( 31 | image_process_fn=cfg.image_process_fn, 32 | text_process_fn=cfg.text_process_fn, 33 | image_input_feat=cfg.image_input_feat, 34 | text_input_feat=cfg.text_input_feat, 35 | text_target_feat=cfg.text_target_feat, 36 | page_sampling=cfg.page_sampling, 37 | render_dpi=cfg.render_dpi, 38 | flatten_json=cfg.flatten_json, 39 | ), 40 | handler=handler, 41 | ) 42 | ] 43 | 44 | if cfg.output_tuple: 45 | # NOTE in this mode we lose '_parse' key and would need to derive from target 46 | # Unless we add support for parse as the last tuple element? 47 | if cfg.text_target_feat is not None: 48 | pipe += [ 49 | wds.to_tuple( 50 | cfg.image_input_feat.output_name, 51 | cfg.text_input_feat.output_name, 52 | cfg.text_target_feat.output_name, 53 | ) 54 | ] 55 | else: 56 | pipe += [ 57 | wds.to_tuple( 58 | cfg.image_input_feat.output_name, 59 | cfg.text_input_feat.output_name, 60 | ) 61 | ] 62 | return pipe 63 | -------------------------------------------------------------------------------- /src/chug/wds/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from urllib.parse import urlparse 5 | 6 | import braceexpand 7 | import webdataset as wds 8 | 9 | 10 | def urldir(url): 11 | """Return the directory part of a url.""" 12 | parsed_url = urlparse(url) 13 | path = parsed_url.path 14 | directory = os.path.dirname(path) 15 | return parsed_url._replace(path=directory).geturl() 16 | 17 | 18 | def expand_urls(urls, weights=None): 19 | if weights is None: 20 | expanded_urls = wds.shardlists.expand_urls(urls) 21 | return expanded_urls 22 | 23 | if isinstance(urls, str): 24 | urllist = urls.split("::") 25 | weights = weights.split('::') 26 | assert len(weights) == len(urllist), \ 27 | f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." 28 | weights = [float(weight) for weight in weights] 29 | all_urls, all_weights = [], [] 30 | for url, weight in zip(urllist, weights): 31 | expanded_url = list(braceexpand.braceexpand(url)) 32 | expanded_weights = [weight for _ in expanded_url] 33 | all_urls.extend(expanded_url) 34 | all_weights.extend(expanded_weights) 35 | return all_urls, all_weights 36 | else: 37 | all_urls = list(urls) 38 | return all_urls, weights 39 | 40 | 41 | def log_and_continue(exn): 42 | """Call in an exception handler to ignore any exception, issue a warning, and continue.""" 43 | logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') 44 | return True 45 | 46 | 47 | def dump_and_reraise(exn): 48 | """Dump stack and stop.""" 49 | import traceback 50 | exception_trace = ''.join(traceback.format_tb(exn.__traceback__)) 51 | logging.error(f'Handling webdataset {type(exn)}. Exception trace:\n {exception_trace}') 52 | current_trace = ''.join(traceback.format_tb(exn.__traceback__)) 53 | logging.error(f'Current stack trace:\n {current_trace}') 54 | raise exn 55 | 56 | 57 | _error_handlers = { 58 | 'log_and_continue': log_and_continue, 59 | 'ignore_and_continue': wds.ignore_and_continue, 60 | 'warn_and_continue': wds.warn_and_continue, 61 | 'ignore_and_stop': wds.ignore_and_stop, 62 | 'warn_and_stop': wds.warn_and_stop, 63 | 'dump_and_reraise': dump_and_reraise, 64 | 'reraise_exception': wds.reraise_exception, 65 | } 66 | 67 | def get_error_handler(name: str): 68 | return _error_handlers.get(name, dump_and_reraise) 69 | -------------------------------------------------------------------------------- /src/chug/wds/dataset_info.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os 4 | from typing import Dict 5 | 6 | from webdataset.shardlists import expand_urls 7 | 8 | from chug.common import SplitInfo 9 | 10 | 11 | def get_dataset_size(shards): 12 | shardlist, _ = expand_urls(shards) 13 | dir_path = os.path.dirname(shardlist[0]) 14 | 15 | sizes_filename = os.path.join(dir_path, 'sizes.json') 16 | len_filename = os.path.join(dir_path, '__len__') 17 | 18 | if os.path.exists(sizes_filename): 19 | sizes = json.load(open(sizes_filename, 'r')) 20 | total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shardlist]) 21 | elif os.path.exists(len_filename): 22 | total_size = ast.literal_eval(open(len_filename, 'r').read()) 23 | else: 24 | total_size = None # num samples undefined 25 | 26 | num_shards = len(shardlist) 27 | 28 | return total_size, num_shards 29 | 30 | ## FIXME this is not working / not completed, parsing _info files is a TODO 31 | 32 | def _parse_split_info(split: str, info: Dict): 33 | def _info_convert(dict_info): 34 | return SplitInfo( 35 | num_samples=dict_info['num_samples'], 36 | filenames=tuple(dict_info['filenames']), 37 | shard_lengths=tuple(dict_info['shard_lengths']), 38 | name=dict_info['name'], 39 | ) 40 | 41 | if 'tar' in split or '..' in split: 42 | split_filenames = expand_urls(split) 43 | if split_name: 44 | split_info = info['splits'][split_name] 45 | if not num_samples: 46 | _fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])} 47 | num_samples = sum(_fc[f] for f in split_filenames) 48 | split_info['filenames'] = tuple(_fc.keys()) 49 | split_info['shard_lengths'] = tuple(_fc.values()) 50 | split_info['num_samples'] = num_samples 51 | split_info = _info_convert(split_info) 52 | else: 53 | split_info = SplitInfo( 54 | name=split_name, 55 | num_samples=num_samples, 56 | filenames=split_filenames, 57 | ) 58 | else: 59 | if 'splits' not in info or split not in info['splits']: 60 | raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})") 61 | split = split 62 | split_info = info['splits'][split] 63 | split_info = _info_convert(split_info) 64 | 65 | return split_info -------------------------------------------------------------------------------- /src/chug/wds/tariterators.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from webdataset.filters import pipelinefilter 4 | from webdataset.tariterators import base_plus_ext, valid_sample, url_opener, tar_file_expander 5 | 6 | from .helpers import log_and_continue 7 | 8 | BASE_RE = re.compile(r"^((?:.*/|)[^.]+)[.]([^/]*)$") 9 | 10 | 11 | def base_plus_ext(path): 12 | """Split off all file extensions. 13 | 14 | Returns base, allext. 15 | 16 | Args: 17 | path: path with extensions 18 | 19 | Returns: 20 | path with all extensions removed 21 | """ 22 | match = re.match(BASE_RE, path) 23 | if not match: 24 | return None, None 25 | return match.group(1), match.group(2) 26 | 27 | 28 | def group_by_keys_nothrow( 29 | data, 30 | keys=base_plus_ext, 31 | lcase=True, 32 | suffixes=None, 33 | handler=None, 34 | ): 35 | """Return function over iterator that groups key, value pairs into samples. 36 | 37 | :param keys: function that splits the key into key and extension (base_plus_ext) 38 | :param lcase: convert suffixes to lower case (Default value = True) 39 | """ 40 | current_sample = None 41 | for filesample in data: 42 | assert isinstance(filesample, dict) 43 | fname, value = filesample["fname"], filesample["data"] 44 | prefix, suffix = keys(fname) 45 | if prefix is None: 46 | continue 47 | if lcase: 48 | suffix = suffix.lower() 49 | # FIXME wds version throws if suffix in current_sample, but we have a potential for 50 | # this happening in the current LAION400m dataset if a tar ends with same prefix as the next 51 | # begins, rare, but can happen since prefix aren't unique across tar files in that dataset 52 | if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: 53 | if valid_sample(current_sample): 54 | yield current_sample 55 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) 56 | if suffixes is None or suffix in suffixes: 57 | current_sample[suffix] = value 58 | if valid_sample(current_sample): 59 | yield current_sample 60 | 61 | 62 | def tarfile_samples_nothrow(src, handler=log_and_continue): 63 | # NOTE this is a re-impl of the wds impl with group_by_keys that doesn't throw 64 | streams = url_opener(src, handler=handler) 65 | files = tar_file_expander(streams, handler=handler) 66 | samples = group_by_keys_nothrow(files, handler=handler) 67 | return samples 68 | 69 | 70 | tarfile_to_samples_nothrow = pipelinefilter(tarfile_samples_nothrow) 71 | -------------------------------------------------------------------------------- /src/chug/hfds/collate.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Optional 2 | 3 | import torch.utils 4 | from torch.utils.data import IterableDataset, DataLoader 5 | 6 | from chug.common import collate 7 | 8 | 9 | def invoke(f, *args, **kwargs): 10 | if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0: 11 | return iter(f) 12 | if isinstance(f, list): 13 | return iter(f) 14 | if callable(f): 15 | result = f(*args, **kwargs) 16 | return result 17 | raise ValueError(f"{f}: not a valid pipeline stage") 18 | 19 | 20 | """ 21 | pipe = [wds.rename(image='jpg'), wds.map_dict(image=tf), wds.to_tuple('image', 'cls')] 22 | hfc = HfCollate(pipe) 23 | dl = DataLoader(ds, batch_size=32, num_workers=4, persistent_workers=True, collate_fn=hfc) 24 | """ 25 | 26 | def flatten_bytes(data): 27 | for sample in data: 28 | to_replace = {k for k, v in sample.items() if isinstance(v, dict) and 'bytes' in v} 29 | if to_replace: 30 | result = {k: v for k, v in sample.items() if k not in to_replace} 31 | result.update({k: sample[k]['bytes'] for k in to_replace}) 32 | yield result 33 | else: 34 | yield sample 35 | 36 | 37 | class HfCollate: 38 | """ Collation wrapper that applies processing pipeline for HF datasets use 39 | """ 40 | def __init__( 41 | self, 42 | pipeline: List[Callable], 43 | collate_fn: Optional[Callable] = None, 44 | apply_collate: bool = True, 45 | ): 46 | """ 47 | Args: 48 | pipeline: list of pipeline functions 49 | collate_fn: use a custom collation function, otherwise defaults to torch default_collate 50 | """ 51 | self.pipeline = pipeline 52 | self.collate_fn = collate_fn or collate 53 | self.apply_collate = apply_collate 54 | self._debug = False 55 | 56 | def __call__(self, batch): 57 | item = False 58 | if not self.apply_collate and isinstance(batch, dict): 59 | batch = [batch] 60 | item = True 61 | 62 | if self._debug: 63 | for b in batch: 64 | for k, v in b.items(): 65 | print(k, type(v)) 66 | if isinstance(v, torch.Tensor): 67 | print(v.shape) 68 | 69 | if self.pipeline: 70 | for pipe_fn in self.pipeline: 71 | batch = invoke(pipe_fn, batch) 72 | 73 | batch = list(batch) 74 | 75 | if self._debug: 76 | for b in batch: 77 | for k, v in b.items(): 78 | print(k, v) 79 | 80 | if self.apply_collate: 81 | return self.collate_fn(batch) 82 | else: 83 | return batch[0] if item else batch 84 | -------------------------------------------------------------------------------- /src/chug/common/collate.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Callable, Dict, Optional, Tuple, Type, Union 3 | 4 | from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format 5 | 6 | 7 | def collate(batch): 8 | r""" 9 | A customized collate function that handles collection type of element within each batch. 10 | 11 | This collate function has been tweaked to provide different functionality when handling 12 | dictionary samples. Certain keys are excluded or not tensorized. 13 | 14 | Args: 15 | batch: a single batch to be collated 16 | """ 17 | elem = batch[0] 18 | elem_type = type(elem) 19 | 20 | if elem_type in default_collate_fn_map: 21 | return default_collate_fn_map[elem_type](batch) 22 | 23 | for collate_type in default_collate_fn_map: 24 | if isinstance(elem, collate_type): 25 | return default_collate_fn_map[collate_type](batch) 26 | 27 | if isinstance(elem, collections.abc.Mapping): 28 | try: 29 | out = {} 30 | for key in elem: 31 | if key.startswith('__'): 32 | # skip keys starting with '__', e.g. '__key__', 33 | continue 34 | elif key.startswith('_'): 35 | # do not recurse or tensorize values for keys starting with '_', e.g. '_parse' 36 | out[key] = [d[key] for d in batch] 37 | else: 38 | out[key] = collate([d[key] for d in batch]) 39 | out = elem_type(out) 40 | return out 41 | except TypeError: 42 | # The mapping type may not support `__init__(iterable)`. 43 | return {key: collate([d[key] for d in batch]) for key in elem} 44 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 45 | return elem_type(*(collate(samples) for samples in zip(*batch))) 46 | elif isinstance(elem, collections.abc.Sequence): 47 | # check to make sure that the elements in batch have consistent size 48 | it = iter(batch) 49 | elem_size = len(next(it)) 50 | if not all(len(elem) == elem_size for elem in it): 51 | raise RuntimeError('each element in list of batch should be of equal size') 52 | transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. 53 | 54 | if isinstance(elem, tuple): 55 | return [collate(samples) for samples in transposed] # Backwards compatibility. 56 | else: 57 | try: 58 | return elem_type([collate(samples) for samples in transposed]) 59 | except TypeError: 60 | # The sequence type may not support `__init__(iterable)` (e.g., `range`). 61 | return [collate(samples) for samples in transposed] 62 | 63 | raise TypeError(default_collate_err_msg_format.format(elem_type)) -------------------------------------------------------------------------------- /src/chug/common/urls.py: -------------------------------------------------------------------------------- 1 | import os 2 | from numbers import Number 3 | from typing import Sequence 4 | 5 | 6 | import braceexpand 7 | import re 8 | 9 | 10 | def envlookup(m): 11 | """Look up match in the environment with prefix WDS_. 12 | 13 | Args: 14 | m: a match object 15 | 16 | Returns: 17 | str: the value of the environment variable WDS_ 18 | """ 19 | key = m.group(1) 20 | for prefix in ('WDS_', 'CHUG_'): 21 | key = prefix + key 22 | if key in os.environ: 23 | return os.environ[key] 24 | assert key in os.environ, f"missing WDS/CHUG environment variable for {key}" 25 | 26 | 27 | def envsubst(s): 28 | """Substitute ${var} with the value of the environment variable WDS_var. 29 | 30 | Args: 31 | s (str): string to be substituted 32 | 33 | Returns: 34 | str: the substituted string 35 | """ 36 | return re.sub(r"\$\{(\w+)\}", envlookup, s) 37 | 38 | 39 | def _subst_and_expand(url: str): 40 | for i in range(10): 41 | last = url 42 | url = envsubst(url) 43 | if url == last: 44 | break 45 | return braceexpand.braceexpand(url) 46 | 47 | 48 | def expand_urls(urls, weights=None): 49 | """ Expand urls (and optionally weights) if they are strings, otherwise return as lists. 50 | """ 51 | if weights is None: 52 | if isinstance(urls, str): 53 | url_list = urls.split("::") 54 | result = [] 55 | for url in url_list: 56 | result.extend(_subst_and_expand(url)) 57 | return result, None 58 | else: 59 | return list(urls), None 60 | 61 | if isinstance(urls, str): 62 | url_list = urls.split('::') 63 | 64 | if isinstance(weights, str): 65 | weights = weights.split('::') 66 | elif isinstance(weights, Number): 67 | weights = [weights] * len(url_list) 68 | assert len(weights) == len(url_list), \ 69 | f"Expected the number of data components ({len(url_list)}) and weights({len(weights)}) to match." 70 | weights = [float(weight) for weight in weights] 71 | all_urls, all_weights = [], [] 72 | for url, weight in zip(url_list, weights): 73 | expanded_url = list(_subst_and_expand(url)) 74 | expanded_weights = [weight] * len(expanded_url) 75 | all_urls.extend(expanded_url) 76 | all_weights.extend(expanded_weights) 77 | else: 78 | all_urls = list(urls) 79 | if isinstance(weights, Number): 80 | # if weights is a scalar, expand to url list 81 | all_weights = [float(weights)] * len(all_urls) 82 | else: 83 | assert len(weights) == len(all_urls), \ 84 | f"Expected the number of data components ({len(all_urls)}) and weights({len(weights)}) to match." 85 | all_weights = list(weights) 86 | 87 | return all_urls, all_weights 88 | 89 | -------------------------------------------------------------------------------- /src/chug/task_pipeline/pipeline_image_text.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Callable, Optional, Union 4 | 5 | import webdataset as wds 6 | 7 | from chug.common import DataTaskCfg, ImageFeatureInfo, FeatureInfo 8 | from chug.wds import get_error_handler, create_image_decoder 9 | 10 | _DEFAULT_IMG_KEY = "jpg;png;jpeg;webp;tif" 11 | _DEFAULT_TXT_KEY = "txt" 12 | _DEFAULT_IMG_KEY_TUPLE = tuple(_DEFAULT_IMG_KEY.split(';')) 13 | _DEFAULT_TXT_KEY_TUPLE = tuple(_DEFAULT_TXT_KEY.split(';')) 14 | 15 | 16 | @dataclass 17 | class DataTaskImageTextCfg(DataTaskCfg): 18 | image_input_feat: ImageFeatureInfo = ImageFeatureInfo() 19 | text_input_feat: FeatureInfo = FeatureInfo('text', input_key=_DEFAULT_TXT_KEY) 20 | 21 | 22 | def filter_incomplete( 23 | sample, 24 | image_key=_DEFAULT_IMG_KEY_TUPLE, 25 | text_key=_DEFAULT_TXT_KEY_TUPLE 26 | ): 27 | has_caption = any(k in sample for k in text_key) 28 | has_image = any(k in sample for k in image_key) 29 | return has_caption and has_image 30 | 31 | 32 | def build_task_pipeline_image_text( 33 | cfg: DataTaskImageTextCfg, 34 | ): 35 | """ Create pipeline for dual image & text input pipelines. 36 | """ 37 | handler = get_error_handler(cfg.error_handler) 38 | pipe = [] 39 | 40 | # FIXME add support for caption target for caption tasks or use a separate pipe? 41 | 42 | if cfg.filter_valid: 43 | filter_fn = partial( 44 | filter_incomplete, 45 | image_key=tuple(cfg.image_input_feat.input_key.split(';')), 46 | text_key=tuple(cfg.text_input_feat.input_key.split(';')), 47 | ) 48 | pipe += [ 49 | wds.select(filter_fn) 50 | ] 51 | 52 | if cfg.decode_and_process_fn: 53 | pipe += [ 54 | wds.map(cfg.decode_and_process_fn) 55 | ] 56 | else: 57 | decode_fn = create_image_decoder( 58 | cfg.decode_fn, 59 | image_mode=cfg.image_input_feat.image_mode, 60 | handler=handler, 61 | ) 62 | 63 | rename_dict = { 64 | cfg.image_input_feat.output_name: cfg.image_input_feat.input_key, 65 | cfg.text_input_feat.output_name: cfg.text_input_feat.input_key, 66 | } 67 | pipe += [ 68 | decode_fn, 69 | wds.rename(**rename_dict, keep=False, handler=handler), 70 | ] 71 | 72 | map_dict = {} 73 | if cfg.image_process_fn is not None: 74 | map_dict[cfg.image_input_feat.output_name] = cfg.image_process_fn 75 | if cfg.text_process_fn is not None: 76 | map_dict[cfg.text_input_feat.output_name] = cfg.text_process_fn 77 | if map_dict: 78 | pipe += [ 79 | wds.map_dict(**map_dict, handler=handler) 80 | ] 81 | 82 | if cfg.output_tuple: 83 | pipe += [ 84 | wds.to_tuple( 85 | cfg.image_input_feat.output_name, 86 | cfg.text_input_feat.output_name, 87 | ) 88 | ] 89 | 90 | return pipe 91 | -------------------------------------------------------------------------------- /src/chug/task_pipeline/pipeline_gtparse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import partial 3 | from typing import Callable, Optional, Union 4 | 5 | import webdataset as wds 6 | 7 | from chug.common import DataTaskCfg, ImageFeatureInfo, FeatureInfo 8 | from chug.wds import get_error_handler, create_image_decoder 9 | 10 | _DEFAULT_IMG_KEY = "jpg;png;jpeg;webp;tif" 11 | _DEFAULT_TXT_KEY = "ground_truth" 12 | _DEFAULT_IMG_KEY_TUPLE = tuple(_DEFAULT_IMG_KEY.split(';')) 13 | _DEFAULT_TXT_KEY_TUPLE = tuple(_DEFAULT_TXT_KEY.split(';')) 14 | 15 | _DEFAULT_IMAGE_FEAT = ImageFeatureInfo('image_input', input_key=_DEFAULT_IMG_KEY, image_mode='L') 16 | _DEFAULT_TXT_FEAT = FeatureInfo('ground_truth', input_key=_DEFAULT_TXT_KEY) 17 | 18 | 19 | @dataclass 20 | class DataTaskImageTextCfg(DataTaskCfg): 21 | image_input_feat: ImageFeatureInfo = _DEFAULT_IMAGE_FEAT 22 | text_input_feat: FeatureInfo = _DEFAULT_TXT_FEAT 23 | 24 | 25 | def filter_no_caption_or_no_image( 26 | sample, 27 | image_key=_DEFAULT_IMG_KEY_TUPLE, 28 | text_key=_DEFAULT_TXT_KEY_TUPLE 29 | ): 30 | has_caption = any(k in sample for k in text_key) 31 | has_image = any(k in sample for k in image_key) 32 | return has_caption and has_image 33 | 34 | 35 | def build_task_pipeline_gtparse( 36 | cfg: DataTaskImageTextCfg, 37 | ): 38 | """ Create pipeline for dual image & text input pipelines. 39 | FIXME add support for caption target for caption tasks or separate pipe? 40 | """ 41 | handler = get_error_handler(cfg.error_handler) 42 | pipe = [] 43 | 44 | if cfg.filter_valid: 45 | filter_fn = partial( 46 | filter_no_caption_or_no_image, 47 | image_key=tuple(cfg.image_input_feat.input_key.split(';')), 48 | text_key=tuple(cfg.text_input_feat.input_key.split(';')), 49 | ) 50 | pipe += [ 51 | wds.select(filter_fn) 52 | ] 53 | 54 | if cfg.decode_and_process_fn: 55 | pipe += [ 56 | wds.map(cfg.decode_and_process_fn) 57 | ] 58 | else: 59 | decode_fn = create_image_decoder( 60 | cfg.decode_fn, 61 | image_mode=cfg.image_input_feat.image_mode, 62 | handler=handler, 63 | ) 64 | 65 | rename_dict = { 66 | cfg.image_input_feat.output_name: cfg.text_input_feat.input_key, 67 | cfg.text_input_feat.output_name: cfg.text_input_feat.input_key, 68 | } 69 | pipe += [ 70 | decode_fn, 71 | wds.rename(**rename_dict), 72 | ] 73 | 74 | map_dict = {} 75 | if cfg.image_process_fn is not None: 76 | map_dict[cfg.image_input_feat.output_name] = cfg.image_process_fn 77 | if cfg.text_process_fn is not None: 78 | map_dict[cfg.text_input_feat.output_name] = cfg.text_process_fn 79 | 80 | if map_dict: 81 | pipe += [ 82 | wds.map_dict(**map_dict, handler=handler) 83 | ] 84 | 85 | if cfg.output_tuple: 86 | pipe += [ 87 | wds.to_tuple( 88 | cfg.image_input_feat.output_name, 89 | cfg.text_input_feat.output_name, 90 | ) 91 | ] 92 | 93 | return pipe 94 | -------------------------------------------------------------------------------- /src/chug/common/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from multiprocessing import Value 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | from numbers import Number 5 | 6 | from torch.utils.data import DataLoader, DistributedSampler 7 | 8 | 9 | class SharedCount: 10 | def __init__(self, count: int = 0): 11 | self.count = Value('i', count) 12 | 13 | def set_value(self, epoch): 14 | self.count.value = epoch 15 | 16 | def get_value(self): 17 | return self.count.value 18 | 19 | 20 | @dataclass 21 | class LoaderBundle: 22 | """ 23 | Bundle a DataLoader with num_batch / num_sample limits, sampler or shared_interval counter exposed 24 | to allow easy seed control per-interval. 25 | """ 26 | loader: DataLoader 27 | num_batches: int = 0 28 | num_samples: int = 0 29 | sampler: DistributedSampler = None 30 | shared_interval: SharedCount = None 31 | 32 | def set_interval(self, interval): 33 | if self.shared_interval is not None: 34 | self.shared_interval.set_value(interval) 35 | if self.sampler is not None and isinstance(self.sampler, DistributedSampler): 36 | self.sampler.set_epoch(interval) 37 | 38 | def __iter__(self): 39 | return self.loader.__iter__() 40 | 41 | 42 | @dataclass 43 | class SplitInfo: 44 | filenames: Tuple[str] 45 | num_samples: int 46 | shard_lengths: Tuple[int] = () 47 | name: str = '' 48 | 49 | 50 | # @dataclass 51 | # class ShardInfo: 52 | # url: str 53 | # weight: float = 1.0 54 | # num_samples: Optional[int] = None 55 | # 56 | # 57 | 58 | 59 | @dataclass 60 | class SourceSpec: 61 | url: str 62 | split: Optional[str] = None # dataset split 63 | template: Optional[str] = None # template to transform url -> usage 64 | sampling_weight: Optional[float] = None 65 | num_samples: Optional[int] = None 66 | 67 | # TODO resolve dataset info and track base url, shard info (sizes, etc) 68 | # base_url: str = None 69 | # info_url: str = None 70 | 71 | 72 | @dataclass 73 | class SourceInfo(SourceSpec): 74 | split_info: Dict[str, SplitInfo] = None 75 | shard_info: Dict[str, Dict[str, Any]] = None 76 | 77 | 78 | @dataclass 79 | class ShardSpec: 80 | urls: List[str] 81 | weights: Optional[Union[float, List[float]]] = None 82 | sizes: Optional[List[int]] = None 83 | 84 | def __post_init__(self): 85 | num_shards = len(self.urls) 86 | if self.weights is not None: 87 | if isinstance(self.weights, Number): 88 | self.weights = [self.weights] * num_shards 89 | assert len(self.weights) == num_shards 90 | if self.sizes is not None: 91 | assert len(self.sizes) == num_shards 92 | 93 | 94 | @dataclass(frozen=True) 95 | class FeatureInfo: 96 | """ Feature Information 97 | 98 | Attributes: 99 | output_name: output feature name, None if an intermediary feature 100 | input_key: input dataset key(s), ';' delimited for multiple options 101 | """ 102 | output_name: Optional[str] = 'image' 103 | input_key: Optional[str] = 'jpg;png' 104 | #parent: Optional[str] = None 105 | 106 | 107 | @dataclass(frozen=True) 108 | class ImageFeatureInfo(FeatureInfo): 109 | """ Image Feature Information 110 | 111 | Attributes: 112 | image_mode: Image colour mode (e.g. 'RGB', 'RGBA', 'L') 113 | output_name: output feature name, None if an intermediary feature 114 | input_key: input dataset key(s), ';' delimited for multiple options 115 | parent: parent key to search for input_key, e.g. 'json' 116 | """ 117 | image_mode: str = 'RGB' 118 | -------------------------------------------------------------------------------- /src/chug/image/transforms_alb.py: -------------------------------------------------------------------------------- 1 | 2 | try: 3 | import albumentations as alb 4 | from albumentations.pytorch import ToTensorV2 5 | has_albumentations = True 6 | except ImportError: 7 | has_albumentations = False 8 | 9 | try: 10 | import cv2 11 | has_cv2 = True 12 | except ImportError: 13 | has_cv2 = False 14 | 15 | import numpy as np 16 | 17 | 18 | class AlbWrapper: 19 | def __init__(self, transforms): 20 | self.transforms = transforms 21 | 22 | def __call__(self, im): 23 | return self.transforms(image=np.asarray(im))["image"] 24 | 25 | def __repr__(self) -> str: 26 | format_string = self.__class__.__name__ + "(" 27 | for t in self.transforms: 28 | format_string += "\n" 29 | format_string += f" {t}" 30 | format_string += "\n)" 31 | return format_string 32 | 33 | 34 | if has_albumentations and has_cv2: 35 | 36 | class ErosionAlb(alb.ImageOnlyTransform): 37 | def __init__(self, scale, always_apply=False, p=0.5): 38 | super().__init__(always_apply=always_apply, p=p) 39 | if type(scale) is tuple or type(scale) is list: 40 | assert len(scale) == 2 41 | self.scale = scale 42 | else: 43 | self.scale = (scale, scale) 44 | 45 | def get_transform_init_args_names(self): 46 | return () 47 | 48 | def apply(self, img, **params): 49 | kernel = cv2.getStructuringElement( 50 | cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1] + 1, 2)) 51 | ) 52 | img = cv2.erode(img, kernel, iterations=1) 53 | return img 54 | 55 | 56 | class DilationAlb(alb.ImageOnlyTransform): 57 | def __init__(self, scale, always_apply=False, p=0.5): 58 | super().__init__(always_apply=always_apply, p=p) 59 | if type(scale) is tuple or type(scale) is list: 60 | assert len(scale) == 2 61 | self.scale = scale 62 | else: 63 | self.scale = (scale, scale) 64 | 65 | def get_transform_init_args_names(self): 66 | return () 67 | 68 | def apply(self, img, **params): 69 | kernel = cv2.getStructuringElement( 70 | cv2.MORPH_ELLIPSE, 71 | tuple(np.random.randint(self.scale[0], self.scale[1] + 1, 2)) 72 | ) 73 | img = cv2.dilate(img, kernel, iterations=1) 74 | return img 75 | 76 | 77 | class BitmapAlb(alb.ImageOnlyTransform): 78 | def __init__(self, value=0, lower=200, always_apply=False, p=0.5): 79 | super().__init__(always_apply=always_apply, p=p) 80 | self.lower = lower 81 | self.value = value 82 | 83 | def get_transform_init_args_names(self): 84 | return () 85 | 86 | def apply(self, img, **params): 87 | img = img.copy() 88 | img[img < self.lower] = self.value 89 | return img 90 | 91 | 92 | class CropMarginCv2: 93 | 94 | def __init__(self): 95 | pass 96 | 97 | def __call__(self, img): 98 | data = np.array(img.convert("L")) 99 | data = data.astype(np.uint8) 100 | max_val = data.max() 101 | min_val = data.min() 102 | if max_val == min_val: 103 | return img 104 | data = (data - min_val) / (max_val - min_val) * 255 105 | gray = 255 * (data < 200).astype(np.uint8) 106 | 107 | coords = cv2.findNonZero(gray) # Find all non-zero points (text) 108 | a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box 109 | return img.crop((a, b, w + a, h + b)) 110 | 111 | -------------------------------------------------------------------------------- /src/chug/task_pipeline/pipeline_doc_vqa.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import partial 3 | from typing import Callable, Dict, Optional, Union 4 | 5 | import webdataset as wds 6 | 7 | from chug.common import DataTaskCfg, FeatureInfo, ImageFeatureInfo 8 | from chug.wds import get_error_handler 9 | from chug.doc import ( 10 | DocVqaProcessor, 11 | DEFAULT_DOC_KEY, 12 | DEFAULT_QUESTION_KEY, 13 | DEFAULT_QUESTION_ID_KEY, 14 | DEFAULT_ANSWER_KEY, 15 | DEFAULT_DOC_KEY_TUPLE, 16 | DEFAULT_QUESTION_KEY_TUPLE, 17 | DEFAULT_ANSWER_KEY_TUPLE, 18 | DEFAULT_DOC_FEAT, 19 | DEFAULT_QUESTION_FEAT, 20 | DEFAULT_QUESTION_ID_FEAT, 21 | DEFAULT_ANSWER_FEAT 22 | ) 23 | 24 | 25 | def filter_missing( 26 | sample, 27 | image_key=DEFAULT_DOC_KEY_TUPLE, 28 | question_key=DEFAULT_QUESTION_KEY_TUPLE, 29 | answer_key=DEFAULT_ANSWER_KEY_TUPLE, 30 | ): 31 | has_question = any(k in sample for k in question_key) 32 | has_answer = any(k in sample for k in answer_key) 33 | has_image = any(k in sample for k in image_key) 34 | return has_question and has_answer and has_image 35 | 36 | 37 | # Currently assuming this schema as default, one set of question/answers per image, images possibly duplicated 38 | # sample = { 39 | # 'png': bytes, 40 | # 'question_id': 33, 41 | # 'doc_id': 55, # optional 42 | # 'question': 'what is a trumpet?', 43 | # 'answers': ['an instrument', 'a brass instrument'] 44 | # } 45 | # 46 | 47 | 48 | @dataclass 49 | class DataTaskDocVqaCfg(DataTaskCfg): 50 | """ 51 | Attributes: 52 | answer_feat: 53 | question_feat: 54 | question_id_feat: 55 | image_input_feat: 56 | text_input_feat: 57 | text_target_feat: 58 | question_prefix: 59 | question_suffix: 60 | answer_prefix: 61 | answer_suffix: 62 | render_dpi: 63 | """ 64 | answer_feat: FeatureInfo = DEFAULT_ANSWER_FEAT 65 | question_feat: FeatureInfo = DEFAULT_QUESTION_FEAT 66 | question_id_feat: FeatureInfo = DEFAULT_QUESTION_ID_FEAT 67 | image_input_feat: ImageFeatureInfo = DEFAULT_DOC_FEAT 68 | text_input_feat: FeatureInfo = FeatureInfo('text_input', input_key=None) 69 | text_target_feat: FeatureInfo = FeatureInfo('text_target', input_key=None) 70 | question_prefix: Optional[str] = '' 71 | question_suffix: Optional[str] = '' 72 | answer_prefix: Optional[str] = '' 73 | answer_suffix: Optional[str] = '' 74 | # FIXME prompt templates instead of prefix+suffix above? 75 | render_dpi: int = 144 76 | 77 | 78 | def build_task_pipeline_doc_vqa( 79 | cfg: DataTaskDocVqaCfg, 80 | ): 81 | # document decoding & pre-processing done together, there is coupling in random page 82 | # selection and in the future, masking of image and/or text 83 | handler = get_error_handler(cfg.error_handler) 84 | 85 | pipe = [ 86 | wds.map( 87 | DocVqaProcessor( 88 | image_process_fn=cfg.image_process_fn, 89 | text_process_fn=cfg.text_process_fn, 90 | image_input_feat=cfg.image_input_feat, 91 | question_feat=cfg.question_feat, 92 | answer_feat=cfg.answer_feat, 93 | question_id_feat=cfg.question_id_feat, 94 | render_dpi=cfg.render_dpi, 95 | question_prefix=cfg.question_prefix, 96 | question_suffix=cfg.question_suffix, 97 | answer_prefix=cfg.answer_prefix, 98 | answer_suffix=cfg.answer_suffix, 99 | ), 100 | handler=handler, 101 | ) 102 | ] 103 | 104 | if cfg.output_tuple: 105 | # NOTE in this mode we lose '_parse' key and would need to derive from target 106 | # Unless we add support for parse as the last tuple element? 107 | if cfg.text_target_feat is not None: 108 | pipe += [ 109 | wds.to_tuple( 110 | cfg.image_input_feat.output_name, 111 | cfg.text_input_feat.output_name, 112 | cfg.text_target_feat.output_name, 113 | ) 114 | ] 115 | else: 116 | pipe += [ 117 | wds.to_tuple( 118 | cfg.image_input_feat.output_name, 119 | cfg.text_input_feat.output_name, 120 | ) 121 | ] 122 | return pipe 123 | 124 | -------------------------------------------------------------------------------- /src/chug/text/tokenization.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Optional, Union 3 | 4 | import torch 5 | 6 | 7 | def prompt_end_pos(tokens: torch.Tensor, prompt_end_token_id: int, empty_default: int = 0) -> int: 8 | end_pos = torch.nonzero(tokens == prompt_end_token_id) 9 | return end_pos[-1].item() if end_pos.numel() > 0 else empty_default 10 | 11 | 12 | def text_input_to_target( 13 | text_input_ids: torch.Tensor, 14 | tokenizer: Optional[Callable] = None, 15 | prompt_end_token: Optional[Union[str, int]] = None, 16 | ignore_id: int = -100, 17 | pad_token_id: Optional[int] = None, 18 | ): 19 | target = text_input_ids.clone() 20 | 21 | if pad_token_id is None: 22 | assert tokenizer is not None, 'tokenizer must be specified if pad_token_id is not.' 23 | pad_token_id = tokenizer.pad_token_id 24 | 25 | # model doesn't need to predict pad token 26 | target[target == pad_token_id] = ignore_id 27 | 28 | if prompt_end_token is not None: 29 | if isinstance(prompt_end_token, str): 30 | assert tokenizer is not None, 'tokenizer must be specified if prompt_end_token_id required.' 31 | prompt_end_token_id = tokenizer.convert_tokens_to_ids(prompt_end_token) 32 | else: 33 | prompt_end_token_id = prompt_end_token # already an int 34 | 35 | # model doesn't need to predict prompt (for VQA) 36 | end_pos = prompt_end_pos(target, prompt_end_token_id) 37 | target[:end_pos + 1] = ignore_id 38 | 39 | return target 40 | 41 | 42 | def tokenize( 43 | text: str, 44 | tokenizer: Callable, 45 | max_length: int, 46 | ids_only: bool = True, 47 | ): 48 | output = tokenizer( 49 | text, 50 | add_special_tokens=False, 51 | return_tensors="pt", 52 | max_length=max_length, 53 | padding="max_length", 54 | truncation=True, 55 | ) 56 | if ids_only: 57 | return output.input_ids[0] 58 | return output 59 | 60 | 61 | def prepare_text_input( 62 | text_input, 63 | tokenizer: Callable, 64 | max_length: int, 65 | task_start_token: Optional[str] = None, 66 | prompt_end_token: Optional[str] = None, 67 | add_eos_token: bool = True, 68 | ignore_id: int = -100, 69 | include_target: bool = True, 70 | return_dict: bool = True, 71 | input_key: str = "text_input", 72 | target_key: str = "text_target", 73 | ): 74 | """ 75 | Simple data preprocessing for raw-text data. 76 | """ 77 | if task_start_token: 78 | text_input = task_start_token + text_input 79 | 80 | if add_eos_token: 81 | text_input += tokenizer.eos_token 82 | 83 | text_input_ids = tokenize(text_input, tokenizer=tokenizer, max_length=max_length) 84 | 85 | if include_target: 86 | text_target_ids = text_input_to_target(text_input_ids, tokenizer, prompt_end_token, ignore_id) 87 | if return_dict: 88 | return {input_key: text_input_ids, target_key: text_target_ids} 89 | else: 90 | return text_input_ids, text_target_ids 91 | else: 92 | # FIXME calculate prompt end pos for validation use (target not needed) 93 | if return_dict: 94 | return {input_key: text_input_ids} 95 | else: 96 | return text_input_ids 97 | 98 | 99 | def create_text_preprocessor( 100 | tokenizer: Union[str, Callable], 101 | max_length: int = 1024, 102 | task_start_token: Optional[str] = None, 103 | prompt_end_token: Optional[str] = None, 104 | ignore_id: int = -100, 105 | include_target: bool = True, 106 | return_dict: bool = True, 107 | input_key: str = "text_input", 108 | target_key: str = "text_target", 109 | ): 110 | if isinstance(tokenizer, str): 111 | import transformers 112 | tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) 113 | 114 | # FIXME just binding prepare_text_input fn for now 115 | # This is currently tailored to the Donut style image enc + text decoder use case. 116 | # More functionality / variety required. 117 | 118 | preprocess_fn = partial( 119 | prepare_text_input, 120 | tokenizer=tokenizer, 121 | max_length=max_length, 122 | task_start_token=task_start_token, 123 | prompt_end_token=prompt_end_token, 124 | ignore_id=ignore_id, 125 | include_target=include_target, 126 | return_dict=return_dict, 127 | input_key=input_key, 128 | target_key=target_key, 129 | ) 130 | return preprocess_fn -------------------------------------------------------------------------------- /src/chug/wds/shardlists.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | from torch.utils.data import IterableDataset 5 | 6 | from chug.common import SharedCount, get_pytorch_worker_seed 7 | from .helpers import expand_urls 8 | 9 | 10 | class ShuffledShardList(IterableDataset): 11 | """An iterable dataset yielding a list of urls that is deterministically shuffled based on epoch.""" 12 | 13 | def __init__( 14 | self, 15 | urls, 16 | seed=0, 17 | interval=-1, 18 | num_sub_intervals=None, 19 | ): 20 | """Iterate through the list of shards.""" 21 | super().__init__() 22 | self.urls = expand_urls(urls) 23 | assert len(self.urls) and isinstance(self.urls[0], str) 24 | self.seed = seed 25 | self.interval = interval 26 | self.num_sub_intervals = num_sub_intervals # FIXME experimental feature 27 | 28 | def __len__(self): 29 | return len(self.urls) 30 | 31 | def __iter__(self): 32 | """Return an iterator over the shards.""" 33 | urls = self.urls.copy() 34 | 35 | # Set epoch 36 | if isinstance(self.interval, SharedCount): 37 | interval = self.interval.get_value() 38 | else: 39 | # NOTE: this is interval tracking is problematic in a multiprocess (dataloader workers or train) 40 | # situation as different workers may wrap at different times (or not at all). 41 | self.interval += 1 42 | interval = self.interval 43 | 44 | if self.seed is not None: 45 | # Shuffle with the same seed across all nodes/workers in each interval or super interval 46 | if self.num_sub_intervals is None: 47 | seed = self.seed + interval 48 | else: 49 | # Keep shuffling consistent across the super epochs 50 | seed = self.seed + (interval // self.num_sub_intervals) 51 | random.Random(seed).shuffle(urls) 52 | 53 | # Restrict to shards in the sub epoch if needed 54 | if self.num_sub_intervals is not None: 55 | urls = urls[interval % self.num_sub_intervals::self.num_sub_intervals] 56 | 57 | # Yield shards 58 | for url in urls: 59 | yield dict(url=url) 60 | 61 | 62 | class ResampledShardsV2(IterableDataset): 63 | """An iterable dataset yielding a list of urls.""" 64 | 65 | def __init__( 66 | self, 67 | urls, 68 | weights=None, 69 | nshards=sys.maxsize, 70 | worker_seed_fn=None, 71 | deterministic=False, 72 | interval=-1, 73 | seed=None, 74 | ): 75 | """Sample shards from the shard list with replacement. 76 | 77 | :param urls: a list of URLs as a Python list or brace notation string 78 | """ 79 | super().__init__() 80 | if weights is not None: 81 | self.urls, self.weights = expand_urls(urls, weights) 82 | assert len(self.urls) == len(self.weights), \ 83 | f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." 84 | else: 85 | self.urls = expand_urls(urls) 86 | self.weights = None 87 | assert isinstance(self.urls[0], str) 88 | self.nshards = nshards 89 | self.rng = random.Random() 90 | self.worker_seed_fn = worker_seed_fn 91 | self.deterministic = deterministic 92 | self.interval = interval 93 | self.seed = seed # only used when seed cannot be recovered from DL workers 94 | 95 | def __iter__(self): 96 | """Return an iterator over the shards.""" 97 | if isinstance(self.interval, SharedCount): 98 | interval = self.interval.get_value() 99 | else: 100 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 101 | # situation as different workers may wrap at different times (or not at all). 102 | self.interval += 1 103 | interval = self.interval 104 | 105 | if self.deterministic: 106 | # reset seed w/ interval if deterministic 107 | if self.worker_seed_fn is None: 108 | # pytorch worker seed should be deterministic (per-worker) 109 | # It is init by process seed, rank, & worker id 110 | seed = get_pytorch_worker_seed(interval, initial_seed=self.seed) 111 | else: 112 | seed = self.worker_seed_fn() + interval 113 | self.rng.seed(seed) 114 | 115 | for _ in range(self.nshards): 116 | if self.weights is None: 117 | yield dict(url=self.rng.choice(self.urls)) 118 | else: 119 | yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) 120 | -------------------------------------------------------------------------------- /src/chug/image/transforms_torch.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms.functional as F 6 | from torchvision import transforms 7 | from PIL import Image, ImageFilter 8 | 9 | 10 | class AlignLongAxis: 11 | def __init__( 12 | self, 13 | input_size, 14 | interpolation=transforms.InterpolationMode.BICUBIC 15 | ): 16 | self.input_size = input_size 17 | self.interpolation = interpolation 18 | 19 | def __call__(self, img): 20 | img_width, img_height = F.get_image_size(img) 21 | if ( 22 | (self.input_size[0] > self.input_size[1] and img_width > img_height) or 23 | (self.input_size[0] < self.input_size[1] and img_width < img_height) 24 | ): 25 | img = F.rotate(img, angle=-90, expand=True, interpolation=self.interpolation) 26 | return img 27 | 28 | 29 | class Bitmap: 30 | def __init__(self, threshold=200): 31 | self.lut = [0 if i < threshold else i for i in range(256)] 32 | 33 | def __call__(self, img): 34 | if img.mode == "RGB" and len(self.lut) == 256: 35 | lut = self.lut + self.lut + self.lut 36 | else: 37 | lut = self.lut 38 | return img.point(lut) 39 | 40 | 41 | class Erosion: 42 | def __init__(self, scale=3): 43 | super().__init__() 44 | if type(scale) is tuple or type(scale) is list: 45 | assert len(scale) == 2 46 | self.scale = scale 47 | else: 48 | self.scale = (scale, scale) 49 | 50 | @staticmethod 51 | def get_params(scale): 52 | if type(scale) is tuple or type(scale) is list: 53 | assert len(scale) == 2 54 | scale = random.choice(scale) 55 | return scale 56 | 57 | def __call__(self, img): 58 | kernel_size = self.get_params(self.scale) 59 | if isinstance(img, torch.Tensor): 60 | padding = kernel_size // 2 61 | img = -torch.nn.functional.max_pool2d(-img, kernel_size=kernel_size, stride=1, padding=padding) # minpool 62 | elif isinstance(img, Image.Image): 63 | img = img.filter(ImageFilter.MinFilter(kernel_size)) 64 | return img 65 | 66 | 67 | class Dilation: 68 | def __init__(self, scale=3): 69 | super().__init__() 70 | self.scale = scale 71 | 72 | @staticmethod 73 | def get_params(scale): 74 | if type(scale) is tuple or type(scale) is list: 75 | assert len(scale) == 2 76 | scale = random.choice(scale) 77 | return scale 78 | 79 | def __call__(self, img): 80 | kernel_size = self.get_params(self.scale) 81 | if isinstance(img, torch.Tensor): 82 | padding = kernel_size // 2 83 | img = torch.nn.functional.max_pool2d(img, kernel_size=kernel_size, stride=1, padding=padding) 84 | elif isinstance(img, Image.Image): 85 | img = img.filter(ImageFilter.MaxFilter(kernel_size)) 86 | return img 87 | 88 | 89 | def python_find_non_zero(image: np.array): 90 | """This is a reimplementation of a findNonZero function equivalent to cv2.""" 91 | non_zero_indices = np.column_stack(np.nonzero(image)) 92 | idxvec = non_zero_indices[:, [1, 0]] 93 | idxvec = idxvec.reshape(-1, 1, 2) 94 | return idxvec 95 | 96 | 97 | def python_bounding_rect(coordinates): 98 | """This is a reimplementation of a BoundingRect function equivalent to cv2.""" 99 | min_values = np.min(coordinates, axis=(0, 1)).astype(int) 100 | max_values = np.max(coordinates, axis=(0, 1)).astype(int) 101 | x_min, y_min = min_values[0], min_values[1] 102 | width = max_values[0] - x_min + 1 103 | height = max_values[1] - y_min + 1 104 | return x_min, y_min, width, height 105 | 106 | 107 | class CropMargin: 108 | def __init__(self) -> None: 109 | pass 110 | 111 | def __call__( 112 | self, 113 | image, 114 | gray_threshold: int = 200, 115 | ) -> np.array: 116 | # FIXME check tensor vs PIL and convert as needed, this is assuming PIL right now 117 | assert not isinstance(image, torch.Tensor) 118 | data = np.array(image.convert("L")).astype(np.uint8) 119 | max_val = data.max() 120 | min_val = data.min() 121 | if max_val == min_val: 122 | return image 123 | data = (data - min_val) / (max_val - min_val) * 255 124 | gray = data < gray_threshold 125 | coords = python_find_non_zero(gray) 126 | x_min, y_min, width, height = python_bounding_rect(coords) 127 | image = F.crop(image, y_min, x_min, height, width) 128 | return image 129 | 130 | 131 | class ConvertColor: 132 | def __init__(self, mode='RGB'): 133 | self.mode = mode 134 | 135 | def __call__(self, image): 136 | assert isinstance(image, Image.Image) 137 | return image.convert(self.mode) 138 | -------------------------------------------------------------------------------- /src/chug/wds/filters.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | from typing import Mapping, Sequence, Union 4 | 5 | import webdataset as wds 6 | from webdataset.filters import _shuffle 7 | 8 | from chug.common import SharedCount, get_pytorch_worker_seed 9 | 10 | 11 | class detshuffle_v2(wds.PipelineStage): 12 | 13 | def __init__( 14 | self, 15 | bufsize: int = 1000, 16 | initial: int = 100, 17 | seed: int = 0, 18 | interval: Union[int, SharedCount] =-1, 19 | unique_worker: bool = False, 20 | ): 21 | self.bufsize = bufsize 22 | self.initial = initial 23 | self.seed = seed 24 | self.interval = interval 25 | self.unique_worker = unique_worker 26 | 27 | def run(self, src): 28 | if isinstance(self.interval, SharedCount): 29 | interval = self.interval.get_value() 30 | else: 31 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 32 | # situation as different workers may wrap at different times (or not at all). 33 | self.interval += 1 34 | interval = self.interval 35 | 36 | rng = random.Random() 37 | if self.unique_worker: 38 | # Use the PyTorch worker's seed, *different* across all nodes/workers 39 | # but also deterministic if they are set consistently 40 | seed = get_pytorch_worker_seed(interval, initial_seed=self.seed) 41 | else: 42 | # This seed to be deterministic AND the *same* across all nodes/workers in each epoch/interval 43 | seed = self.seed + interval 44 | rng.seed(seed) 45 | 46 | return _shuffle(src, self.bufsize, self.initial, rng) 47 | 48 | 49 | def _map_v2(data, f, handler=wds.reraise_exception): 50 | """ Map samples. 51 | 52 | This function differs from wds.map, it only adds '__key__' back to sample if it exists. 53 | 54 | """ 55 | for sample in data: 56 | try: 57 | result = f(sample) 58 | except Exception as exn: 59 | if handler(exn): 60 | continue 61 | else: 62 | break 63 | if result is None: 64 | continue 65 | if isinstance(sample, dict) and isinstance(result, dict) and "__key__" in sample: 66 | result["__key__"] = sample.get("__key__") 67 | yield result 68 | 69 | 70 | map_v2 = wds.pipelinefilter(_map_v2) 71 | 72 | 73 | def _expand_maybe(data, f, handler=wds.reraise_exception): 74 | for sample in data: 75 | if isinstance(sample, Mapping): 76 | try: 77 | result = f(sample) 78 | except Exception as exn: 79 | if handler(exn): 80 | continue 81 | else: 82 | break 83 | if result is None: 84 | continue 85 | if "__key__" in sample: 86 | result["__key__"] = sample["__key__"] 87 | yield result 88 | else: 89 | assert isinstance(sample, Sequence) 90 | for subsample in sample: 91 | assert isinstance(subsample, Mapping) 92 | try: 93 | result = f(subsample) 94 | except Exception as exn: 95 | if handler(exn): 96 | continue 97 | else: 98 | break 99 | if result is None: 100 | continue 101 | if "__key__" in subsample: 102 | result["__key__"] = subsample["__key__"] 103 | yield result 104 | 105 | 106 | map_expand_maybe = wds.pipelinefilter(_expand_maybe) 107 | 108 | 109 | def _expand_always(data, f, handler=wds.reraise_exception): 110 | for sample in itertools.chain(*data): 111 | assert isinstance(sample, Mapping) 112 | try: 113 | result = f(sample) 114 | except Exception as exn: 115 | if handler(exn): 116 | continue 117 | else: 118 | break 119 | if result is None: 120 | continue 121 | if "__key__" in sample: 122 | result["__key__"] = sample["__key__"] 123 | yield result 124 | 125 | 126 | map_expand_always = wds.pipelinefilter(_expand_always) 127 | 128 | 129 | def _flatten_nested(data, *args, replace_existing=True, remove_original=True): 130 | """Convert dict samples to tuples.""" 131 | for sample in data: 132 | for k in args: 133 | nested_dict = sample.pop(k, {}) if remove_original else sample.get(k, {}) 134 | if replace_existing: 135 | sample.update(nested_dict) 136 | elif k in sample: 137 | for sk, sv in nested_dict.items(): 138 | sample.setdefault(sk, sv) 139 | yield sample 140 | 141 | 142 | flatten_nested = wds.pipelinefilter(_flatten_nested) -------------------------------------------------------------------------------- /src/chug/wds/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Sequence, Union 2 | 3 | import torch.utils.data 4 | import webdataset 5 | import webdataset as wds 6 | 7 | from chug.common import ShardSpec, collate 8 | from .filters import detshuffle_v2 9 | from .shardlists import ResampledShardsV2, ShuffledShardList 10 | from .tariterators import tarfile_to_samples_nothrow 11 | 12 | 13 | def build_data_pipeline( 14 | shards: ShardSpec, 15 | task_pipeline: List[Callable], 16 | is_training: bool = False, 17 | batch_size: int = 0, 18 | resampled: bool = False, 19 | multi_interval: bool = False, 20 | seed: int = 0, 21 | shared_interval_count: int = -1, 22 | num_batches_per_worker: int = 0, 23 | sample_shuffle_initial: int = 1, 24 | sample_shuffle_size: int = 1, 25 | collate_fn: Optional[Callable] = None, 26 | batched_task_pipeline: Optional[List[Callable]] = None, 27 | handler=wds.reraise_exception, 28 | ): 29 | """ 30 | 31 | Args: 32 | task_pipeline: task-specific pipeline for processing samples 33 | shards: ShardSpec w/ shard url list and optional sizes and weights 34 | is_training: Train mode. Enables shuffling of samples and forces consistent batch #s across workers. 35 | batch_size: Batch for sample collation. Collation / batching disabled if batch_size == 0. 36 | resampled: Enable resampling with replacement of shards. 37 | multi_interval: 38 | seed: 39 | shared_interval_count: 40 | num_batches_per_worker: 41 | sample_shuffle_initial: 42 | sample_shuffle_size: 43 | collate_fn: 44 | batched_task_pipeline: An optional task-specific pipeline for processing batched samples 45 | handler: Exception handler 46 | 47 | Returns: 48 | 49 | """ 50 | if not isinstance(task_pipeline, (list, tuple)): 51 | task_pipeline = [task_pipeline] 52 | assert len(task_pipeline) 53 | 54 | if resampled: 55 | datapipe = [ResampledShardsV2( 56 | shards.urls, 57 | weights=shards.weights, 58 | deterministic=True, 59 | interval=shared_interval_count, 60 | seed=seed, 61 | )] 62 | else: 63 | assert shards.weights is None, \ 64 | "upsampling_factors is only supported when sampling with replacement (resampled=False)." 65 | if is_training: 66 | datapipe = [ShuffledShardList( 67 | shards.urls, 68 | seed=seed, 69 | interval=shared_interval_count, 70 | )] 71 | else: 72 | datapipe = [wds.SimpleShardList( 73 | shards.urls, 74 | )] 75 | 76 | # at this point we have an iterator over all the shards 77 | if is_training: 78 | if not resampled: 79 | datapipe.extend([ 80 | wds.split_by_node, 81 | wds.split_by_worker, 82 | ]) 83 | # at this point, we have an iterator over the shards assigned to each worker at each node 84 | datapipe.extend([ 85 | tarfile_to_samples_nothrow(handler=handler), 86 | detshuffle_v2( 87 | bufsize=sample_shuffle_size, 88 | initial=sample_shuffle_initial, 89 | seed=seed, 90 | interval=shared_interval_count, 91 | unique_worker=True, 92 | ) 93 | # wds.shuffle( 94 | # bufsize=sample_shuffle_size, 95 | # initial=sample_shuffle_initial, 96 | # ), 97 | ]) 98 | else: 99 | datapipe.extend([ 100 | wds.split_by_worker, 101 | # at this point, we have an iterator over the shards assigned to each worker 102 | wds.tarfile_to_samples(handler=handler), 103 | ]) 104 | 105 | # task specific decode and map pipline (per-sample) 106 | datapipe.extend(task_pipeline) 107 | 108 | # collation, tensor output disabled with batch_size == 0 or None 109 | if batch_size: 110 | # NOTE torch default_collate handles dicts, wds default_collate does not 111 | collate_fn = collate_fn or collate 112 | datapipe.extend([ 113 | wds.batched( 114 | batch_size, 115 | partial=not is_training, 116 | collation_fn=collate_fn, 117 | ) 118 | ]) 119 | 120 | # task-specific batched pipeline (per-batch) 121 | if batched_task_pipeline: 122 | datapipe.extend(batched_task_pipeline) 123 | 124 | datapipe = wds.DataPipeline(*datapipe) 125 | 126 | if is_training and num_batches_per_worker > 0: 127 | if multi_interval: 128 | datapipe = datapipe.with_epoch(num_batches_per_worker) # each worker is iterating over this 129 | else: 130 | datapipe = datapipe.repeat(nbatches=num_batches_per_worker) 131 | 132 | return datapipe 133 | -------------------------------------------------------------------------------- /src/chug/doc/doc_vqa_processor.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Callable, Dict, List, Optional, Tuple 3 | 4 | from chug import FeatureInfo, ImageFeatureInfo 5 | from chug.doc import DocProcessor, DEFAULT_QUESTION_FEAT, DEFAULT_QUESTION_ID_FEAT, DEFAULT_ANSWER_FEAT, \ 6 | DEFAULT_DOC_FEAT 7 | from chug.doc.doc_processor import _get_value 8 | 9 | 10 | class DocVqaProcessor(DocProcessor): 11 | def __init__( 12 | self, 13 | image_process_fn: Optional[Callable] = None, 14 | text_process_fn: Optional[Callable] = None, 15 | question_feat: FeatureInfo = DEFAULT_QUESTION_FEAT, 16 | question_id_feat: FeatureInfo = DEFAULT_QUESTION_ID_FEAT, 17 | answer_feat: FeatureInfo = DEFAULT_ANSWER_FEAT, 18 | multi_qa_feat: Optional[FeatureInfo] = None, 19 | image_input_feat: ImageFeatureInfo = DEFAULT_DOC_FEAT, 20 | text_target_feat: FeatureInfo = FeatureInfo('text_target', input_key=None), 21 | question_prefix: Optional[str] = '', 22 | question_suffix: Optional[str] = '', 23 | answer_prefix: Optional[str] = '', 24 | answer_suffix: Optional[str] = '', 25 | render_dpi: int = 150, 26 | squeeze_pages: bool = True, 27 | expand_pages: bool = False, 28 | flatten_json: bool = True, 29 | seed: int = 0, 30 | ): 31 | super().__init__( 32 | image_process_fn=image_process_fn, 33 | text_process_fn=text_process_fn, 34 | image_input_feat=image_input_feat, 35 | text_target_feat=text_target_feat, 36 | render_dpi=render_dpi, 37 | squeeze_pages=squeeze_pages, 38 | expand_pages=expand_pages, 39 | flatten_json=flatten_json, 40 | seed=seed, 41 | ) 42 | self.question_feat = question_feat 43 | self.question_key = question_feat.input_key.split(';') 44 | self.question_id_feat = question_id_feat 45 | self.question_id_key = question_id_feat.input_key.split(';') 46 | self.answer_feat = answer_feat 47 | self.answer_key = answer_feat.input_key.split(';') 48 | if multi_qa_feat is not None: 49 | self.expand_pages = True # override 50 | self.multi_qa_key = multi_qa_feat.input_key.split(';') 51 | else: 52 | # expand pages only used / supported for multi-qa expansion right now 53 | self.expand_pages = False 54 | self.multi_qa_key = None 55 | 56 | # FIXME support flexible q/a prompting formats, do with prefix/suffix or template strings? 57 | # Donut style: '{question}{answer}' 58 | # Common: 'Question: {question} Answer: {answer}' 59 | self.question_prefix = question_prefix or '' 60 | self.question_suffix = question_suffix or '' 61 | self.answer_prefix = answer_prefix or '' 62 | self.answer_suffix = answer_suffix or '' 63 | #self.prompt_template = '{question}' 64 | #self.prompt_template_full = '{question}{answer}' 65 | 66 | def _decode_anno(self, sample) -> Tuple[Dict, List[int], int]: 67 | if self.multi_qa_key: 68 | # FIXME multi qa expansion is a WIP 69 | qa_list = sample[self.multi_qa_key] 70 | assert isinstance(qa_list, (list, tuple)), f'Expected a list or tuple, got {type(qa_list)}.' 71 | assert False, 'WIP' 72 | else: 73 | question = _get_value(self.question_key, sample) 74 | question_id = _get_value(self.question_id_key, sample) 75 | answers = _get_value(self.answer_key, sample) 76 | 77 | if answers is not None and self.text_target_feat is not None: 78 | answer = random.choice(answers) 79 | else: 80 | answer = None 81 | 82 | input_text = self.question_prefix + question + self.question_suffix + self.answer_prefix 83 | if answer is not None: 84 | input_text += answer + self.answer_suffix 85 | 86 | output = { 87 | self.text_input_name: input_text, 88 | '_parse': { 89 | 'question_id': question_id, 90 | 'question': question, 91 | 'answers': answers, # list, all answers included in parse 92 | } 93 | } 94 | 95 | if self.text_process_fn is not None: 96 | processed = self.text_process_fn(input_text) 97 | assert self.text_input_name in processed, \ 98 | f"Text input name '{self.text_input_name}' not found in processed sample." 99 | if self.text_target_feat is not None: 100 | assert self.text_target_name in processed, f"Expected a text target named '{self.text_target_name}' in processed sample." 101 | output.update(processed) 102 | else: 103 | output[self.text_input_name] = input_text 104 | 105 | return output, [0], 1 106 | 107 | def _expand_anno(self, anno, count: int): 108 | # FIXME implement expansion for multi-qa (and eventually multi-page option) 109 | pass 110 | -------------------------------------------------------------------------------- /src/chug/hfds/loader.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Callable, List, Optional 3 | 4 | from torch.utils.data import DataLoader, DistributedSampler 5 | 6 | from chug.common import LoaderBundle, DistributedCfg, SharedCount 7 | 8 | from .collate import HfCollate 9 | from .wrappers import SafeDataset, WrappedIterableDataset 10 | 11 | _SAMPLE_SHUFFLE_SIZE = 2000 12 | 13 | 14 | def _disable_decode(ds): 15 | import datasets 16 | 17 | to_set = [] 18 | for k, v in ds.features.items(): 19 | if isinstance(v, datasets.Image): 20 | d = deepcopy(v) 21 | d.decode = False 22 | to_set.append((k, d)) 23 | elif isinstance(v, datasets.Audio): 24 | d = deepcopy(v) 25 | d.decode = False 26 | to_set.append((k, d)) 27 | for k, d in to_set: 28 | ds = ds.cast_column(k, d) 29 | return ds 30 | 31 | 32 | def create_loader_hf( 33 | source: str, 34 | split: str, 35 | task_pipeline: Optional[List[Callable]] = None, 36 | data_dir: Optional[str] = None, 37 | num_samples: Optional[int] = None, 38 | streaming: bool = False, 39 | is_training: bool = False, 40 | batch_size: Optional[int] = 1, 41 | resampled: bool = False, 42 | multi_interval: bool = True, 43 | num_batches_round: str = 'ceil', 44 | num_workers: int = 4, 45 | persistent_workers: bool = True, 46 | start_interval: int = 0, 47 | seed: int = 0, 48 | collate_fn: Optional[Callable] = None, 49 | sample_shuffle_size: int = _SAMPLE_SHUFFLE_SIZE, 50 | distributed: DistributedCfg = DistributedCfg(), 51 | disable_decode: bool = False, 52 | ): 53 | """ 54 | 55 | Args: 56 | source: 57 | split: 58 | task_pipeline: 59 | data_dir: 60 | num_samples: 61 | streaming: 62 | is_training: 63 | batch_size: 64 | resampled: 65 | multi_interval: 66 | num_batches_round: 67 | num_workers: 68 | persistent_workers: 69 | start_interval: 70 | seed: 71 | collate_fn: 72 | sample_shuffle_size: 73 | distributed: 74 | disable_decode: 75 | 76 | Returns: 77 | 78 | """ 79 | from datasets import VerificationMode, load_dataset 80 | batched = batch_size is not None and batch_size >= 0 81 | 82 | if collate_fn is not None: 83 | assert task_pipeline is None, 'task_pipeline should not be set if custom collation function is used.' 84 | elif batched or task_pipeline is not None: 85 | # collation fn applies task pipeline 86 | assert task_pipeline is not None, 'task_pipeline is needed' 87 | collate_fn = HfCollate( 88 | task_pipeline, 89 | apply_collate=batched, 90 | ) 91 | 92 | if streaming: 93 | from datasets.distributed import split_dataset_by_node 94 | 95 | dataset = load_dataset( 96 | source, 97 | data_dir=data_dir, 98 | streaming=True, 99 | ) 100 | 101 | if split not in dataset: 102 | assert False, f'Split {split} not in dataset ({dataset.keys()})' 103 | dataset = dataset[split] 104 | if disable_decode: 105 | dataset = _disable_decode(dataset) 106 | 107 | # FIXME num_samples calc, get a reliable estimate from dataset in streaming mode 108 | if num_samples is None: 109 | info = dataset.info 110 | if info.splits is not None and split in info.splits: 111 | num_samples = info.splits[split].num_examples 112 | 113 | if is_training and multi_interval: 114 | assert num_samples, ( 115 | "num_samples must be available in dataset metadata or manually provided for multi-interval training") 116 | 117 | if is_training: 118 | dataset = dataset.shuffle(seed, buffer_size=sample_shuffle_size) 119 | 120 | # FIXME split_dataset_by_node has some concerns as currently implemented 121 | dataset = split_dataset_by_node(dataset, distributed.global_rank, distributed.world_size) 122 | interval_count = SharedCount(start_interval) 123 | dataset = WrappedIterableDataset(dataset, interval_count=interval_count) 124 | 125 | # HF datasets treats batch_size differently than torch defaults, in torch batch_size = None 126 | # disables batching, in HF it returns the full dataset. We restore torch behaviour here. 127 | #batch_size = batch_size or 1 128 | base_loader = DataLoader( 129 | dataset=dataset, 130 | collate_fn=collate_fn, 131 | sampler=None, 132 | shuffle=False, 133 | drop_last=batched and is_training, # FIXME improve wrt train vs validation vs sharding specifics 134 | batch_size=batch_size, 135 | num_workers=num_workers, 136 | persistent_workers=persistent_workers, 137 | ) 138 | 139 | batch_size = batch_size or 1 140 | loader = LoaderBundle( 141 | loader=base_loader, 142 | num_batches=num_workers // batch_size, 143 | num_samples=num_samples, 144 | shared_interval=interval_count, 145 | ) 146 | 147 | else: 148 | dataset = load_dataset( 149 | source, 150 | data_dir=data_dir, 151 | verification_mode=VerificationMode.ALL_CHECKS, 152 | ) 153 | 154 | if split not in dataset: 155 | assert False, f'Split {split} not in dataset ({dataset.keys()})' 156 | dataset = dataset[split] 157 | if disable_decode: 158 | dataset = _disable_decode(dataset) 159 | dataset = SafeDataset(dataset) 160 | 161 | sampler = None 162 | if distributed.world_size > 1: 163 | sampler = DistributedSampler( 164 | dataset, 165 | rank=distributed.global_rank, 166 | shuffle=is_training, 167 | seed=seed, 168 | num_replicas=distributed.world_size, 169 | drop_last=False, 170 | ) 171 | sampler.set_epoch(start_interval) 172 | 173 | base_loader = DataLoader( 174 | dataset=dataset, 175 | collate_fn=collate_fn, 176 | sampler=sampler, 177 | drop_last=is_training, 178 | batch_size=batch_size, 179 | num_workers=num_workers, 180 | ) 181 | 182 | loader = LoaderBundle( 183 | loader=base_loader, 184 | num_batches=len(base_loader), 185 | num_samples=len(dataset), 186 | sampler=sampler, 187 | ) 188 | 189 | return loader 190 | -------------------------------------------------------------------------------- /src/chug/wds/loader.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import math 3 | import warnings 4 | from typing import Callable, List, Optional, Union 5 | 6 | import torch.utils.data 7 | import webdataset as wds 8 | 9 | from chug.common import DistributedCfg, LoaderBundle, SharedCount, ShardSpec, seed_worker 10 | from .helpers import expand_urls 11 | from .pipeline import build_data_pipeline 12 | 13 | 14 | _SAMPLE_SHUFFLE_SIZE = 2000 15 | _SAMPLE_SHUFFLE_INITIAL = 500 16 | 17 | 18 | def create_loader_wds( 19 | shards: Union[str, List[str], ShardSpec], 20 | task_pipeline: Optional[List[Callable]], 21 | num_samples: Optional[int] = None, 22 | is_training: bool = False, 23 | batch_size: Optional[int] = 1, 24 | resampled: bool = False, 25 | multi_interval: bool = True, 26 | num_batches_round: str = 'ceil', 27 | num_workers: int = 4, 28 | persistent_workers: bool = True, 29 | start_interval: int = 0, 30 | seed: int = 0, 31 | handler: Callable = wds.reraise_exception, 32 | collate_fn: Optional[Callable] = None, 33 | sample_shuffle_size: int = _SAMPLE_SHUFFLE_SIZE, 34 | sample_shuffle_initial: int = _SAMPLE_SHUFFLE_INITIAL, 35 | distributed: DistributedCfg = DistributedCfg(), 36 | ): 37 | """ Create a webdataset loader 38 | 39 | Args: 40 | shards: 41 | task_pipeline: 42 | is_training: 43 | resampled: If True, shards are resampled with replacement. 44 | multi_interval: If True, run loader in multi-interval mode (multi-epoch), num_samples is interpreted 45 | as num_samples per interval (epoch if # samples == dataset length). Dataset length is set to a 46 | fixed value to approximate at-least once sample visiting per interval. 47 | If False, num_samples is treated as total samples to visit over training without paying attention 48 | to # samples in underlying dataset. Dataset length is not accessible. 49 | start_interval: 50 | seed: 51 | num_workers: 52 | persistent_workers: 53 | batch_size: 54 | num_batches_round: 55 | collate_fn: 56 | sample_shuffle_size: 57 | sample_shuffle_initial: 58 | distributed: 59 | 60 | Returns: 61 | 62 | """ 63 | resampled = resampled and is_training 64 | 65 | if not isinstance(shards, ShardSpec): 66 | if isinstance(shards, str): 67 | shards = expand_urls(shards) 68 | shards = ShardSpec( 69 | urls=shards, 70 | ) 71 | 72 | num_shards = len(shards.urls) 73 | if num_samples is None: 74 | if shards.sizes: 75 | num_samples = sum(shards.sizes) 76 | if is_training and not num_samples: 77 | raise RuntimeError( 78 | 'The number of dataset samples must be specified for the training dataset ' 79 | 'if no dataset length info is present.') 80 | 81 | num_batches_per_worker = 0 82 | num_workers_nonzero = max(num_workers, 1) 83 | if is_training: 84 | assert batch_size >= 1, 'batching must be enabled for train, set batch_size>=1' 85 | # We want to see the same # of batches on each member of the distributed group (GPU), 86 | # this is enforced by making each worker produce the same # of batches regardless of the 87 | # underlying iterator, so we estimate and make the iterator wrap around if end is hit. 88 | # This will repeat some samples and may miss some sample per interval as shards may be 89 | # uneven or allocated unevenly across all workers. There are ways improve on this naive 90 | # approach, to get closer to the ideal of each sample in an interval (epoch) seen once, 91 | # but difficult to achieve perfectly, and most improvements require full co-ordination across 92 | # all workers via out-of-band IPC/RPC. 93 | 94 | # roll over and repeat a few samples to get same number of full batches on each node 95 | round_fn = math.floor if num_batches_round == 'floor' else math.ceil 96 | global_batch_size = batch_size * distributed.world_size 97 | num_batches = round_fn(num_samples / global_batch_size) 98 | num_batches_per_worker = round_fn(num_batches / num_workers_nonzero) # per dataloader worker 99 | num_batches = num_batches_per_worker * num_workers_nonzero 100 | num_samples = num_batches * global_batch_size 101 | else: 102 | # Eval / inference will exhaust the iterator if the size is not specified. 103 | # Eval currently supported for 1 train process only (primary) 104 | # FIXME support distributed eval 105 | # https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.join 106 | num_samples = num_samples or 0 107 | # last batches are partial, eval is done on a single (primary) process 108 | if batch_size: 109 | num_batches = math.ceil(num_samples / batch_size) 110 | else: 111 | num_batches = 0 112 | 113 | if is_training: 114 | # create a shared epoch store to sync epoch to dataloader worker proc 115 | shared_interval_count = SharedCount(count=start_interval) 116 | if not resampled: 117 | assert num_shards >= num_workers_nonzero * distributed.world_size, 'number of shards must be >= total workers' 118 | else: 119 | shared_interval_count = None 120 | 121 | datapipe = build_data_pipeline( 122 | shards=shards, 123 | task_pipeline=task_pipeline, 124 | is_training=is_training, 125 | batch_size=batch_size, 126 | resampled=resampled, 127 | multi_interval=multi_interval, 128 | seed=seed, 129 | shared_interval_count=shared_interval_count, 130 | num_batches_per_worker=num_batches_per_worker, 131 | sample_shuffle_initial=sample_shuffle_initial, 132 | sample_shuffle_size=sample_shuffle_size, 133 | handler=handler, 134 | collate_fn=collate_fn, 135 | ) 136 | 137 | dl_generator = torch.Generator() 138 | dl_generator.manual_seed(seed) 139 | 140 | dataloader = wds.WebLoader( 141 | datapipe, 142 | batch_size=None, # batching done in data-pipeline 143 | shuffle=False, # shuffling done in data-pipeline 144 | num_workers=num_workers, 145 | persistent_workers=persistent_workers and num_workers > 0, 146 | generator=dl_generator, 147 | worker_init_fn=seed_worker, 148 | ) 149 | 150 | return LoaderBundle( 151 | loader=dataloader, 152 | num_batches=num_batches, 153 | num_samples=num_samples, 154 | shared_interval=shared_interval_count, 155 | ) 156 | 157 | 158 | -------------------------------------------------------------------------------- /src/chug/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional 2 | 3 | from chug.common import DataCfg, DataTaskCfg, DistributedCfg, LoaderBundle, source_to_shard_spec, SourceSpec 4 | from chug.hfds import create_loader_hf 5 | from chug.task_pipeline import create_task_pipeline 6 | from chug.wds import create_loader_wds, get_error_handler 7 | 8 | 9 | def create_loader( 10 | data_cfg: DataCfg, 11 | task_cfg: DataTaskCfg, 12 | task_pipeline: Optional[List[Callable]] = None, 13 | is_training: bool = False, 14 | start_interval: int = 0, 15 | seed: int = 0, 16 | distributed: DistributedCfg = DistributedCfg(), 17 | ) -> LoaderBundle: 18 | """ 19 | Creates a dataloader for training or validation based on configuration settings. 20 | 21 | Parameters: 22 | data_cfg: Configuration object for the dataset. 23 | task_cfg: Configuration object for the task specific processing. 24 | task_pipeline: Task specific processing pipeline (takes priority over task_cfg). 25 | is_training : Indicates if the loader is for training data (True) or validation data (False). 26 | start_interval: The starting interval (epoch for full passes) for setting seed, etc. appropriately. 27 | seed: Seed for random operations to ensure reproducibility. 28 | distributed: Distributed device information. 29 | 30 | Returns: 31 | DataLoader: A PyTorch DataLoader instance configured according to the provided settings. 32 | 33 | Note: 34 | Currently supports "wds" and "hf_dataset" as dataset formats. 35 | """ 36 | if data_cfg.format == "wds": 37 | loader = create_loader_from_config_wds( 38 | data_cfg=data_cfg, 39 | task_cfg=task_cfg, 40 | task_pipeline=task_pipeline, 41 | is_training=is_training, 42 | start_interval=start_interval, 43 | seed=seed, 44 | distributed=distributed, 45 | ) 46 | 47 | elif data_cfg.format.startswith("hf"): 48 | loader = create_loader_from_config_hf( 49 | data_cfg=data_cfg, 50 | task_cfg=task_cfg, 51 | task_pipeline=task_pipeline, 52 | is_training=is_training, 53 | start_interval=start_interval, 54 | seed=seed, 55 | distributed=distributed, 56 | ) 57 | 58 | else: 59 | assert False, f"Unsupported dataset format ({data_cfg.format})." 60 | 61 | return loader 62 | 63 | 64 | def _validate_cfgs( 65 | data_cfg: DataCfg, 66 | task_cfg: Optional[DataTaskCfg], 67 | is_training: bool = False, 68 | ): 69 | batch_size = data_cfg.batch_size 70 | if batch_size is not None: 71 | if task_cfg.decode_and_process_fn is None: 72 | # FIXME make validation task specific once we have tasks that don't require both image and text preproc 73 | assert task_cfg.image_process_fn is not None and task_cfg.text_process_fn is not None,\ 74 | 'task_cfg.image_process_fn and task_cfg.text_process_fn must be set if batching enabled' 75 | else: 76 | assert task_cfg.decode_fn is None, \ 77 | 'task_cfg.decode_fn should not be set at the same time as task_cfg.decode_and_process_fn' 78 | 79 | 80 | def create_loader_from_config_wds( 81 | data_cfg: DataCfg, 82 | task_cfg: Optional[DataTaskCfg], 83 | task_pipeline: Optional[List[Callable]] = None, 84 | is_training: bool = False, 85 | start_interval: int = 0, 86 | seed: int = 0, 87 | collate_fn: Optional[Callable] = None, 88 | distributed: DistributedCfg = DistributedCfg(), 89 | ): 90 | """ 91 | 92 | Args: 93 | data_cfg: 94 | task_cfg: 95 | task_pipeline: 96 | is_training: 97 | start_interval: 98 | seed: 99 | collate_fn: 100 | distributed: 101 | 102 | Returns: 103 | 104 | """ 105 | _validate_cfgs(data_cfg, task_cfg, is_training=is_training) 106 | 107 | if task_pipeline is None: 108 | assert task_cfg is not None 109 | task_pipeline = create_task_pipeline( 110 | task_cfg, 111 | ) 112 | 113 | handler = get_error_handler(task_cfg.error_handler) 114 | 115 | return create_loader_wds( 116 | shards=data_cfg.shard_spec, 117 | task_pipeline=task_pipeline, 118 | num_samples=data_cfg.num_samples, 119 | is_training=is_training, 120 | resampled=data_cfg.resampled, 121 | multi_interval=True, # FIXME via config? 122 | num_workers=data_cfg.num_workers, 123 | batch_size=data_cfg.batch_size, 124 | persistent_workers=data_cfg.persistent_workers, 125 | collate_fn=collate_fn, 126 | start_interval=start_interval, 127 | seed=seed, 128 | handler=handler, 129 | distributed=distributed, 130 | ) 131 | 132 | 133 | def create_loader_from_config_hf( 134 | data_cfg: DataCfg, 135 | task_cfg: DataTaskCfg, 136 | task_pipeline: Optional[List[Callable]] = None, 137 | is_training: bool = False, 138 | start_interval: int = 0, 139 | seed: int = 0, 140 | distributed: DistributedCfg = DistributedCfg(), 141 | ): 142 | """ 143 | 144 | Args: 145 | data_cfg: 146 | task_cfg: 147 | task_pipeline: 148 | is_training: 149 | start_interval: 150 | seed: 151 | distributed: 152 | 153 | Returns: 154 | 155 | """ 156 | assert not isinstance(data_cfg.source, (list, tuple)), "Multiple sources not supported for HF datasets." 157 | assert isinstance(data_cfg.source, (str, SourceSpec)), \ 158 | "The specified source for HF dataset must be a string or SourceSpec." 159 | 160 | _validate_cfgs(data_cfg, task_cfg, is_training=is_training) 161 | 162 | if isinstance(data_cfg.source, SourceSpec): 163 | source = data_cfg.source.url 164 | split = data_cfg.source.split 165 | assert split, "Split must be set in SourceSpec with HF datasets." 166 | else: 167 | source = data_cfg.source 168 | split = data_cfg.split 169 | assert split, "Split must be set in DataCfg when string source is used with HF datasets." 170 | 171 | if task_pipeline is None: 172 | assert task_cfg is not None 173 | task_pipeline = create_task_pipeline( 174 | task_cfg, 175 | ) 176 | 177 | streaming = 'hfids' in data_cfg.format 178 | 179 | return create_loader_hf( 180 | source=source, 181 | split=split, 182 | task_pipeline=task_pipeline, 183 | streaming=streaming, 184 | is_training=is_training, 185 | batch_size=data_cfg.batch_size, 186 | data_dir=data_cfg.data_dir, 187 | num_samples=data_cfg.num_samples, 188 | num_workers=data_cfg.num_workers, 189 | persistent_workers=data_cfg.persistent_workers, 190 | seed=seed, 191 | distributed=distributed, 192 | ) 193 | -------------------------------------------------------------------------------- /src/chug/image/build_transforms_image.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import Any, Dict, Optional, Union 3 | 4 | from torchvision import transforms 5 | from timm.data import ( 6 | ResizeKeepRatio, 7 | CenterCropOrPad, 8 | RandomResizedCropAndInterpolation, 9 | create_transform, 10 | ) 11 | 12 | from chug.common import ImageInputCfg, ImageAugCfg 13 | from .transforms_torch import ConvertColor 14 | 15 | 16 | def build_transforms_image_timm( 17 | input_cfg: ImageInputCfg, 18 | is_training: bool = False, 19 | do_normalize: bool = True, 20 | do_convert: bool = False, 21 | aug_cfg: Optional[Union[Dict[str, Any], ImageAugCfg]] = None, 22 | composed: bool = True, 23 | ): 24 | """ Build image transforms leveraging timm's create_transform() functionality. 25 | 26 | Args: 27 | input_cfg: 28 | is_training: In training mode, apply train transforms w/ augmentations 29 | do_normalize: Enable normalization of ouput tensors by specified mean & std deviation. 30 | aug_cfg: 31 | 32 | Returns: 33 | 34 | """ 35 | interpolation = input_cfg.interpolation or 'bicubic' 36 | assert interpolation in ['bicubic', 'bilinear', 'random'] 37 | 38 | resize_mode = input_cfg.resize_mode or 'shortest' 39 | assert resize_mode in ('shortest', 'longest', 'squash') 40 | 41 | if isinstance(aug_cfg, dict): 42 | aug_cfg = ImageAugCfg(**aug_cfg) 43 | else: 44 | aug_cfg = aug_cfg or ImageAugCfg.imagenet() 45 | 46 | if is_training: 47 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 48 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 49 | # FIXME map aug_cfg -> timm.create_transform args more carefully 50 | 51 | train_transform = create_transform( 52 | input_size=input_cfg.size, 53 | is_training=True, 54 | use_prefetcher=not do_normalize, # FIXME prefetcher mode disables normalize, but outputs np.array 55 | hflip=0., 56 | mean=input_cfg.mean, 57 | std=input_cfg.std, 58 | re_mode='pixel', 59 | interpolation=interpolation, 60 | **aug_cfg_dict, 61 | ) 62 | return train_transform 63 | else: 64 | if resize_mode == 'longest': 65 | timm_crop_mode = 'border' 66 | elif resize_mode == 'squash': 67 | timm_crop_mode = 'squash' 68 | else: 69 | assert resize_mode == 'shortest' 70 | timm_crop_mode = 'center' 71 | 72 | eval_transform = create_transform( 73 | input_size=input_cfg.size, 74 | is_training=False, 75 | use_prefetcher=not do_normalize, # FIXME prefetcher mode disables normalize, but outputs np.array 76 | mean=input_cfg.mean, 77 | std=input_cfg.std, 78 | crop_pct=1.0, 79 | crop_mode=timm_crop_mode, 80 | # FIXME 81 | # composed=composed, 82 | ) 83 | return eval_transform 84 | 85 | 86 | def build_transforms_image_basic( 87 | input_cfg: ImageInputCfg, 88 | is_training: bool = False, 89 | do_normalize: bool = True, 90 | do_convert: bool = False, 91 | aug_cfg: Optional[Union[Dict[str, Any], ImageAugCfg]] = None, 92 | composed: bool = True, 93 | ): 94 | """ Build image transfoms leveraging torchvision transforms. 95 | """ 96 | if do_normalize: 97 | normalize = transforms.Normalize(mean=input_cfg.mean, std=input_cfg.std) 98 | else: 99 | normalize = None 100 | 101 | interpolation = input_cfg.interpolation or 'bicubic' 102 | assert interpolation in ['bicubic', 'bilinear', 'random'] 103 | # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set 104 | if interpolation == 'bilinear': 105 | interpolation_mode = transforms.InterpolationMode.BILINEAR 106 | else: 107 | interpolation_mode = transforms.InterpolationMode.BICUBIC 108 | 109 | resize_mode = input_cfg.resize_mode or 'shortest' 110 | assert resize_mode in ('shortest', 'longest', 'squash') 111 | 112 | if isinstance(aug_cfg, dict): 113 | aug_cfg = ImageAugCfg(**aug_cfg) 114 | else: 115 | aug_cfg = aug_cfg or ImageAugCfg.imagenet() 116 | 117 | if is_training: 118 | image_size = input_cfg.size 119 | if resize_mode == 'shortest': 120 | if isinstance(image_size, (tuple, list)) and image_size[0] == image_size[1]: 121 | image_size = image_size[0] # w/ scalar for final resize in RRC will use shortest edge 122 | # FIXME note we don't have good option for 'longest' resizing w/ RRC 123 | 124 | transform_list = [ 125 | # like torchvision.transforms.RandomResizedCrop but supports randomized interpolation for robustness 126 | RandomResizedCropAndInterpolation( 127 | image_size, 128 | scale=aug_cfg.scale or (1.0, 1.0), 129 | ratio=aug_cfg.ratio or (1.0, 1.0), 130 | interpolation=interpolation, 131 | ), 132 | ] 133 | 134 | if do_convert: 135 | transform_list.append(ConvertColor(mode=input_cfg.mode)) 136 | 137 | if aug_cfg.color_jitter_prob: 138 | assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 139 | if aug_cfg.color_jitter_prob is not None: 140 | transform_list.append( 141 | transforms.RandomApply( 142 | transforms.ColorJitter( 143 | *aug_cfg.color_jitter, 144 | ), 145 | p=aug_cfg.color_jitter_prob, 146 | ) 147 | ) 148 | elif aug_cfg.color_jitter is not None: 149 | transform_list.append(transforms.ColorJitter(*aug_cfg.color_jitter)) 150 | 151 | if aug_cfg.grayscale_prob: 152 | transform_list.append(transforms.RamndomGrayscale(aug_cfg.grayscale_prob)) 153 | 154 | if aug_cfg.gaussian_blur_prob: 155 | gaussian_blur_kernel = aug_cfg.gaussian_blur_kernel_size or 23 156 | transforms.RandomApply([ 157 | transforms.GaussianBlur( 158 | kernel_size=gaussian_blur_kernel, 159 | )], 160 | p=aug_cfg.gaussian_blur_prob, 161 | ) 162 | 163 | else: 164 | image_size = input_cfg.size 165 | 166 | if resize_mode == 'longest': 167 | transform_list = [ 168 | ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), 169 | CenterCropOrPad(image_size, fill=input_cfg.fill_color) 170 | ] 171 | elif resize_mode == 'squash': 172 | if isinstance(image_size, int): 173 | image_size = (image_size, image_size) 174 | 175 | transform_list = [ 176 | transforms.Resize(image_size, interpolation=interpolation_mode, antialias=True), 177 | ] 178 | else: 179 | assert resize_mode == 'shortest' 180 | if not isinstance(image_size, (tuple, list)): 181 | image_size = (image_size, image_size) 182 | 183 | if image_size[0] == image_size[1]: 184 | # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) 185 | transform_list = [ 186 | transforms.Resize(image_size[0], interpolation=interpolation_mode, antialias=True) 187 | ] 188 | else: 189 | # resize shortest edge to matching target dim for non-square target 190 | transform_list = [ResizeKeepRatio(image_size)] 191 | 192 | transform_list += [transforms.CenterCrop(image_size)] 193 | 194 | if do_convert: 195 | transform_list.append(ConvertColor(mode=input_cfg.mode)) 196 | # end if is_training 197 | 198 | transform_list += [transforms.ToTensor()] 199 | 200 | if normalize is not None: 201 | transform_list += [normalize] 202 | 203 | return transforms.Compose(transform_list) if composed else transform_list 204 | -------------------------------------------------------------------------------- /src/chug/doc/doc_read_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | from chug import ImageFeatureInfo, FeatureInfo 4 | from chug.doc import DocProcessor, DEFAULT_DOC_FEAT 5 | from chug.doc.doc_processor import get_next_valid_page_index, _get_value, _logger 6 | 7 | 8 | class DocReadProcessor(DocProcessor): 9 | """ Process documents w/ OCR annotation for reading tasks. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | image_process_fn: Optional[Callable] = None, 15 | text_process_fn: Optional[Callable] = None, 16 | image_input_feat: ImageFeatureInfo = DEFAULT_DOC_FEAT, 17 | text_input_feat: FeatureInfo = FeatureInfo('text_input', input_key='pages'), 18 | text_target_feat: FeatureInfo = FeatureInfo('text_target', input_key=None), 19 | line_break: str = '\n', 20 | page_sampling: str = 'random', 21 | render_dpi: int = 150, 22 | squeeze_pages: bool = True, 23 | expand_pages: bool = False, 24 | flatten_json: bool = True, 25 | seed: int = 0, 26 | ): 27 | super().__init__( 28 | image_process_fn=image_process_fn, 29 | text_process_fn=text_process_fn, 30 | image_input_feat=image_input_feat, 31 | text_input_feat=text_input_feat, 32 | text_target_feat=text_target_feat, 33 | render_dpi=render_dpi, 34 | page_sampling=page_sampling, 35 | squeeze_pages=squeeze_pages, 36 | expand_pages=expand_pages, 37 | flatten_json=flatten_json, 38 | seed=seed, 39 | ) 40 | self.line_break = line_break 41 | assert page_sampling in ('random', 'first', 'all_valid', 'all') 42 | 43 | def _process_anno_pages(self, anno): 44 | assert isinstance(anno, (list, tuple)), f"Annotation should be a list of pages" 45 | num_pages = len(anno) 46 | if not num_pages: 47 | raise RuntimeError("Empty annotation. Skipping...") 48 | 49 | # FIXME for initial behaviour we will randomly sample one of N pages 50 | # TODO determine if we want to train in multi-page mode, use another sampling strategy? 51 | page_indices = [] 52 | try: 53 | if self.page_sampling == 'random': 54 | n_wanted_pages = min(1, num_pages) # TODO increase for multi-page processing, rand start+end? 55 | current_index = self.generator.randrange(-1, num_pages - 1) 56 | for _ in range(n_wanted_pages): 57 | current_index = get_next_valid_page_index(current_index, num_pages, anno) 58 | page_indices.append(current_index) 59 | elif self.page_sampling == 'first': 60 | current_index = get_next_valid_page_index(-1, num_pages, anno) 61 | page_indices.append(current_index) 62 | elif self.page_sampling == 'all_valid': 63 | current_index = -1 64 | for _ in range(num_pages): 65 | current_index = get_next_valid_page_index(current_index, num_pages, anno) 66 | page_indices.append(current_index) 67 | elif self.page_sampling == 'all': 68 | page_indices = list(range(num_pages)) 69 | except RuntimeError: 70 | pass 71 | 72 | if not page_indices: 73 | raise RuntimeError("No valid annotated pages. Skipping...") 74 | 75 | text_pages = [] 76 | tokenized_text_pages = [] 77 | target_pages = [] 78 | for current_index in page_indices: 79 | # FIXME currently encoding each page separately with own start/end tokens. 80 | # For multi-age should consider encoding in one sequence w/ page-break tokens. 81 | anno_page = anno[current_index] 82 | if 'lines' in anno_page: 83 | # Two supported formats right now 84 | # { 85 | # 'pages': [ 86 | # { 87 | # 'text': [], # these are lines 88 | # 'bbox': [], 89 | # } 90 | # ] 91 | # } 92 | # 93 | # OR 94 | # 95 | # { 96 | # 'pages': [ 97 | # { 98 | # 'lines': { 99 | # 'text': [], 100 | # 'bbox': [], 101 | # }, 102 | # 'words': { 103 | # 'text': [], 104 | # 'bbox': [], 105 | # } 106 | # } 107 | # ] 108 | # } 109 | # 110 | # 111 | anno_page = anno_page['lines'] 112 | 113 | # Currently page text is created by concatenating lines of text with a CR line break 114 | # Additions could involve: 115 | # * using different line-break tokens between lines 116 | # * using word-level bbox anno information to mask works and construct partial lines 117 | # * group lines into blocks (or use block annos) and treat blocks / paragraphs of text and 118 | if not anno_page["text"]: 119 | raise RuntimeError("No text on page, skipping sample...") 120 | 121 | text = self.line_break.join(anno_page["text"]) 122 | 123 | # FIXME cleanup, split process and decode for more flexibility 124 | # tokenize w/ and generate training target if enabled 125 | if self.text_process_fn is not None: 126 | processed = self.text_process_fn(text) 127 | assert self.text_input_name in processed, \ 128 | f"Text input name '{self.text_input_name}' not found in processed sample." 129 | tokenized_text_pages.append(processed[self.text_input_name]) 130 | if self.text_target_name in processed: 131 | target_pages.append(processed[self.text_target_name]) 132 | else: 133 | if self.text_target_feat is not None: 134 | assert False, f"Expected a text target named '{self.text_target_name}' in processed sample." 135 | else: 136 | # FIXME warn assert that target not supported w/o text preprocessing? 137 | tokenized_text_pages.append(text) 138 | 139 | text_pages.append(anno_page["text"]) # unencoded text added as lines 140 | 141 | gt_parse = { 142 | 'num_pages': num_pages, # total # of pages in doc 143 | 'page_indices': page_indices, # page indices sampled 144 | 'page_text': text_pages, # text of sampled page indices pages[].lines[] 145 | } 146 | 147 | output = { 148 | self.text_input_name: tokenized_text_pages, 149 | '_parse': gt_parse, 150 | } 151 | if target_pages: 152 | output[self.text_target_name] = target_pages 153 | 154 | return output 155 | 156 | def _decode_anno(self, sample): 157 | anno = _get_value(self.text_input_key, sample) 158 | assert anno is not None, f"No annotation found with keys ({self.text_input_key})." 159 | 160 | try: 161 | page_anno = self._process_anno_pages(anno) 162 | except Exception as exn: 163 | _logger.error(f'Issue processing annotation for {sample["__url__"]}, {sample["__key__"]}.') 164 | #_logger.error(json.dumps(anno, indent=4)) 165 | raise exn 166 | 167 | # extract info from the _parse 168 | info = page_anno.get('_parse', {}) 169 | page_indices = info.get('page_indices', [0]) # the samples page indices 170 | num_anno_pages = info.get('num_pages', 1) 171 | 172 | # TODO support 'image info' to relay details such as text bbox, layout 173 | # page_image_info = info.get('image_info', None) 174 | # if page_image_info is not None: 175 | # assert len(page_image_info) == len(page_indices) 176 | 177 | return page_anno, page_indices, num_anno_pages 178 | 179 | def _expand_anno(self, anno, count: int): 180 | expanded_annos = [] 181 | for i in range(count): 182 | sample = {} 183 | for k, v in anno.items(): 184 | if k == '_parse': 185 | gt_parse = {} 186 | gt_parse['num_pages'] = v['num_pages'] 187 | gt_parse['page_indices'] = [v['page_indices'][i]] 188 | gt_parse['page_text'] = [v['page_text'][i]] 189 | sample[k] = gt_parse 190 | else: 191 | sample[k] = v[i] if isinstance(v, (list, tuple)) else v 192 | expanded_annos.append(sample) 193 | return expanded_annos 194 | -------------------------------------------------------------------------------- /src/chug/doc/doc_processor.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import json 3 | import logging 4 | import random 5 | from typing import Callable, Dict, List, Optional, Tuple 6 | 7 | from chug.common import FeatureInfo, ImageFeatureInfo 8 | from chug.wds import decode_image_pages, decode_pdf_pages 9 | 10 | 11 | from .constants import ( 12 | DEFAULT_DOC_FEAT, 13 | ) 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | def get_next_valid_page_index( 19 | current_index: int, 20 | num_pages: int, 21 | page_annos: list, 22 | retries: int = 10, 23 | wanted: str = 'lines', 24 | ): 25 | """ 26 | Get the index of the next valid page which contains text. If it doesn't find any non-empty page 27 | after 'retries' attempts, it raises a RuntimeError. 28 | 29 | Parameters: 30 | current_index (int): Current page index. 31 | num_pages (int): Total number of pages. 32 | page_annos (list): List of page annotations. 33 | retries (int): Number of maximum retries for a given document. 34 | 35 | Returns: 36 | int: The index of the next non-empty page. 37 | """ 38 | for _ in range(retries): 39 | # Get the next index, wrap around to 0 if it exceeds num_pages (in case of random init) 40 | current_index = (current_index + 1) % num_pages 41 | anno_page = page_annos[current_index] 42 | anno_page = anno_page.get(wanted, anno_page) # use 'lines' / 'words' level if exists 43 | if anno_page["text"]: 44 | return current_index 45 | raise RuntimeError(f"No non-empty page found after {retries} attempts") 46 | 47 | 48 | def _get_value(keys, sample): 49 | if isinstance(keys, (list, tuple)): 50 | value = None 51 | for k in keys: 52 | if (value := sample.get(k, None)) is not None: 53 | break 54 | return value 55 | else: 56 | return sample.get(keys, None) 57 | 58 | 59 | class DocProcessor: 60 | """ Process documents w/ OCR annotation for reading tasks. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | image_process_fn: Optional[Callable] = None, 66 | text_process_fn: Optional[Callable] = None, 67 | image_input_feat: ImageFeatureInfo = DEFAULT_DOC_FEAT, 68 | text_input_feat: FeatureInfo = FeatureInfo('text_input', input_key='pages'), 69 | text_target_feat: FeatureInfo = FeatureInfo('text_target', input_key=None), 70 | page_sampling: str = 'random', 71 | render_dpi: int = 150, 72 | squeeze_pages: bool = True, 73 | expand_pages: bool = False, 74 | flatten_json: bool = True, 75 | seed: int = 0, 76 | ): 77 | """ 78 | 79 | Args: 80 | image_process_fn: 81 | text_process_fn: 82 | page_sampling: 83 | render_dpi: 84 | seed: 85 | """ 86 | self.image_process_fn = image_process_fn 87 | self.text_process_fn = text_process_fn 88 | 89 | self.image_input_feat = image_input_feat 90 | self.image_input_name = image_input_feat.output_name 91 | self.image_input_key = image_input_feat.input_key.split(';') 92 | self.text_input_feat = text_input_feat 93 | self.text_input_name = text_input_feat.output_name 94 | self.text_input_key = text_input_feat.input_key.split(';') 95 | self.text_target_feat = text_target_feat 96 | self.text_target_name = text_target_feat.output_name 97 | 98 | self.page_sampling = page_sampling 99 | self.render_dpi = render_dpi 100 | self.squeeze_pages = squeeze_pages 101 | self.expand_pages = expand_pages 102 | self.flatten_json = flatten_json 103 | self.generator = random.Random() 104 | self.generator.seed(seed) 105 | # FIXME note, should move to torchvision v2 annotations at some point 106 | # * they should all eventually have a generator arg for better handling random state 107 | # * they have forms that accept bbox/points args to transform annotations in sync with image 108 | 109 | def _preprocess_image_pages(self, decoded_pages, page_image_info=None): 110 | if self.image_process_fn is None: 111 | return decoded_pages 112 | 113 | if page_image_info is not None: 114 | # FIXME, WIP. If train objective involves masking or otherwise processing image 115 | # with knowledge of annotations / text content, anno info should contain 116 | # mask locations, etc. For such a task, we need to pass it to image preprocess 117 | decoded_pages = [self.image_process_fn(dp, page_info=pi) for dp, pi in zip(decoded_pages, page_image_info)] 118 | else: 119 | decoded_pages = [self.image_process_fn(dp) for dp in decoded_pages] 120 | 121 | return decoded_pages 122 | 123 | def _decode_image_pages( 124 | self, 125 | sample, 126 | ext, 127 | page_indices, 128 | num_anno_pages, 129 | ): 130 | image_mode = self.image_input_feat.image_mode 131 | 132 | decoded_pages, num_image_pages = decode_image_pages( 133 | sample[ext], 134 | image_mode=image_mode, 135 | page_indices=page_indices, 136 | ) 137 | if num_image_pages != num_anno_pages: 138 | _logger.warning( 139 | f'Mismatch between num image and num annotation pages {num_image_pages} != {num_anno_pages}' 140 | f' for sample {sample["__url__"]}, {sample["__key__"]}.') 141 | 142 | decoded_pages = self._preprocess_image_pages(decoded_pages) 143 | 144 | return decoded_pages, num_image_pages 145 | 146 | def _decode_pdf_pages( 147 | self, 148 | sample, 149 | ext, 150 | page_indices, 151 | num_anno_pages, 152 | ): 153 | image_mode = self.image_input_feat.image_mode 154 | decoded_pages, num_image_pages = decode_pdf_pages( 155 | sample[ext], 156 | image_mode=image_mode, 157 | page_indices=page_indices, 158 | ) 159 | if num_anno_pages is not None and num_image_pages != num_anno_pages: 160 | _logger.warning( 161 | f'Mismatch between num image and num annotation pages {num_image_pages} != {num_anno_pages}' 162 | f' for sample {sample["__url__"]}, {sample["__key__"]}.') 163 | 164 | decoded_pages = self._preprocess_image_pages(decoded_pages) 165 | 166 | return decoded_pages, num_image_pages 167 | 168 | @abc.abstractmethod 169 | def _decode_anno(self, sample) -> Tuple[Dict, List[int], int]: 170 | pass 171 | 172 | def _expand_anno(self, anno, count: int): 173 | expanded_annos = [ 174 | {k: v[i] if isinstance(v, (list, tuple)) else v for k, v in anno.items()} 175 | for i in range(count) 176 | ] 177 | return expanded_annos 178 | 179 | def _squeeze_anno(self, anno): 180 | anno = {k: v[0] if isinstance(v, (list, tuple)) else v for k, v in anno.items()} 181 | return anno 182 | 183 | def __call__(self, sample): 184 | if 'json' in sample and isinstance(sample['json'], bytes): 185 | # decode json if present and in undecoded state 186 | sample['json'] = json.loads(sample['json']) 187 | 188 | if self.flatten_json and 'json' in sample: 189 | # flatten json into sample 190 | sample.update(sample.pop('json')) 191 | 192 | # FIXME separate decode & preprocess interfaces 193 | 194 | # decode page annotations / text 195 | page_anno, page_indices, num_anno_pages = self._decode_anno(sample) 196 | 197 | # decode page images 198 | page_images = [] 199 | for ext in self.image_input_key: 200 | if ext in sample: 201 | if ext == 'pdf': 202 | images, num_image_pages = self._decode_pdf_pages( 203 | sample, 204 | ext, 205 | page_indices, 206 | num_anno_pages, 207 | ) 208 | else: 209 | images, num_image_pages = self._decode_image_pages( 210 | sample, 211 | ext, 212 | page_indices, 213 | num_anno_pages, 214 | ) 215 | page_images.extend(images) 216 | # process one document type per doc, should be ordered by priority 217 | break 218 | 219 | assert len(page_images), 'No page images present' 220 | 221 | if self.expand_pages and len(page_images) > 1: 222 | # expand pages and page annotations into multiple samples (return list of sample dicts) 223 | page_anno = self._expand_anno(page_anno, len(page_images)) 224 | decoded = [{self.image_input_name: pi, **pa} for pi, pa in zip(page_images, page_anno)] 225 | else: 226 | if self.squeeze_pages and len(page_images) == 1: 227 | # squeeze page & annotation lists into singular items 228 | page_images = page_images[0] 229 | page_anno = self._squeeze_anno(page_anno) 230 | decoded = {self.image_input_name: page_images, **page_anno} 231 | 232 | return decoded 233 | 234 | 235 | -------------------------------------------------------------------------------- /src/chug/common/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, fields, replace 2 | from typing import Any, Dict, List, Optional, Tuple, Union, Sequence, Callable 3 | 4 | from simple_parsing.helpers import Serializable 5 | 6 | from .types import ImageFeatureInfo, ShardSpec, SourceSpec 7 | from .urls import expand_urls 8 | 9 | 10 | def image_mode_to_chs(fmt: str): 11 | if fmt is None: 12 | return None 13 | assert fmt in ('L', 'RGB') # could support more... 14 | return 1 if fmt == 'L' else 3 15 | 16 | 17 | @dataclass 18 | class ImageInputCfg(Serializable): 19 | size: Optional[Tuple[int, int]] = (512, 512) 20 | mode: Optional[str] = 'L' 21 | mean: Optional[Union[float, Tuple[float, ...]]] = 0.5 22 | std: Optional[Union[float, Tuple[float, ...]]] = 0.5 23 | interpolation: Optional[str] = 'bicubic' 24 | fill_color: Optional[Union[int, Tuple[int, ...]]] = 255 25 | crop_margin: Optional[bool] = False 26 | align_long_axis: Optional[bool] = False 27 | transform_type: Optional[str] = 'image_basic' 28 | resize_mode: Optional[str] = 'shortest' 29 | 30 | @property 31 | def image_chs(self): 32 | return image_mode_to_chs(self.mode) 33 | 34 | def __post_init__(self): 35 | image_chs = self.image_chs 36 | if image_chs is not None: 37 | # ensue mean/std attributes match # image_chs 38 | for attr_name in ('mean', 'std'): 39 | attr = getattr(self, attr_name) 40 | if attr is not None and not isinstance(attr, Sequence): 41 | attr = (attr,) 42 | if image_chs == 1 and len(attr) > image_chs: 43 | attr = (sum(attr) / len(attr),) 44 | if image_chs > 1 and len(attr) == 1: 45 | attr = attr * image_chs 46 | assert len(attr) == image_chs 47 | setattr(self, attr_name, attr) 48 | 49 | # ensure fill color matches image_chs 50 | if self.fill_color is not None: 51 | if not isinstance(self.fill_color, Sequence): 52 | self.fill_color = (self.fill_color,) 53 | if image_chs == 1 and len(self.fill_color) > image_chs: 54 | self.fill_color = (int(sum(self.fill_color) / len(self.fill_color)),) 55 | if image_chs > 1 and len(self.fill_color) == 1: 56 | self.fill_color = self.fill_color * image_chs 57 | 58 | @classmethod 59 | def empty(cls): 60 | return cls(**{f.name: None for f in fields(cls)}) 61 | 62 | def set_default(self, right, inplace=False): 63 | # set left fields from right fields if left fields are not-initialized (None) 64 | changes = { 65 | f.name: v for f in fields(self) 66 | if (v := getattr(right, f.name)) is not None and getattr(self, f.name) is None 67 | } 68 | if inplace: 69 | for k, v in changes.items(): 70 | setattr(self, k, v) 71 | return self 72 | else: 73 | return replace(self, **changes) 74 | 75 | def merge(self, right, inplace=False): 76 | # merge from right to left for right fields that are not None 77 | changes = {f.name: v for f in fields(self) if (v := getattr(right, f.name)) is not None} 78 | if inplace: 79 | for k, v in changes.items(): 80 | setattr(self, k, v) 81 | return self 82 | else: 83 | return replace(self, **changes) 84 | 85 | 86 | @dataclass 87 | class ImageAugCfg(Serializable): 88 | """ 89 | A simple flat config struct for overriding common augmentation defaults. 90 | 91 | Each image transform type supports different augmentations and have their own defaults, 92 | this struct is intended to override defaults for common values, not necessarily to 93 | cover all cases and define all augmentation possibilities across all schemes. 94 | """ 95 | 96 | # resize scale bounds (1.0 = middle point = same scale) 97 | scale: Optional[Tuple[float, float]] = None 98 | 99 | # resize aspect ratio bounds (1.0 = 1:1) 100 | ratio: Optional[Tuple[float, float]] = None 101 | 102 | # color jitter, per item probs 103 | color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None 104 | 105 | # for simclr, control prob for applying any of the jitter probs above 106 | color_jitter_prob: Optional[float] = None 107 | 108 | # for preprocess w/ grayscale (simclr), control prob of converting to graysacle 109 | grayscale_prob: Optional[float] = None 110 | 111 | gaussian_blur_prob: Optional[float] = None 112 | gaussian_blur_kernel_size: Optional[int] = None 113 | 114 | # probability of applying random-erasing (timm style aug) 115 | re_prob: Optional[float] = None 116 | 117 | # number of random-erasing blocks (timm style aug) 118 | re_count: Optional[int] = None 119 | 120 | @classmethod 121 | def clip(cls, **kwargs): 122 | aug = cls( 123 | scale=(0.9, 1.0), 124 | ratio=(0.75, 1. / 0.75), 125 | ) 126 | aug = replace(aug, **kwargs) 127 | return aug 128 | 129 | @classmethod 130 | def imagenet(cls, **kwargs): 131 | aug = cls( 132 | scale=(0.08, 1.0), 133 | ratio=(0.75, 1. / 0.75), 134 | color_jitter=(0.4, 0.4, 0.4), 135 | ) 136 | aug = replace(aug, **kwargs) 137 | return aug 138 | 139 | @classmethod 140 | def simclr(cls, **kwargs): 141 | aug = cls( 142 | scale=(0.08, 1.0), 143 | ratio=(0.75, 1. / 0.75), 144 | color_jitter=(0.4, 0.4, 0.4, 0.1), 145 | color_jitter_prob=0.8, 146 | grayscale_prob=0.2, 147 | gaussian_blur_prob=0.5, 148 | #gaussian_blur_kernel_size=23, 149 | ) 150 | aug = replace(aug, **kwargs) 151 | return aug 152 | 153 | 154 | # Vision preprocessing config 155 | @dataclass 156 | class PreprocessCfg(Serializable): 157 | image_input: ImageInputCfg = field(default_factory=ImageInputCfg) 158 | aug_cfg: Optional[ImageAugCfg] = None 159 | 160 | 161 | @dataclass 162 | class DataArg(Serializable): 163 | """ Data source argument in an argument friendly form (multiple sources represented in string) 164 | """ 165 | source: str 166 | split: Optional[str] = None 167 | sampling_weight: Optional[str] = None 168 | template: Optional[str] = None # template to transform url for use 169 | num_samples: Optional[Union[int, str]] = None 170 | data_dir: Optional[str] = None 171 | 172 | batch_size: int = 1 173 | format: str = "wds" # e.g. "hfds", "hfids", or "wds" 174 | 175 | resampled: bool = False # sample shards with replacement 176 | multi_interval: bool = True 177 | persistent_workers: bool = True 178 | num_workers: int = 4 179 | 180 | 181 | def split_sources( 182 | source: str, 183 | split: Optional[str] = None, 184 | sampling_weights: Optional[Union[str, List[float]]] = None, 185 | num_samples: Optional[Union[int, str, List[int]]] = None, 186 | ): 187 | if '::' in source: 188 | source_split = source.split('::') 189 | else: 190 | source_split = [source] 191 | num_sources = len(source_split) 192 | 193 | if sampling_weights is not None: 194 | if isinstance(sampling_weights, str): 195 | weights_split = sampling_weights.split('::') 196 | sampling_weights = [float(w) for w in weights_split] 197 | assert len(sampling_weights) == num_sources 198 | 199 | num_samples_per_source = None 200 | if num_samples is not None: 201 | if isinstance(num_samples, str): 202 | num_samples_split = num_samples.split('::') 203 | num_samples = [int(s) for s in num_samples_split] 204 | 205 | try: 206 | len(num_samples) 207 | except Exception: 208 | num_samples_per_source = [None] * num_sources 209 | else: 210 | num_samples_per_source = num_samples 211 | num_samples = sum(num_samples_per_source) 212 | finally: 213 | assert len(num_samples_per_source) == num_sources 214 | 215 | output = [] 216 | for i, s in enumerate(source_split): 217 | output.append(SourceSpec( 218 | url=s, 219 | split=split, 220 | sampling_weight=None if sampling_weights is None else sampling_weights[i], 221 | num_samples=None if num_samples_per_source is None else num_samples_per_source[i], 222 | )) 223 | 224 | return output, num_samples 225 | 226 | 227 | # FIXME add code to resolve shard information from _info.yaml or .json files (see dataset_info.py) 228 | 229 | 230 | def source_to_shard_spec( 231 | source: Union[str, SourceSpec, List[SourceSpec]], 232 | ): 233 | if isinstance(source, str): 234 | source_list = [SourceSpec(url=source)] 235 | elif isinstance(source, SourceSpec): 236 | source_list = [source] 237 | else: 238 | assert isinstance(source[0], SourceSpec) 239 | source_list = source 240 | 241 | # process weights first in case some are set and some are not 242 | if not all(s.sampling_weight is None for s in source_list): 243 | weights = [s.sampling_weight if s.sampling_weight else 1.0 for s in source_list] 244 | else: 245 | weights = [None] * len(source_list) 246 | 247 | all_urls = [] 248 | all_weights = [] 249 | for s, w in zip(source_list, weights): 250 | expanded_urls, expanded_weights = expand_urls(s.url, weights=w) 251 | all_urls.extend(expanded_urls) 252 | if expanded_weights: 253 | all_weights.extend(expanded_weights) 254 | all_weights = all_weights or None 255 | sizes = None # FIXME resolve sizes 256 | 257 | ss = ShardSpec(urls=all_urls, weights=all_weights, sizes=sizes) 258 | return ss 259 | 260 | 261 | @dataclass 262 | class DataCfg(Serializable): 263 | source: Union[str, SourceSpec, List[SourceSpec]] 264 | split: Optional[str] = None 265 | num_samples: Optional[int] = None # overrides num_samples across sources if set 266 | data_dir: Optional[str] = None 267 | 268 | batch_size: Optional[int] = 1 269 | format: str = "wds" # e.g. "hfds", "hfids", or "wds". 270 | 271 | resampled: bool = False # sample shards with replacement 272 | multi_interval: bool = True 273 | persistent_workers: bool = True 274 | num_workers: int = 4 275 | 276 | @classmethod 277 | def from_arg(cls, data_arg: DataArg): 278 | sources, _ = split_sources( 279 | data_arg.source, 280 | data_arg.num_samples, 281 | data_arg.num_samples, 282 | ) 283 | return cls( 284 | source=sources, 285 | num_samples=data_arg.num_samples, 286 | data_dir=data_arg.data_dir, 287 | batch_size=data_arg.batch_size, 288 | format=data_arg.format, 289 | resampled=data_arg.resampled, 290 | multi_interval=data_arg.multi_interval, 291 | persistent_workers=data_arg.persistent_workers, 292 | num_workers=data_arg.num_workers, 293 | ) 294 | 295 | @property 296 | def shard_spec(self): 297 | return source_to_shard_spec(self.source) 298 | 299 | def __post_init__(self): 300 | if self.num_workers == 0: 301 | self.persistent_workers = False 302 | 303 | @dataclass 304 | class DistributedCfg: 305 | world_size: int = 1 306 | local_rank: int = 0 307 | global_rank: int = 0 308 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Hugging Face Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/chug/wds/decode.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | import os 4 | import random 5 | import re 6 | import warnings 7 | from typing import Any, Callable, Dict, Optional, Tuple, Union 8 | 9 | import numpy as np 10 | import webdataset as wds 11 | from PIL import Image 12 | 13 | from .helpers import log_and_continue 14 | 15 | 16 | # IMPORTANT fitz aka PyMuPDF is AGPL licensed w/ a commercial purchase option, manual intervention required to use it. 17 | _USE_AGPL_PYMUPDF = int(os.environ.get('CHUG_USE_AGPL_PYMUPDF', -1)) 18 | if _USE_AGPL_PYMUPDF < 1: 19 | import importlib 20 | if importlib.util.find_spec('fitz') is not None and _USE_AGPL_PYMUPDF == -1: 21 | # warn if not explicitly disable by setting to 0 22 | warnings.warn( 23 | "The fitz/pymupdf library is installed but disabled as the environment variable CHUG_USE_AGPL_PYMUPDF is" 24 | " not set to 1. Please be aware of the licensing concerns in your use cases if you enable it.") 25 | fitz = None 26 | else: 27 | try: 28 | import fitz 29 | except ImportError as e: 30 | fitz = None 31 | 32 | # defaults to pypdfium2 when fitz is not installed and/or not enabled 33 | if fitz is None: 34 | try: 35 | import pypdfium2 36 | except ImportError as e: 37 | pypdfium2 = None 38 | else: 39 | pypdfium2 = None 40 | 41 | _logger = logging.getLogger(__name__) 42 | 43 | 44 | PDF_EXTENSIONS = { 45 | 'pdf', 46 | } 47 | 48 | 49 | def decode_pdf_pages( 50 | data: bytes, 51 | image_mode: str = 'L', 52 | page_indices: Optional[Tuple[int]] = None, 53 | select_random: Optional[int] = None, 54 | render_dpi: int = 144, 55 | ): 56 | rendered_pages = [] 57 | 58 | with io.BytesIO(data) as b: 59 | # FIXME test and use an alternate pdf reader/render as default 60 | if fitz is not None: 61 | doc = fitz.Document(stream=b) 62 | num_doc_pages = doc.page_count 63 | 64 | if page_indices is not None: 65 | page_indices = [p % num_doc_pages for p in page_indices] # support -ve indexing 66 | else: 67 | page_indices = range(num_doc_pages) 68 | 69 | if select_random: 70 | if select_random == 1: 71 | page_indices = [random.choice(page_indices)] 72 | else: 73 | page_indices = random.sample(page_indices, select_random) 74 | page_indices.sort() 75 | 76 | for i in page_indices: 77 | page = doc.load_page(i) 78 | if image_mode == 'L': 79 | fitz_cs = fitz.csGRAY 80 | fitz_mode = 'L' 81 | alpha = False 82 | elif image_mode == 'RGB' or image_mode == 'BGR': 83 | fitz_cs = fitz.csRGB 84 | fitz_mode = 'RGB' 85 | alpha = False 86 | elif image_mode == 'RGBA' or image_mode == 'BGRA': 87 | fitz_cs = fitz.csRGB 88 | fitz_mode = 'RGBA' 89 | alpha = True 90 | else: 91 | assert False 92 | pixmap = page.get_pixmap(dpi=render_dpi, colorspace=fitz_cs, alpha=alpha) 93 | page_image = Image.frombuffer(fitz_mode, (pixmap.width, pixmap.height), pixmap.samples) 94 | if fitz_mode != page_image.mode: 95 | page_image = page_image.convert(image_mode) 96 | 97 | rendered_pages += [page_image] 98 | 99 | elif pypdfium2 is not None: 100 | grayscale = image_mode == "L" 101 | reverse = 'RGB' in image_mode 102 | doc = pypdfium2.PdfDocument(data) 103 | num_doc_pages = len(doc) 104 | page_indices = page_indices or range(num_doc_pages) 105 | if select_random: 106 | page_indices = [random.choice(page_indices)] 107 | for i in page_indices: 108 | page = doc[i] 109 | page_image = page.render( 110 | scale=render_dpi / 72, 111 | grayscale=grayscale, 112 | rev_byteorder=reverse, # RGB instead of BGR(X) 113 | ).to_pil() 114 | if image_mode != page_image.mode: 115 | page_image = page_image.convert(image_mode) 116 | 117 | rendered_pages += [page_image] 118 | else: 119 | assert False, "No PDF decoding library installed, please install one of pypdfium2 or fitz (PyMuPDF). " \ 120 | "NOTE: pypdifum2 is Apache 2.0 / BSD 3.0 licensed and fitz is AGPL." 121 | 122 | return rendered_pages, num_doc_pages 123 | 124 | 125 | def decode_image_pages( 126 | data: bytes, 127 | image_mode: str = 'L', 128 | page_indices: Optional[Tuple[int]] = None, 129 | select_random: Optional[int] = None, 130 | ): 131 | """ decode multi-page image (e.g. TIFF)""" 132 | decoded_pages = [] 133 | 134 | if isinstance(data, Image.Image): 135 | doc_image = data 136 | else: 137 | doc_image = Image.open(io.BytesIO(data)) 138 | 139 | num_image_pages = getattr(doc_image, 'n_frames', 1) 140 | 141 | if page_indices is not None: 142 | page_indices = [p % num_image_pages for p in page_indices] # support -ve indexing 143 | else: 144 | page_indices = range(num_image_pages) 145 | 146 | if select_random: 147 | if select_random == 1: 148 | page_indices = [random.choice(page_indices)] 149 | else: 150 | page_indices = random.sample(page_indices, select_random) 151 | page_indices.sort() 152 | 153 | for i, page_index in enumerate(page_indices): 154 | assert page_index < num_image_pages 155 | if num_image_pages > 1: 156 | doc_image.seek(page_index) 157 | else: 158 | assert page_index == 0, "not a multi-page image" 159 | doc_image.load() 160 | 161 | page_image = doc_image.convert(image_mode) 162 | decoded_pages.append(page_image) 163 | 164 | return decoded_pages, num_image_pages 165 | 166 | 167 | class DecodeDoc: 168 | 169 | def __init__( 170 | self, 171 | imagespec, 172 | num_pages=1, 173 | page_sampling='first', 174 | ): 175 | """Create a PDF handler. 176 | 177 | Args: 178 | imagespec: short string indicating the type of decoding 179 | The `imagespec` specifies whether the image is decoded 180 | to numpy/torch/pi, decoded to uint8/float, and decoded 181 | to l/rgb/rgba: 182 | 183 | - l8: numpy uint8 l 184 | - rgb8: numpy uint8 rgb 185 | - rgba8: numpy uint8 rgba 186 | - l: numpy float l 187 | - rgb: numpy float rgb 188 | - rgba: numpy float rgba 189 | - torchl8: torch uint8 l 190 | - torchrgb8: torch uint8 rgb 191 | - torchrgba8: torch uint8 rgba 192 | - torchl: torch float l 193 | - torchrgb: torch float rgb 194 | - torch: torch float rgb 195 | - torchrgba: torch float rgba 196 | - pill: pil None l 197 | - pil: pil None rgb 198 | - pilrgb: pil None rgb 199 | - pilrgba: pil None rgba 200 | 201 | """ 202 | if imagespec not in list(wds.autodecode.imagespecs.keys()): 203 | raise ValueError("Unknown imagespec: %s" % imagespec) 204 | self.imagespec = imagespec.lower() 205 | # FIXME need to work out padding / selection issues for multi-page support 206 | assert num_pages == 1, "Only 1-page decoding supported at present" 207 | self.num_pages = num_pages 208 | assert page_sampling in {'random', 'first', 'last'} # TODO add 'all' w/ multi-page support 209 | self.page_sampling = page_sampling 210 | 211 | def __call__(self, key, data): 212 | """ 213 | Args: 214 | key: file name extension 215 | data: data to be decoded 216 | """ 217 | extension = re.sub(r".*[.]", "", key) 218 | if extension not in {'pdf', 'tiff', 'tif'}: 219 | return None 220 | 221 | imagespec = self.imagespec 222 | atype, etype, mode = wds.autodecode.imagespecs[imagespec] 223 | 224 | select_random = False 225 | if self.page_sampling == 'random': 226 | page_indices = None 227 | select_random = True 228 | elif self.page_sampling == 'first': 229 | page_indices = [0] # first page 230 | elif self.page_sampling == 'last': 231 | page_indices = [-1] 232 | else: 233 | assert False 234 | 235 | if extension == 'pdf': 236 | # pdf document 237 | result, num_pages = decode_pdf_pages( 238 | data, 239 | image_mode=mode.upper(), 240 | page_indices=page_indices, 241 | select_random=select_random, 242 | ) 243 | else: 244 | # multi-page image doc (e.g. tiff) 245 | result, num_pages = decode_image_pages( 246 | data, 247 | image_mode=mode.upper(), 248 | page_indices=page_indices, 249 | select_random=select_random, 250 | ) 251 | 252 | if atype == "pil": 253 | return result 254 | 255 | result = np.asarray(result) 256 | 257 | if etype == "float": 258 | result = result.astype(np.float32) / 255.0 259 | 260 | assert result.ndim in [2, 3], result.shape 261 | assert mode in ["l", "rgb", "rgba"], mode 262 | 263 | if mode == "l": 264 | if result.ndim == 3: 265 | result = np.mean(result[:, :, :3], axis=2) 266 | elif mode == "rgb": 267 | if result.ndim == 2: 268 | result = np.repeat(result[:, :, np.newaxis], 3, axis=2) 269 | elif result.shape[2] == 4: 270 | result = result[:, :, :3] 271 | elif mode == "rgba": 272 | if result.ndim == 2: 273 | result = np.repeat(result[:, :, np.newaxis], 4, axis=2) 274 | result[:, :, 3] = 255 275 | elif result.shape[2] == 3: 276 | result = np.concatenate( 277 | [result, 255 * np.ones(result.shape[:2])], axis=2 278 | ) 279 | 280 | assert atype in ["numpy", "torch"], atype 281 | 282 | if atype == "numpy": 283 | return result 284 | elif atype == "torch": 285 | import torch 286 | 287 | if result.ndim == 3: 288 | return torch.from_numpy(result.transpose(2, 0, 1)) 289 | else: 290 | return torch.from_numpy(result) 291 | 292 | return None 293 | 294 | 295 | def create_image_decoder( 296 | decode_fn: Optional[Callable] = None, 297 | image_mode: str = "RGB", 298 | enable_doc: bool = False, # FIXME enable doc support by default once tested? 299 | handler: Callable = log_and_continue, 300 | ): 301 | if decode_fn is None: 302 | if image_mode == "L": 303 | img_type = "pill" 304 | elif image_mode == "RGB": 305 | img_type = "pilrgb" 306 | else: 307 | assert False, f"Unsupported image_mode ({image_mode})" 308 | if enable_doc: 309 | # FIXME, generic img + pdf decode WIP 310 | decode_fn = wds.decode(DecodeDoc(img_type), img_type, handler=handler) 311 | else: 312 | decode_fn = wds.decode(img_type, handler=handler) 313 | elif isinstance(decode_fn, (list, tuple)): 314 | decode_fn = wds.decode(*decode_fn, handler=handler) 315 | else: 316 | assert isinstance(decode_fn, Callable) 317 | decode_fn = wds.map(decode_fn) 318 | 319 | return decode_fn 320 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chugging Data 2 | 3 | A library to help w/ efficient training for multi-modal data. Initially focused on image & document + text tasks. 4 | 5 | `chug` currently leverages `webdataset` and Hugging Face `datasets`. 6 | 7 | `webdataset` tar files and dataset pipelines are preferred for scalable pretraining. 8 | 9 | Hugging Face `datasets` are supported and work great for exploration, validation, and fine-tune use cases. 10 | 11 | `chug` provides on the fly PDF decoding and rendering via either pypdfium2 (https://github.com/pypdfium2-team/pypdfium2) as a default, or fitz/pymupdf (https://github.com/pymupdf/PyMuPDF) if your use case is okay with their AGPL-3.0 license. `fitz` support must be manually enabled. The pdf handling is implemented at the webdataset level, so you can plug it in to other webdataset pipelines. This enables large scale sharded streaming of native .pdf files without needing to pre-render to .png/.tiff, etc. 12 | 13 | ## Status 14 | 15 | This library is still a WIP, consider this an alpha release (pre announcement). Major features should be working, the library has been tested with several PDF datasets we will shortly make public. However, do expect there will still be breaking changes, lots of improvements, etc. 16 | 17 | `pip install --pre chug` will install the current dev version. 18 | 19 | ### TODOs 20 | 21 | ### Nearish 22 | * Cleanup and refinement, codebase will change 23 | * Documentation & unit-tests 24 | * Support reading of info .json/.yaml files for automatic shard info resolution for webdatasets (like timm) 25 | 26 | ### Mediumish 27 | * Option to output bbox annotations for lines (or word + word output) for tasks that leverage layout 28 | * Unified preprocessor functions for combined image + text tokenization (img+text token interleaving, etc.) 29 | * Image token (patch) packing ala NaViT. Online bin packing based algorithms integrated with image preprocessing and pipeline. 30 | 31 | ### Longish 32 | * Increase range of task pipelines for other tasks, modelling needs 33 | * Support additional modalities & targets (video, audio, detection/dense pixel targets, image/video/audio targets) 34 | * Explore alternatives to .tar shards (array_record, arrow, etc) 35 | 36 | ## Design 37 | 38 | ### Submodule Hierarchy 39 | 40 | The library has been designed so that functions, classes at different levels can be used independently. 41 | 42 | If one wants to build a loader & pipeline with JSON/YAML serializable configs, use the top-level `chug.create_loader()` in `chug/loader.py`. Depending on dataset sources, one can easily switch this between webdataset, HF datasets (in the future, other sources). 43 | 44 | Bypassing the highest level, one can also call `build_pipeline_*` methods in `task_pipeline` and then call `create_loader_wds` with a full array of args for `wds` only use cases. 45 | 46 | If one doesn't want to use `chug` loaders and pipelines at all, `image`, `text`, and `wds` (especially decoder) functionality may be useful in other projects. 47 | 48 | #### Library modules (highest to lowest level) 49 | 50 | The dependencies of modules within the library are intended to follow the hierarchy below. e.g. doc depends on wds, but wds should never depend on doc. 51 | 52 | ``` 53 | app 54 | | 55 | loader (chug/loader.py) 56 | | 57 | task_pipeline 58 | | 59 | doc 60 | | 61 | wds, hfds, image, text 62 | | 63 | common 64 | ``` 65 | 66 | ### Submodules 67 | 68 | #### `common` 69 | 70 | Configs, structures (dataclasses) for general use across the library 71 | 72 | #### `wds` 73 | 74 | Webdataset (`wds` for short) specific code. Extensions and alterations of webdataset functionality to fit covered use case and improve robustness. 75 | 76 | All data pipelines in `chug` currently leverage `wds` pipelines, even when not using `wds` datasets. 77 | 78 | Document oriented decoding (pdf decoder) is present in `chug/wds/decode.py`, it can be used with any webdataset pipeline as a decoder. e.g. `wds.decode(chug.wds.DecodeDoc('pill'), 'pill')` 79 | 80 | #### `hfds` 81 | 82 | Hugging Face `datasets` support. A minimal wrapper that allows `datasets` to be used with chug processing pipelines. 83 | 84 | The processing pipelines remain webdataset based when using `datasets`, they are invoked by a custom collate class. 85 | 86 | #### `image` 87 | 88 | Image processing, `torchvision` and `albumentations` based transform building code. A mix of generic image (imagenet, simclr) transforms and document specific transforms, including an implementation of `albumentations` based `nougat` transforms. 89 | 90 | #### `text` 91 | 92 | Text processing, tokenization code. 93 | 94 | #### `doc` 95 | 96 | Document processing code. Currently focused on processors that apply image/pdf decoders and process document OCR or VQA annotations. 97 | 98 | #### `task_pipeline` 99 | 100 | Task specific pipelines, where dataset formats meet modelling needs. 101 | 102 | Inputs to task pipelines are sample dictionaries based on the dataset form, they are decoded and then processed into outputs that match model input requirements. 103 | 104 | Task specific pipelines that handle the data <--> model input interface are inserted into an encompassing data pipeline which handles shard lists, shuffle, wrapping, distributed worker, splitting, batching, etc. 105 | 106 | #### `chug.loader` 107 | 108 | This lone top-level file includes the main factory methods for creating loaders w/ associated pipelines from config dataclasses. 109 | 110 | #### `app` 111 | 112 | Most applications using `chug` will exist outside of the lib in training libraries, etc. Some builtin utility / exploration apps will be included here. 113 | 114 | ## Concepts 115 | 116 | WIP 117 | 118 | ## Datasets 119 | 120 | Datasets that work well with this library can be found on the Hugging Face Hub under the `pixparse` organization (https://huggingface.co/pixparse). 121 | 122 | We'll add links to other noteworthy datasets that can be used as we become aware of them. 123 | 124 | 125 | ## Usage / Examples 126 | 127 | ### Document Reading, Training w/ IDL 128 | ```python 129 | import chug 130 | img_cfg = chug.ImageInputCfg(size=(1024, 768), transform_type='doc_better') 131 | img_fn = chug.create_image_preprocessor(input_cfg=img_cfg, is_training=True) 132 | txt_fn = chug.create_text_preprocessor( 133 | 'naver-clova-ix/donut-base', 134 | prompt_end_token='', 135 | task_start_token='', # NOTE needs to be added to tokenizer 136 | ) 137 | 138 | task_cfg = chug.DataTaskDocReadCfg( 139 | image_process_fn=img_fn, 140 | text_process_fn=txt_fn, 141 | page_sampling='random', 142 | error_handler='dump_and_reraise', 143 | ) 144 | data_cfg = chug.DataCfg( 145 | source='pipe:curl -s -f -L https://huggingface.co/datasets/pixparse/idl-wds/resolve/main/idl-train-0{0000..2999}.tar', 146 | batch_size=8, 147 | num_samples=3144726, 148 | format='wds', 149 | ) 150 | lb = chug.create_loader( 151 | data_cfg, 152 | task_cfg, 153 | is_training=True, 154 | ) 155 | ii = iter(lb) 156 | sample = next(ii) 157 | ``` 158 | 159 | ### Document Reading, Exploring IDL 160 | ```python 161 | import chug 162 | task_cfg = chug.DataTaskDocReadCfg(page_sampling='all') 163 | data_cfg = chug.DataCfg( 164 | source='pixparse/idl-wds', 165 | split='train', 166 | batch_size=None, 167 | format='hfids', 168 | num_workers=0, 169 | ) 170 | lb = chug.create_loader( 171 | data_cfg, 172 | task_cfg, 173 | ) 174 | ii = iter(lb) 175 | sample = next(ii) 176 | ``` 177 | 178 | ### Document Reading, Training with PDFA 179 | 180 | ```python 181 | import chug 182 | img_cfg = chug.ImageInputCfg(size=(1024, 768), transform_type='doc_nougat') 183 | img_fn = chug.create_image_preprocessor(input_cfg=img_cfg, is_training=True) 184 | txt_fn = chug.create_text_preprocessor( 185 | 'naver-clova-ix/donut-base', 186 | prompt_end_token='', 187 | task_start_token='', # NOTE needs to be added to tokenizer 188 | ) 189 | 190 | task_cfg = chug.DataTaskDocReadCfg( 191 | image_process_fn=img_fn, 192 | text_process_fn=txt_fn, 193 | page_sampling='random', 194 | ) 195 | data_cfg = chug.DataCfg( 196 | source='pipe:curl -s -f -L https://huggingface.co/datasets/pixparse/pdfa-english-train/resolve/main/pdfa-eng-train-{000000..005000}.tar', 197 | batch_size=8, 198 | num_samples=1000000, # FIXME replace with actual 199 | format='wds', 200 | ) 201 | lb = chug.create_loader( 202 | data_cfg, 203 | task_cfg, 204 | is_training=True, 205 | ) 206 | ii = iter(lb) 207 | sample = next(ii) 208 | ``` 209 | 210 | ### Document Reading, Exploring PDFA 211 | 212 | ```python 213 | import chug 214 | 215 | task_cfg = chug.DataTaskDocReadCfg( 216 | page_sampling='all', 217 | ) 218 | data_cfg = chug.DataCfg( 219 | source='pixparse/pdfa-eng-wds', 220 | split='train', 221 | batch_size=None, 222 | format='hfids', 223 | num_workers=0, 224 | ) 225 | lb = chug.create_loader( 226 | data_cfg, 227 | task_cfg, 228 | ) 229 | ii = iter(lb) 230 | sample = next(ii) 231 | ``` 232 | 233 | 234 | ### Image + Text 235 | 236 | ### Training 237 | 238 | ```python 239 | import chug 240 | import transformers 241 | from functools import partial 242 | img_cfg = chug.ImageInputCfg(size=(512, 512), transform_type='image_timm') 243 | img_fn = chug.create_image_preprocessor(input_cfg=img_cfg, is_training=True) 244 | tokenizer = transformers.AutoTokenizer.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K') 245 | txt_fn = partial(chug.tokenize, max_length=1000, tokenizer=tokenizer) 246 | task_cfg = chug.DataTaskImageTextCfg( 247 | image_process_fn=img_fn, 248 | text_process_fn=txt_fn, 249 | ) 250 | data_cfg = chug.DataCfg( 251 | source='pipe:curl -s -f -L https://huggingface.co/datasets/pixparse/cc12m-wds/resolve/main/cc12m-train-{0000..2175}.tar', 252 | batch_size=8, 253 | num_samples=10968539, 254 | format='wds', 255 | ) 256 | lb = chug.create_loader( 257 | data_cfg, 258 | task_cfg, 259 | is_training=True, 260 | ) 261 | ii = iter(lb) 262 | sample = next(ii) 263 | ``` 264 | 265 | ### Document VQA 266 | 267 | #### Training, Fine-tuning 268 | ```python 269 | import chug 270 | from chug.task_pipeline import create_task_pipeline 271 | img_cfg = chug.ImageInputCfg(size=(1024, 768), transform_type='doc_basic') 272 | img_fn = chug.create_image_preprocessor(img_cfg, is_training=True) 273 | txt_fn = chug.create_text_preprocessor( 274 | 'naver-clova-ix/donut-base-finetuned-docvqa', 275 | prompt_end_token='', 276 | task_start_token='', 277 | ) 278 | 279 | task_cfg = chug.DataTaskDocVqaCfg( 280 | image_process_fn=img_fn, 281 | text_process_fn=txt_fn, 282 | ) 283 | data_cfg = chug.DataCfg( 284 | source='pipe:curl -s -f -L https://huggingface.co/datasets/pixparse/docvqa-wds/resolve/main/docvqa-train-{000..383}.tar', 285 | batch_size=8, 286 | format='wds', 287 | num_samples=39463, 288 | ) 289 | lb = chug.create_loader( 290 | data_cfg, 291 | task_cfg, 292 | is_training=True, 293 | ) 294 | ii = iter(lb) 295 | sample = next(ii) 296 | ``` 297 | 298 | #### Exploration 299 | 300 | ```python 301 | import chug 302 | from chug.task_pipeline import create_task_pipeline 303 | task_cfg = chug.DataTaskDocVqaCfg( 304 | question_prefix='Question: ', 305 | question_suffix='', 306 | answer_prefix='Answer: ', 307 | answer_suffix='' 308 | ) 309 | data_cfg = chug.DataCfg( 310 | source='pixparse/docvqa-single-page-questions', 311 | split='validation', 312 | batch_size=None, 313 | format='hfids', 314 | num_workers=0, 315 | ) 316 | lb = chug.create_loader( 317 | data_cfg, 318 | task_cfg 319 | ) 320 | ii = iter(lb) 321 | sample = next(ii) 322 | ``` 323 | 324 | ## Acknowledgement 325 | 326 | `chug` evolve from the `webdataset` datapipeline used successfully in the [OpenCLIP](https://github.com/mlfoundations/open_clip) project. Thanks to all the contributors in that project. Future work will likely involve closing the loop and leveraging `chug` in OpenCLIP for increased capability. 327 | 328 | The image/document augmentations in `chug` rely on a number of external influences. Our document oriented `doc_better` torchvision augmentations are influenced by `nougat`, and the `doc_nougat` is a direct adaptation of the [`albumentations`](https://albumentations.ai/) + `cv2` document pipeline in [`nougat`](https://github.com/facebookresearch/nougat). Several image augmentations leverage existing work in the `timm` library. 329 | 330 | Also, big thanks to the maintainers of [`webdataset`](https://github.com/webdataset/webdataset) and Hugging Face [`datasets`](https://github.com/huggingface/datasets). 331 | -------------------------------------------------------------------------------- /src/chug/image/build_transforms_doc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from chug.common import ImageInputCfg, ImageAugCfg 4 | 5 | from .transforms_torch import AlignLongAxis, Bitmap, Erosion, Dilation, CropMargin 6 | 7 | # NOTE, chug currently depends on some time aug impl, this should be flipped if timm ends up 8 | # leveraging chug data pipelines. 9 | from timm.data import str_to_interp_mode, ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad 10 | 11 | from torchvision import transforms 12 | 13 | 14 | def build_transforms_doc_basic( 15 | input_cfg: ImageInputCfg, 16 | is_training: bool = False, 17 | do_normalize: bool = True, 18 | aug_cfg: Optional[ImageAugCfg] = None, 19 | composed: bool = True, 20 | ): 21 | # an improved torchvision + custom op transforms (no albumentations) 22 | image_size = input_cfg.size 23 | interpolation_mode = str_to_interp_mode(input_cfg.interpolation) 24 | 25 | pp = [] 26 | 27 | if input_cfg.crop_margin: 28 | pp += [CropMargin()] 29 | 30 | if input_cfg.align_long_axis: 31 | pp += [AlignLongAxis(image_size, interpolation=interpolation_mode)] 32 | 33 | if is_training: 34 | pp += [ 35 | RandomCropOrPad(image_size, fill=input_cfg.fill_color), 36 | transforms.CenterCrop(image_size), 37 | ] 38 | else: 39 | pp += [ 40 | ResizeKeepRatio(image_size, longest=1, interpolation=input_cfg.interpolation), 41 | CenterCropOrPad(image_size, fill=input_cfg.fill_color), 42 | ] 43 | 44 | pp += [transforms.ToTensor()] 45 | 46 | if do_normalize: 47 | pp += [transforms.Normalize(input_cfg.mean, input_cfg.std)] 48 | 49 | return transforms.Compose(pp) if composed else pp 50 | 51 | 52 | def build_transforms_doc_better( 53 | input_cfg: ImageInputCfg, 54 | is_training: bool = False, 55 | do_normalize: bool = True, 56 | aug_cfg: Optional[ImageAugCfg] = None, 57 | composed: bool = True, 58 | ): 59 | # an improved torchvision + custom op transforms (no albumentations) 60 | image_size = input_cfg.size 61 | interpolation_mode = str_to_interp_mode(input_cfg.interpolation) 62 | pp = [] 63 | 64 | if input_cfg.crop_margin: 65 | pp += [CropMargin()] 66 | 67 | if input_cfg.align_long_axis: 68 | pp += [AlignLongAxis(image_size, interpolation=interpolation_mode)] 69 | 70 | if is_training: 71 | # FIXME merge defaults w/ aug_cfg 72 | defaults = dict( 73 | scale_prob=0.05, 74 | scale_range=(0.85, 1.04), 75 | ratio_prob=0.05, 76 | ratio_range=(.9, 1.11), 77 | bitmap_prob=0.55, 78 | erosion_dilation_prob=0.02, 79 | shear_prob=0.05, 80 | shear_range_x=(0, 3.), 81 | shear_range_y=(-3, 0), 82 | shift_scale_rotate_prob=0.03, 83 | shift_range_x=0.04, 84 | shift_range_y=0.03, 85 | rotate_range=3, 86 | elastic_prob=0.04, 87 | elastic_alpha=50., 88 | elastic_sigma=12., 89 | brightness_contrast_prob=0.04, 90 | brightness_range=0.1, 91 | contrast_range=0.1, 92 | gaussian_blur_prob=0.03, 93 | gaussian_blur_kernel=3, 94 | ) 95 | params = defaults 96 | 97 | pp += [ 98 | ResizeKeepRatio( 99 | image_size, 100 | longest=1, 101 | interpolation=input_cfg.interpolation, 102 | random_scale_prob=params['scale_prob'], 103 | random_scale_range=params['scale_range'], 104 | random_aspect_prob=params['ratio_prob'], 105 | random_aspect_range=params['ratio_range'], 106 | ), 107 | transforms.RandomApply([ 108 | Bitmap() 109 | ], 110 | p=params['bitmap_prob'] 111 | ), 112 | transforms.RandomApply([ 113 | transforms.RandomChoice([ 114 | Erosion(3), 115 | Dilation(3), 116 | ])], 117 | p=params['erosion_dilation_prob'] 118 | ), 119 | transforms.RandomApply([ 120 | transforms.RandomAffine( 121 | degrees=0, 122 | shear=params['shear_range_x'] + params['shear_range_y'], 123 | interpolation=interpolation_mode, 124 | fill=input_cfg.fill_color, 125 | )], 126 | p=params['shear_prob'], 127 | ), 128 | transforms.RandomApply([ 129 | transforms.RandomAffine( 130 | degrees=params['ratio_range'], 131 | translate=(params['shift_range_x'], params['shift_range_y']), 132 | interpolation=interpolation_mode, 133 | fill=input_cfg.fill_color, 134 | )], 135 | p=params['shift_scale_rotate_prob'], 136 | ), 137 | transforms.RandomApply([ 138 | transforms.ElasticTransform( 139 | alpha=params['elastic_alpha'], 140 | sigma=params['elastic_sigma'], 141 | interpolation=interpolation_mode, 142 | fill=input_cfg.fill_color, 143 | )], 144 | p=params['elastic_prob'], 145 | ), 146 | transforms.RandomApply([ 147 | transforms.ColorJitter( 148 | brightness=params['brightness_range'], 149 | contrast=params['contrast_range'], 150 | )], 151 | p=params['brightness_contrast_prob'], 152 | ), 153 | transforms.RandomApply([ 154 | transforms.GaussianBlur( 155 | params['gaussian_blur_kernel'], 156 | sigma=(0.1, 0.8), 157 | )], 158 | p=params['gaussian_blur_prob'], 159 | ), 160 | RandomCropOrPad(image_size, fill=input_cfg.fill_color), 161 | transforms.CenterCrop(image_size), 162 | ] 163 | else: 164 | pp += [ 165 | ResizeKeepRatio(image_size, longest=1, interpolation=input_cfg.interpolation), 166 | CenterCropOrPad(image_size, fill=input_cfg.fill_color), 167 | ] 168 | 169 | pp += [transforms.ToTensor()] 170 | 171 | if do_normalize: 172 | pp += [transforms.Normalize(input_cfg.mean, input_cfg.std)] 173 | 174 | return transforms.Compose(pp) if composed else pp 175 | 176 | 177 | def build_transforms_doc_nougat( 178 | input_cfg: ImageInputCfg, 179 | is_training: bool = False, 180 | do_normalize: bool = True, 181 | aug_cfg: Optional[ImageAugCfg] = None, 182 | composed: bool = True, 183 | ): 184 | import albumentations as alb 185 | from chug.image.transforms_alb import BitmapAlb, ErosionAlb, DilationAlb, AlbWrapper, CropMarginCv2 186 | 187 | # albumentations + custom opencv transforms from nougat 188 | image_size = input_cfg.size 189 | if input_cfg.interpolation == 'bilinear': 190 | interpolation_mode = 1 191 | else: 192 | interpolation_mode = 2 # bicubic 193 | border_mode = 0 194 | 195 | tv_pp = [] 196 | alb_pp = [] 197 | 198 | if input_cfg.crop_margin: 199 | tv_pp += [CropMarginCv2()] 200 | 201 | if input_cfg.align_long_axis: 202 | tv_pp += [AlignLongAxis(image_size)] 203 | 204 | if is_training: 205 | # FIXME merge defaults w/ aug_cfg 206 | defaults = dict( 207 | #scale_prob=0.05, 208 | scale_range=(0.85, 1.03), # done as part of shift_scale_rotate 209 | #ratio_prob=0.05, 210 | #ratio_range=(.9, 1.11), 211 | bitmap_prob=0.05, 212 | erosion_dilation_prob=0.02, 213 | erosion_dilation_scale=(2, 3), 214 | shear_prob=0.03, 215 | shear_range_x=(0, 3.), 216 | shear_range_y=(-3, 0), 217 | shift_scale_rotate_prob=0.03, 218 | shift_range_x=(0, 0.04), 219 | shift_range_y=(0, 0.03), 220 | rotate_range=2., 221 | grid_distort_prob=0.04, 222 | grid_distort_range=0.05, 223 | elastic_prob=0.04, 224 | elastic_alpha=50., 225 | elastic_sigma=12., 226 | brightness_contrast_prob=0.03, 227 | brightness_range=0.1, 228 | constrast_range=0.1, 229 | gaussian_noise_prob=0.08, 230 | gaussian_noise_range=20., # variance range 231 | gaussian_blur_prob=0.03, 232 | gaussian_blur_kernel_range=(3, 3), 233 | image_compression_prob=0.1, 234 | ) 235 | params = defaults 236 | scale_range_centered = tuple(x - 1 for x in params['scale_range']) 237 | params['scale_range'] = scale_range_centered 238 | 239 | tv_pp += [ 240 | # this should be equivalent to initial resize & pad in Donut prepare_input() 241 | ResizeKeepRatio(image_size, longest=1, interpolation=input_cfg.interpolation), 242 | RandomCropOrPad(image_size, fill=input_cfg.fill_color), 243 | ] 244 | 245 | alb_pp += [ 246 | BitmapAlb(p=params['bitmap_prob']), 247 | alb.OneOf([ 248 | ErosionAlb(params['erosion_dilation_scale']), 249 | DilationAlb(params['erosion_dilation_scale']) 250 | ], 251 | p=params['erosion_dilation_prob'] 252 | ), 253 | alb.Affine( 254 | shear={ 255 | "x": params['shear_range_x'], 256 | "y": params['shear_range_y'] 257 | }, 258 | cval=input_cfg.fill_color, 259 | p=params['shear_prob'] 260 | ), 261 | alb.ShiftScaleRotate( 262 | shift_limit_x=params['shift_range_x'], 263 | shift_limit_y=params['shift_range_y'], 264 | scale_limit=params['scale_range'], 265 | rotate_limit=params['rotate_range'], 266 | border_mode=border_mode, 267 | interpolation=interpolation_mode, 268 | value=input_cfg.fill_color, 269 | p=params['shift_scale_rotate_prob'], 270 | ), 271 | alb.GridDistortion( 272 | distort_limit=params['grid_distort_range'], 273 | border_mode=border_mode, 274 | interpolation=interpolation_mode, 275 | value=input_cfg.fill_color, 276 | p=params['grid_distort_prob'], 277 | ), 278 | alb.Compose( 279 | [ 280 | alb.Affine( 281 | translate_px=(0, 5), 282 | always_apply=True, 283 | cval=input_cfg.fill_color, 284 | ), 285 | alb.ElasticTransform( 286 | p=1.0, 287 | alpha=params['elastic_alpha'], 288 | sigma=params['elastic_sigma'], 289 | alpha_affine=12., # FIXME no common param, alpha_affine unique to alb 290 | border_mode=border_mode, 291 | value=input_cfg.fill_color, 292 | ), 293 | ], 294 | p=params['elastic_prob'], 295 | ), 296 | alb.RandomBrightnessContrast( 297 | brightness_limit=params['brightness_range'], 298 | contrast_limit=params['constrast_range'], 299 | brightness_by_max=True, 300 | p=params['brightness_contrast_prob'], 301 | ), 302 | alb.ImageCompression( 303 | quality_lower=95, 304 | p=params['image_compression_prob'], 305 | ), 306 | alb.GaussNoise( 307 | var_limit=params['gaussian_noise_range'], 308 | p=params['gaussian_noise_prob'] 309 | ), 310 | alb.GaussianBlur( 311 | blur_limit=params['gaussian_blur_kernel_range'], 312 | p=params['gaussian_blur_prob'], 313 | ), 314 | ] 315 | else: 316 | # inference / eval 317 | tv_pp += [ 318 | ResizeKeepRatio(image_size, longest=1, interpolation=input_cfg.interpolation), 319 | CenterCropOrPad(image_size, fill=input_cfg.fill_color), 320 | ] 321 | 322 | #alb_pp += [alb.pytorch.ToTensorV2()] 323 | if alb_pp: 324 | # FIXME leave alb uncomposed too if composed=False? 325 | tv_pp += [AlbWrapper(alb.Compose(alb_pp))] 326 | 327 | tv_pp += [transforms.ToTensor()] 328 | if do_normalize: 329 | tv_pp += [transforms.Normalize(input_cfg.mean, input_cfg.std)] 330 | 331 | return transforms.Compose(tv_pp) if composed else tv_pp 332 | --------------------------------------------------------------------------------