├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── COCO_DATA_FORMAT.md ├── CODE_OF_CONDUCT.md ├── DATA_PREPARATION.md ├── IRIS_DATA_FORMAT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── TSV_FORMAT.md ├── UTIL_COMMANDS.md ├── setup.py ├── tests ├── __init__.py ├── image_matting_test_data.zip ├── resources │ ├── __init__.py │ └── util.py ├── test_box_alteration.py ├── test_coco_adaptor │ ├── __init__.py │ ├── coco_adaptor_base.py │ ├── test_image_caption_coco_manifest_adaptor.py │ ├── test_image_matting_coco_manifest_adaptor.py │ ├── test_image_regression_coco_manifest_adaptor.py │ ├── test_image_text_matching_coco_manifest_adaptor.py │ ├── test_key_value_pair_coco_manifest_adaptor.py │ ├── test_multiclass_classification_coco_manifest_adaptor.py │ ├── test_multilabel_classification_coco_manifest_adaptor.py │ ├── test_multitask_coco_manifest_adaptor.py │ ├── test_object_detection_coco_manifest_adaptor.py │ ├── test_text_2_image_retrieval_coco_manifest_adaptor.py │ ├── test_visual_object_grounding_coco_manifest_adaptor.py │ └── test_visual_question_anwsering_coco_manifest_adaptor.py ├── test_coco_format_manifest.py ├── test_coco_iris_to_kvp_wrapper │ ├── test_classification_as_kvp.py │ ├── test_detection_as_kvp.py │ └── test_vqa_as_kvp.py ├── test_dataset_downloader.py ├── test_dataset_manifest.py ├── test_dataset_registry.py ├── test_detection_as_classification.py ├── test_file_reader.py ├── test_fixtures.py ├── test_ic_od_as_image_text_dataset.py ├── test_iris_format_manifest.py ├── test_manifest │ ├── __init__.py │ ├── test_key_value_pair.py │ ├── test_pickleable.py │ └── test_visual_object_grounding.py ├── test_manifest_dataset.py ├── test_manifest_to_coco │ ├── __init__.py │ └── test_manifest_to_coco.py ├── test_merge_manifest │ ├── __init__.py │ └── test_manifest_merge.py ├── test_sample │ ├── __init__.py │ └── test_sample_manfiest.py ├── test_spawn │ ├── __init__.py │ └── test_spawn.py ├── test_stand_along_image_dict_generation │ ├── __init__.py │ └── test_generate_standalone_image_dict.py ├── test_torch_dataest │ ├── __init__.py │ └── test_torch_dataset.py └── test_utils.py ├── tox.ini └── vision_datasets ├── __init__.py ├── commands ├── __init__.py ├── check_dataset.py ├── converter_od_to_ic.py ├── converter_to_aml_coco.py ├── converter_to_line_oriented_format.py ├── converter_tsv_to_coco.py ├── download_dataset.py ├── list_operations_by_data_type.py ├── transform_images.py └── utils.py ├── common ├── __init__.py ├── base64_utils.py ├── constants.py ├── data_manifest │ ├── __init__.py │ ├── coco_manifest_adaptor.py │ ├── data_manifest.py │ ├── iris_data_manifest_adaptor.py │ ├── operations │ │ ├── __init__.py │ │ ├── balanced_instance_weights_generator.py │ │ ├── filter.py │ │ ├── generate_coco.py │ │ ├── generate_stand_alone_image_list_base.py │ │ ├── merge.py │ │ ├── operation.py │ │ ├── remove_categories.py │ │ ├── sample.py │ │ ├── spawn.py │ │ └── split.py │ └── utils.py ├── data_reader │ ├── __init__.py │ ├── dataset_downloader.py │ ├── file_reader.py │ └── image_loader.py ├── dataset │ ├── __init__.py │ ├── base_dataset.py │ └── vision_dataset.py ├── dataset_info.py ├── dataset_management │ ├── __init__.py │ ├── dataset_hub.py │ └── dataset_registry.py ├── factory │ ├── __init__.py │ ├── coco_manifest_adaptor_factory.py │ ├── data_manifest_factory.py │ └── operations │ │ ├── __init__.py │ │ ├── balanced_instance_weights_factory.py │ │ ├── coco_generator_factory.py │ │ ├── manifest_merger_factory.py │ │ ├── sampler_factory.py │ │ ├── spawn_factory.py │ │ ├── split_factory.py │ │ ├── stand_alone_image_dicts_factory.py │ │ └── supported_operations_by_data_type.py └── utils.py ├── image_caption ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py └── operations.py ├── image_classification ├── __init__.py ├── classification_as_kvp_dataset.py ├── coco_manifest_adaptor.py ├── manifest.py └── operations.py ├── image_matting ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py └── operations.py ├── image_object_detection ├── __init__.py ├── coco_manifest_adaptor.py ├── detection_as_classification_dataset.py ├── detection_as_kvp_dataset.py ├── manifest.py └── operations.py ├── image_regression ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py └── operations.py ├── image_text_matching ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py ├── operations.py └── vision_as_image_text_dataset.py ├── key_value_pair ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py └── operations.py ├── multi_task ├── __init__.py ├── coco_manifest_adaptor.py └── operations.py ├── text_2_image_retrieval ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py └── operations.py ├── torch ├── __init__.py ├── dataset.py └── torch_dataset.py ├── visual_object_grounding ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py └── operations.py └── visual_question_answering ├── __init__.py ├── coco_manifest_adaptor.py ├── manifest.py ├── operations.py └── vqa_as_kvp_dataset.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.8", "3.9", "3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | pip install -e .[torch] 32 | - name: Lint with flake8 33 | run: | 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. 37 | flake8 . --count --exit-zero --max-complexity=10 --statistics 38 | - name: Test with pytest 39 | run: | 40 | pytest 41 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | *venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea/* 107 | .vscode/* 108 | tests/temp 109 | 110 | temp/ 111 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /IRIS_DATA_FORMAT.md: -------------------------------------------------------------------------------- 1 | # Iris format 2 | 3 | Here is an example with explanation of what a `DatasetInfo` looks like for `iris` format: 4 | 5 | ```{json} 6 | { 7 | "name": "sampled-ms-coco", 8 | "version": 1, 9 | "description": "A sampled ms-coco dataset.", 10 | "type": "object_detection", 11 | "root_folder": "detection/coco2017_20200401", 12 | "format": "iris", // indicating the annotation data are stored in iris format 13 | "train": { 14 | "index_path": "train_images.txt", // index file for images and labels for training, example can be found in next section 15 | "files_for_local_usage": [ 16 | "train_images.zip", 17 | "train_labels.zip" 18 | ], 19 | }, 20 | "val": { 21 | "index_path": "val_images.txt", 22 | "files_for_local_usage": [ 23 | "val_images.zip", 24 | "val_labels.zip" 25 | ], 26 | }, 27 | "test": { 28 | "index_path": "test_images.txt", 29 | "files_for_local_usage": [ 30 | "test_images.zip", 31 | "test_labels.zip" 32 | ], 33 | }, 34 | "labelmap": "labels.txt", // includes tag names 35 | "image_metadata_path": "image_meta_info.txt", // includes info about image width and height 36 | }, 37 | ``` 38 | 39 | ## Iris image classification format 40 | 41 | Each rows in the index file (`index_path`) is: 42 | 43 | ``` {txt} 44 | 45 | ``` 46 | 47 | Note that the class/label index should start from zero. 48 | 49 | Example: 50 | 51 | ``` {txt} 52 | train_images1.zip@1.jpg 0,1,2 53 | train_images2.zip@1.jpg 2,3 54 | ... 55 | ``` 56 | 57 | ## Iris object detection format 58 | 59 | The index file for OD is slightly different from IC. Each rows in the index file is: 60 | 61 | ``` {txt} 62 | 63 | ``` 64 | 65 | Same with classification, the class/label index should start from 0. 66 | 67 | Example for `train_images.txt`: 68 | 69 | ``` {txt} 70 | train_images.zip@1.jpg train_labels.zip@1.txt 71 | train_images.zip@2.jpg train_labels.zip@2.txt 72 | ... 73 | ``` 74 | 75 | Formats and example for a label file like `train_labels.zip@1.txt`: 76 | 77 | ``` {txt} 78 | class_index left top right bottom 79 | ``` 80 | 81 | ``` {txt} 82 | 3 200 300 600 1200 // class_id, left, top, right, bottom 83 | 4 100 100 200 200 84 | ... 85 | ``` 86 | 87 | ## Multitask DatasetInfo 88 | 89 | The `DatasetInfo` for multitask is not very different from single task. A `'tasks'` section will be found in the json and the `'type'` of the dataset is `'multitask'`. Within each task, it wraps the 90 | info specific to that task. 91 | 92 | Below is an example for `'iris'` format, but the general idea applies to `'coco'` format as well. 93 | 94 | ```{json} 95 | { 96 | "name": "coco-vehicle-multitask", 97 | "version": 1, 98 | "type": "multitask", 99 | "root_folder": "classification/coco_vehicle_multitask_20210202", 100 | "format": "iris", 101 | "tasks": { 102 | "vehicle_color": { 103 | "type": "classification_multiclass", 104 | "train": { 105 | "index_path": "train_images_VehicleColor.txt", 106 | "files_for_local_usage": [ 107 | "train_images.zip" 108 | ] 109 | }, 110 | "test": { 111 | "index_path": "test_images_VehicleColor.txt", 112 | "files_for_local_usage": [ 113 | "test_images.zip" 114 | ] 115 | }, 116 | "labelmap": "labels_VehicleColor.txt" 117 | }, 118 | "vehicle_type": { 119 | "type": "classification_multiclass", 120 | "train": { 121 | "index_path": "train_images_VehicleType.txt", 122 | "files_for_local_usage": [ 123 | "train_images.zip" 124 | ] 125 | }, 126 | "test": { 127 | "index_path": "test_images_VehicleType.txt", 128 | "files_for_local_usage": [ 129 | "test_images.zip" 130 | ] 131 | }, 132 | "labelmap": "labels_VehicleType.txt" 133 | } 134 | } 135 | } 136 | ``` 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /TSV_FORMAT.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | TSV format uses a single or multiple *.tvs format to store both the image annotation and image files, where the 4 | - **first column** is the image_id, it can be string, int or any unique identifier 5 | - **second column** being the annotaion 6 | - **third column** being the [base64-encoded](https://en.wikipedia.org/wiki/Base64) string of the image data. 7 | 8 | This repo does not support consuming TSV format, but we provide tools for converting to/from TSV format from/to coco, for limited tasks and data, such as 9 | 10 | ```bash 11 | vision_convert_to_tsv {dataset_name} -r {dataset_registry_json} -k {data storage url} -f {local_dir} [-u Usages] 12 | ``` 13 | 14 | For more commands and help related to TSV, check out [UTIL_COMMANDS.md](./UTIL_COMMANDS.md). 15 | 16 | # Task-wise Format 17 | 18 | ## Image Classifciation 19 | 20 | Below is an example of multiclass classification: 21 | 22 | ``` 23 | 1 [{"class": "dog"}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 24 | 2 [{"class": "cat"}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 25 | 3 [{"class": "wolff"}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 26 | ... 27 | ``` 28 | 29 | For multilabel classification, 30 | 31 | ``` 32 | 1 [{"class": "dog"}, {"class": "canidae"}, {"class": "pet"}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 33 | 2 [{"class": "cat"}, {"class": "Felidae"}, {"class": "pet"}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 34 | 3 [{"class": "wolff"}, {"class": "canidae"}}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 35 | ... 36 | ``` 37 | 38 | ## Object Detection 39 | 40 | The format of object detection is very similar to mutlilable classification, with an additonal field `rect: [left, top, right, bottom]` 41 | 42 | ``` 43 | 1 [{"class": "dog", "rect": [10, 10, 100, 100]}, {"class": "cat", "rect": [10, 10, 100, 100]}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 44 | 2 [{"class": "cat", "rect": [10, 20, 250, 100]}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 45 | 3 [{"class": "wolff", "rect": [100, 200, 250, 1000]}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 46 | ... 47 | ``` 48 | 49 | 50 | ## Image Caption 51 | 52 | 53 | The format of image caption is straightforward as 54 | 55 | ``` 56 | 1 [{"caption": "dog playing with a cat"}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 57 | 2 [{"caption": "dog eating food"}] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 58 | 3 [{"caption": "wolff sitting in snow"] /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAA..... 59 | ... 60 | ``` 61 | -------------------------------------------------------------------------------- /UTIL_COMMANDS.md: -------------------------------------------------------------------------------- 1 | # Helpful commands 2 | 3 | - `vision_download`: help you download the dataset files to local disk for consumption, it can be downloaded/converted to TSV directly as well 4 | - `vision_check_dataset`: check if a dataset or [coco json + images] is problematic or not. 5 | - `vision_convert_od_to_ic`: convert a detection dataset to classification dataset (with or without augmentations). 6 | - `vision_convert_to_aml_coco`: generate a coco that can be used for AzureML 7 | - `vision_list_supported_operations`: list the supported operations by certain data type. 8 | 9 | For each commoand, run `command -h` for more details. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from os import path 3 | 4 | VERSION = '1.0.21' 5 | 6 | # Get the long description from the README file 7 | here = path.abspath(path.dirname(__file__)) 8 | with open(path.join(here, 'README.md'), 'r') as f: 9 | long_description = f.read() 10 | 11 | setuptools.setup(name='vision_datasets', 12 | author='Ping Jin, Shohei Ono', 13 | description='A utility repo for vision dataset access and management.', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | url='https://github.com/microsoft/vision-datasets', 17 | version=VERSION, 18 | python_requires='>=3.8', 19 | license='MIT', 20 | keywords='vision datasets classification detection', 21 | packages=setuptools.find_packages(), 22 | package_data={'': ['resources/*']}, 23 | install_requires=[ 24 | 'azure-identity', 25 | 'azure-storage-blob', 26 | 'numpy>=1.18.3', 27 | 'Pillow>=6.2.2', 28 | 'requests>=2.23.0', 29 | 'tenacity>=6.2.0', 30 | 'tqdm' 31 | ], 32 | classifiers=[ 33 | 'Development Status :: 4 - Beta', 34 | 'Intended Audience :: Developers', 35 | 'License :: OSI Approved :: MIT License', 36 | 'Programming Language :: Python :: 3.8', 37 | 'Programming Language :: Python :: 3.9', 38 | 'Programming Language :: Python :: 3.10', 39 | ], 40 | extras_require={ 41 | 'torch': ['torch>=1.6.0'], 42 | 'plot': ['matplotlib'], 43 | }, 44 | entry_points={ 45 | 'console_scripts': ['vision_download=vision_datasets.commands.download_dataset:main', 46 | 'vision_check_dataset=vision_datasets.commands.check_dataset:main', 47 | 'vision_transform_images=vision_datasets.commands.transform_images:main', 48 | 'vision_convert_od_to_ic=vision_datasets.commands.converter_od_to_ic:main', 49 | 'vision_convert_to_aml_coco=vision_datasets.commands.converter_to_aml_coco:main', 50 | 'vision_list_supported_operations=vision_datasets.commands.list_operations_by_data_type:main', 51 | 'vision_convert_to_line_oriented_format=vision_datasets.commands.converter_to_line_oriented_format:main'] 52 | }) 53 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/__init__.py -------------------------------------------------------------------------------- /tests/image_matting_test_data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/image_matting_test_data.zip -------------------------------------------------------------------------------- /tests/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/resources/__init__.py -------------------------------------------------------------------------------- /tests/test_box_alteration.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from vision_datasets.image_object_detection.detection_as_classification_dataset import BoxAlteration 5 | 6 | 7 | class TestBoxAlteration(unittest.TestCase): 8 | def test_zoom_box_out_of_range(self): 9 | left, t, r, b = BoxAlteration.zoom_box(10, 10, 20, 20, 100, 100, 50, 51, random.Random(0)) 10 | assert left == 0 11 | assert t == 0 12 | assert r == 100 13 | assert b == 100 14 | 15 | def test_zoom_box_shrink_to_pt(self): 16 | left, t, r, b = BoxAlteration.zoom_box(10, 5, 20, 25, 100, 100, 0, 0, random.Random(0)) 17 | assert left == 15 18 | assert t == 15 19 | assert r == 15 20 | assert b == 15 21 | 22 | def test_zoom_box_no_change(self): 23 | left, t, r, b = BoxAlteration.zoom_box(10, 5, 20, 25, 100, 100, 1, 1, random.Random(0)) 24 | assert left == 10 25 | assert t == 5 26 | assert r == 20 27 | assert b == 25 28 | 29 | def test_shift_box_no_change(self): 30 | left, t, r, b = BoxAlteration.shift_box(10, 5, 20, 25, 100, 100, 0, 0, random.Random(0)) 31 | assert left == 10 32 | assert t == 5 33 | assert r == 20 34 | assert b == 25 35 | 36 | def test_shift_box_rb_out(self): 37 | left, t, r, b = BoxAlteration.shift_box(10, 5, 20, 25, 100, 100, 50, 50, random.Random(0)) 38 | assert left == 100 39 | assert t == 100 40 | assert r == 100 41 | assert b == 100 42 | 43 | def test_shift_box_lt_out(self): 44 | left, t, r, b = BoxAlteration.shift_box(10, 5, 20, 25, 100, 100, -50, -50, random.Random(0)) 45 | assert left == 0 46 | assert t == 0 47 | assert r == 0 48 | assert b == 0 49 | 50 | def test_shift_box_rb(self): 51 | left, t, r, b = BoxAlteration.shift_box(10, 5, 20, 25, 100, 100, 1, 1, random.Random(0)) 52 | assert left == 20 53 | assert t == 25 54 | assert r == 30 55 | assert b == 45 56 | 57 | def test_shift_box_lt(self): 58 | left, t, r, b = BoxAlteration.shift_box(10, 5, 20, 25, 100, 100, -0.1, -0.1, random.Random(0)) 59 | assert left == 9 60 | assert t == 3 61 | assert r == 19 62 | assert b == 23 63 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_coco_adaptor/__init__.py -------------------------------------------------------------------------------- /tests/test_coco_adaptor/coco_adaptor_base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from vision_datasets.common import DatasetManifest, DatasetManifestWithMultiImageLabel 3 | from ..resources.util import coco_dict_to_manifest 4 | 5 | 6 | class BaseCocoAdaptor: 7 | def test_create_data_manifest(self, coco_dict, schema: dict = None): 8 | manifest = coco_dict_to_manifest(self.TASK, coco_dict, schema) 9 | self.check(manifest, coco_dict) 10 | return manifest 11 | 12 | def test_create_data_manifest_with_additional_info(self, coco_dict, schema: dict = None): 13 | coco_dict = copy.deepcopy(coco_dict) 14 | for img in coco_dict['images']: 15 | img['img_field_1'] = 1 16 | img['img_field_2'] = 2 17 | 18 | for ann in coco_dict['annotations']: 19 | ann['ann_field_1'] = 1 20 | ann['ann_field_2'] = 2 21 | 22 | if 'categories' in coco_dict: 23 | for cat in coco_dict['categories']: 24 | cat['cat_field_1'] = 1 25 | cat['cat_field_2'] = 2 26 | 27 | coco_dict['dataset_field_1'] = 1 28 | coco_dict['dataset_field_2'] = 2 29 | 30 | manifest = coco_dict_to_manifest(self.TASK, coco_dict, schema) 31 | self.check(manifest, coco_dict) 32 | for img in manifest.images: 33 | assert img.additional_info.get('img_field_1') == 1 34 | assert img.additional_info.get('img_field_2') == 2 35 | for ann in img.labels: 36 | assert ann.additional_info.get('ann_field_1') == 1 37 | assert ann.additional_info.get('ann_field_2') == 2 38 | 39 | if isinstance(manifest, DatasetManifestWithMultiImageLabel): 40 | for ann in manifest.annotations: 41 | assert ann.additional_info.get('ann_field_1') == 1 42 | assert ann.additional_info.get('ann_field_2') == 2 43 | 44 | if 'categories' in coco_dict: 45 | for cat in manifest.categories: 46 | assert cat.additional_info.get('cat_field_1') == 1 47 | assert cat.additional_info.get('cat_field_2') == 2 48 | 49 | assert manifest.additional_info.get('dataset_field_1') == 1 50 | assert manifest.additional_info.get('dataset_field_2') == 2 51 | 52 | def check(self, manifest, coco_dict): 53 | assert len(manifest.images) == len(coco_dict['images']) 54 | categories = coco_dict.get('categories') 55 | if categories: 56 | assert manifest.categories and len(manifest.categories) == len(categories) 57 | if isinstance(manifest, DatasetManifest): 58 | assert sum([len(img.labels) for img in manifest.images]) == len(coco_dict['annotations']) 59 | elif isinstance(manifest, DatasetManifestWithMultiImageLabel): 60 | assert manifest.categories is None 61 | assert len(manifest.annotations) == len(coco_dict['annotations']) 62 | img_id_set = set(range(len(manifest.images))) 63 | img_id_coco_to_manifest = {im['id']: id for id, im in enumerate(coco_dict['images'])} 64 | for id, ann in enumerate(manifest.annotations): 65 | assert all([img_id in img_id_set for img_id in ann.img_ids]) 66 | assert ann.id == coco_dict['annotations'][id]['id'] 67 | coco_img_ids = coco_dict['annotations'][id]['image_ids'] 68 | assert ann.img_ids == [img_id_coco_to_manifest[coco_img_id] for coco_img_id in coco_img_ids] 69 | else: 70 | raise ValueError(f"Unknown manifest type: {type(manifest)}") 71 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_image_caption_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_datasets.common import DatasetTypes 3 | from .coco_adaptor_base import BaseCocoAdaptor 4 | from ..resources.util import coco_database 5 | 6 | 7 | class TestImageCaption(BaseCocoAdaptor): 8 | TASK = DatasetTypes.IMAGE_CAPTION 9 | 10 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 11 | def test_create_data_manifest(self, coco_dict): 12 | super().test_create_data_manifest(coco_dict) 13 | 14 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 15 | def test_create_data_manifest_with_additional_info(self, coco_dict): 16 | super().test_create_data_manifest_with_additional_info(coco_dict) 17 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_image_matting_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from vision_datasets.common import DatasetTypes 4 | 5 | from ..resources.util import coco_database 6 | from .coco_adaptor_base import BaseCocoAdaptor 7 | 8 | 9 | class TestImageMatting(BaseCocoAdaptor): 10 | TASK = DatasetTypes.IMAGE_MATTING 11 | 12 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 13 | def test_create_data_manifest(self, coco_dict): 14 | super().test_create_data_manifest(coco_dict) 15 | 16 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 17 | def test_create_data_manifest_with_additional_info(self, coco_dict): 18 | super().test_create_data_manifest_with_additional_info(coco_dict) 19 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_image_regression_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from vision_datasets.common import DatasetTypes 4 | 5 | from ..resources.util import coco_database, coco_dict_to_manifest 6 | from .coco_adaptor_base import BaseCocoAdaptor 7 | 8 | 9 | class TestImageRegression(BaseCocoAdaptor): 10 | TASK = DatasetTypes.IMAGE_REGRESSION 11 | 12 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 13 | def test_create_data_manifest(self, coco_dict): 14 | super().test_create_data_manifest(coco_dict) 15 | 16 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 17 | def test_create_data_manifest_with_additional_info(self, coco_dict): 18 | super().test_create_data_manifest_with_additional_info(coco_dict) 19 | 20 | def test_create_data_manifest_when_multiple_annotation_per_image_should_fail(self): 21 | coco_dict = { 22 | "images": [{"id": 1, "file_name": "train_images.zip@1.jpg"}, 23 | {"id": 2, "file_name": "train_images.zip@2.jpg"}], 24 | "annotations": [ 25 | {"id": 1, "image_id": 1, "target": 1.0}, 26 | {"id": 2, "image_id": 1, "target": 3.0}, 27 | {"id": 3, "image_id": 2, "target": 2.0}, 28 | ] 29 | } 30 | 31 | with pytest.raises(ValueError, match='image with id 1 will possess unexpected number of annotations 2 for DatasetTypes.IMAGE_REGRESSION dataset.'): 32 | coco_dict_to_manifest(self.TASK, coco_dict) 33 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_image_text_matching_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from vision_datasets.common import DatasetTypes 4 | 5 | from ..resources.util import coco_database 6 | from .coco_adaptor_base import BaseCocoAdaptor 7 | 8 | 9 | class TestImageTextMatching(BaseCocoAdaptor): 10 | TASK = DatasetTypes.IMAGE_TEXT_MATCHING 11 | 12 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 13 | def test_create_data_manifest(self, coco_dict): 14 | super().test_create_data_manifest(coco_dict) 15 | 16 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 17 | def test_create_data_manifest_with_additional_info(self, coco_dict): 18 | super().test_create_data_manifest_with_additional_info(coco_dict) 19 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_key_value_pair_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import tempfile 4 | import pathlib 5 | import pytest 6 | 7 | from vision_datasets.common import DatasetTypes, CocoManifestAdaptorFactory 8 | from .coco_adaptor_base import BaseCocoAdaptor 9 | from ..resources.util import coco_database, schema_database, coco_dict_to_manifest 10 | 11 | 12 | class TestKeyValuePair(BaseCocoAdaptor): 13 | TASK = DatasetTypes.KEY_VALUE_PAIR 14 | 15 | @pytest.mark.parametrize("coco_dict, schema", zip(coco_database[TASK], schema_database)) 16 | def test_create_data_manifest(self, coco_dict, schema): 17 | super().test_create_data_manifest(coco_dict, schema) 18 | 19 | @pytest.mark.parametrize("coco_dict, schema", zip(coco_database[TASK], schema_database)) 20 | def test_create_data_manifest_with_additional_info(self, coco_dict, schema): 21 | super().test_create_data_manifest_with_additional_info(coco_dict, schema) 22 | 23 | def prepare_schema_and_coco_dict(self): 24 | schema = copy.deepcopy(schema_database[1]) 25 | coco_dict = copy.deepcopy(coco_database[TestKeyValuePair.TASK][1]) 26 | return schema, coco_dict 27 | 28 | def test_create_data_manifest_example(self): 29 | schema, coco_dict = self.prepare_schema_and_coco_dict() 30 | key_value_pair_manifest = coco_dict_to_manifest(TestKeyValuePair.TASK, coco_dict, schema=schema) 31 | 32 | key_value_pair_manifest.images[0].additional_info['meta_data'] = coco_dict['images'][0]['metadata'] 33 | 34 | ann_0 = key_value_pair_manifest.annotations[0] 35 | assert ann_0.id == coco_dict['annotations'][0]['id'] 36 | assert ann_0.img_ids == [0, 1] 37 | assert ann_0.fields == coco_dict['annotations'][0]['fields'] 38 | assert ann_0.text is None 39 | 40 | ann_1 = key_value_pair_manifest.annotations[1] 41 | assert ann_1.id == coco_dict['annotations'][1]['id'] 42 | assert ann_1.img_ids == [1, 0] 43 | assert ann_1.fields == coco_dict['annotations'][1]['fields'] 44 | assert ann_1.text == coco_dict['annotations'][1]['text'] 45 | 46 | def test_schema_mismatch_key_value_pair(self): 47 | schema, coco_dict = self.prepare_schema_and_coco_dict() 48 | # remove a field that defined in schema 49 | del coco_dict['annotations'][0]['fields']['productMatch'] 50 | 51 | adaptor = CocoManifestAdaptorFactory.create(TestKeyValuePair.TASK, schema=schema) 52 | with tempfile.TemporaryDirectory() as temp_dir: 53 | dm1_path = pathlib.Path(temp_dir) / 'coco.json' 54 | dm1_path.write_text(json.dumps(coco_dict)) 55 | with pytest.raises(ValueError): 56 | adaptor.create_dataset_manifest(str(dm1_path)) 57 | 58 | def test_grounding_error_for_multi_image_annotation(self): 59 | schema = copy.deepcopy(schema_database[2]) 60 | coco_dict = copy.deepcopy(coco_database[TestKeyValuePair.TASK][2]) 61 | # change one annotation to multi-image annotation 62 | coco_dict['annotations'][0]['image_ids'] = [1, 2] 63 | 64 | adaptor = CocoManifestAdaptorFactory.create(TestKeyValuePair.TASK, schema=schema) 65 | with tempfile.TemporaryDirectory() as temp_dir: 66 | dm1_path = pathlib.Path(temp_dir) / 'coco.json' 67 | dm1_path.write_text(json.dumps(coco_dict)) 68 | with pytest.raises(ValueError): 69 | adaptor.create_dataset_manifest(str(dm1_path)) 70 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_multiclass_classification_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pytest 3 | from vision_datasets.common import DatasetTypes 4 | from .coco_adaptor_base import BaseCocoAdaptor 5 | from ..resources.util import coco_database, coco_dict_to_manifest 6 | 7 | 8 | class TestMultiClassClassification(BaseCocoAdaptor): 9 | TASK = DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS 10 | 11 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 12 | def test_create_data_manifest(self, coco_dict): 13 | super().test_create_data_manifest(coco_dict) 14 | 15 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 16 | def test_create_data_manifest_with_additional_info(self, coco_dict): 17 | super().test_create_data_manifest_with_additional_info(coco_dict) 18 | 19 | def test_create_data_manifest_when_multiple_annotation_per_image_should_fail(self): 20 | coco_dict = { 21 | "images": [ 22 | {"id": 1, "width": 224.0, "height": 224.0, "file_name": "train/1.jpg"}, 23 | {"id": 2, "width": 224.0, "height": 224.0, "file_name": "train/3.jpg"}], 24 | "annotations": [ 25 | {"id": 1, "category_id": 1, "image_id": 1}, 26 | {"id": 2, "category_id": 2, "image_id": 1}, 27 | {"id": 3, "category_id": 2, "image_id": 2} 28 | ], 29 | "categories": [ 30 | {"id": 1, "name": "cat"}, 31 | {"id": 2, "name": "dog"} 32 | ] 33 | } 34 | with pytest.raises(ValueError, match='image with id 1 will possess unexpected number of annotations 2 for DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS dataset.'): 35 | coco_dict_to_manifest(self.TASK, coco_dict) 36 | 37 | def test_supercategory_working(self): 38 | super_category = 'animal' 39 | coco_dict = copy.deepcopy(coco_database[self.TASK][0]) 40 | for category in coco_dict['categories']: 41 | category['supercategory'] = super_category 42 | manifest = coco_dict_to_manifest(self.TASK, coco_dict) 43 | for category in manifest.categories: 44 | assert category.super_category == super_category 45 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_multilabel_classification_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_datasets.common import DatasetTypes 3 | from .coco_adaptor_base import BaseCocoAdaptor 4 | from ..resources.util import coco_database 5 | 6 | 7 | class TestMultiLabelClassification(BaseCocoAdaptor): 8 | TASK = DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL 9 | 10 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 11 | def test_create_data_manifest(self, coco_dict): 12 | super().test_create_data_manifest(coco_dict) 13 | 14 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 15 | def test_create_data_manifest_with_additional_info(self, coco_dict): 16 | super().test_create_data_manifest_with_additional_info(coco_dict) 17 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_multitask_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_datasets.common import DatasetTypes 3 | from ..resources.util import coco_database, coco_dict_to_manifest_multitask 4 | 5 | 6 | class TestMultiTask: 7 | TASK = DatasetTypes.MULTITASK 8 | 9 | @pytest.mark.parametrize("tasks, coco_dicts", coco_database[TASK]) 10 | def test_create_data_manifest(self, tasks, coco_dicts): 11 | coco_dict_to_manifest_multitask(tasks, coco_dicts) 12 | 13 | # TODO: need to implement more checks 14 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_object_detection_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import copy 3 | from vision_datasets.common import DatasetTypes 4 | from .coco_adaptor_base import BaseCocoAdaptor 5 | from ..resources.util import coco_database 6 | 7 | 8 | class TestObjectDetection(BaseCocoAdaptor): 9 | TASK = DatasetTypes.IMAGE_OBJECT_DETECTION 10 | 11 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 12 | def test_create_data_manifest(self, coco_dict): 13 | super().test_create_data_manifest(coco_dict) 14 | 15 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 16 | def test_create_data_manifest_with_additional_info(self, coco_dict): 17 | super().test_create_data_manifest_with_additional_info(coco_dict) 18 | 19 | def test_iscrowd_working(self): 20 | coco_dict = copy.deepcopy(coco_database[self.TASK][0]) 21 | coco_dict['annotations'][0]['iscrowd'] = 1 22 | super().test_create_data_manifest(coco_dict) 23 | 24 | def check(self, manifest, coco_dict): 25 | super().check(manifest, coco_dict) 26 | is_crowd_cnt_in_manifest = sum([sum([0 if x.additional_info.get('iscrowd', 0) == 0 else 1 for x in image.labels]) for image in manifest.images]) 27 | is_crowd_cnt_in_coco = sum([0 if ann.get('iscrowd', 0) == 0 else 1 for ann in coco_dict['annotations']]) 28 | assert is_crowd_cnt_in_coco == is_crowd_cnt_in_manifest 29 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_text_2_image_retrieval_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_datasets.common import DatasetTypes 3 | from .coco_adaptor_base import BaseCocoAdaptor 4 | from ..resources.util import coco_database 5 | 6 | 7 | class TestText2ImageRetrieval(BaseCocoAdaptor): 8 | TASK = DatasetTypes.TEXT_2_IMAGE_RETRIEVAL 9 | 10 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 11 | def test_create_data_manifest(self, coco_dict): 12 | super().test_create_data_manifest(coco_dict) 13 | 14 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 15 | def test_create_data_manifest_with_additional_info(self, coco_dict): 16 | super().test_create_data_manifest_with_additional_info(coco_dict) 17 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_visual_object_grounding_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import copy 3 | from vision_datasets.common import DatasetTypes 4 | from .coco_adaptor_base import BaseCocoAdaptor 5 | from ..resources.util import coco_database 6 | from vision_datasets.visual_object_grounding import VisualObjectGroundingLabelManifest 7 | 8 | 9 | class TestVisualObjectGrounding(BaseCocoAdaptor): 10 | TASK = DatasetTypes.VISUAL_OBJECT_GROUNDING 11 | 12 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 13 | def test_create_data_manifest(self, coco_dict): 14 | super().test_create_data_manifest(coco_dict) 15 | 16 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 17 | def test_create_data_manifest_with_additional_info(self, coco_dict): 18 | super().test_create_data_manifest_with_additional_info(coco_dict) 19 | 20 | @pytest.mark.parametrize("format", ["ltwh", "ltrb"]) 21 | def test_bbox_format(self, format): 22 | coco_dict = copy.deepcopy(coco_database[self.TASK][0]) 23 | coco_dict['bbox_format'] = format 24 | manifest = super().test_create_data_manifest(coco_dict) 25 | ann_by_image = [[] for _ in range(len(coco_dict['images']))] 26 | 27 | for ann in coco_dict['annotations']: 28 | ann_by_image[ann['image_id']-1].append(ann) 29 | 30 | for i, image in enumerate(manifest.images): 31 | image_anns = ann_by_image[i] 32 | for j, label in enumerate(image.labels): 33 | label: VisualObjectGroundingLabelManifest 34 | for k, g in enumerate(label.groundings): 35 | for t, bbox in enumerate(g.bboxes): 36 | gt = image_anns[j]["groundings"][k]['bboxes'][t] 37 | gt_ltwh = gt if format == "ltwh" else [gt[0], gt[1], gt[2]-gt[0], gt[3]-gt[1]] 38 | manifest_ltwh = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]] 39 | assert gt_ltwh == manifest_ltwh 40 | -------------------------------------------------------------------------------- /tests/test_coco_adaptor/test_visual_question_anwsering_coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_datasets.common import DatasetTypes 3 | from .coco_adaptor_base import BaseCocoAdaptor 4 | from ..resources.util import coco_database 5 | 6 | 7 | class TestVisualQuestionAnswering(BaseCocoAdaptor): 8 | TASK = DatasetTypes.VISUAL_QUESTION_ANSWERING 9 | 10 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 11 | def test_create_data_manifest(self, coco_dict): 12 | super().test_create_data_manifest(coco_dict) 13 | 14 | @pytest.mark.parametrize("coco_dict", coco_database[TASK]) 15 | def test_create_data_manifest_with_additional_info(self, coco_dict): 16 | super().test_create_data_manifest_with_additional_info(coco_dict) 17 | -------------------------------------------------------------------------------- /tests/test_coco_iris_to_kvp_wrapper/test_classification_as_kvp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tests.test_fixtures import MulticlassClassificationTestFixtures, MultilabelClassificationTestFixtures 4 | from vision_datasets.common import DatasetTypes 5 | from vision_datasets.image_classification import MulticlassClassificationAsKeyValuePairDataset, MultilabelClassificationAsKeyValuePairDataset 6 | from vision_datasets.key_value_pair.manifest import KeyValuePairLabelManifest 7 | 8 | 9 | class TestClassificationAsKeyValuePairDataset(unittest.TestCase): 10 | def test_multiclass_classification(self): 11 | sample_classification_dataset, tempdir = MulticlassClassificationTestFixtures.create_an_ic_dataset() 12 | with tempdir: 13 | kvp_dataset = MulticlassClassificationAsKeyValuePairDataset(sample_classification_dataset) 14 | 15 | self.assertIsInstance(kvp_dataset, MulticlassClassificationAsKeyValuePairDataset) 16 | self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR) 17 | self.assertIn("name", kvp_dataset.dataset_info.schema) 18 | self.assertIn("description", kvp_dataset.dataset_info.schema) 19 | self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema) 20 | 21 | self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"], 22 | {"className": { 23 | "type": "string", 24 | "description": "Class name that the image belongs to.", 25 | "classes": { 26 | "1-class": {"description": "A single class name. Only output 1-class as the class name if present."}, 27 | "2-class": {"description": "A single class name. Only output 2-class as the class name if present."}, 28 | "3-class": {"description": "A single class name. Only output 3-class as the class name if present."}, 29 | } 30 | } 31 | }) 32 | 33 | _, target, _ = kvp_dataset[0] 34 | self.assertIsInstance(target, KeyValuePairLabelManifest) 35 | self.assertEqual(target.label_data, 36 | {"fields": {"className": {"value": "1-class"}}} 37 | ) 38 | 39 | def test_multilabel_classification(self): 40 | sample_classification_dataset, tempdir = MultilabelClassificationTestFixtures.create_an_ic_dataset(n_images=2, n_categories=2) 41 | with tempdir: 42 | kvp_dataset = MultilabelClassificationAsKeyValuePairDataset(sample_classification_dataset) 43 | 44 | self.assertIsInstance(kvp_dataset, MultilabelClassificationAsKeyValuePairDataset) 45 | self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR) 46 | self.assertIn("name", kvp_dataset.dataset_info.schema) 47 | self.assertIn("description", kvp_dataset.dataset_info.schema) 48 | self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema) 49 | 50 | self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"], 51 | {'classNames': { 52 | 'type': 'array', 53 | 'description': 'Class names that the image belongs to.', 54 | 'items': { 55 | 'type': 'string', 56 | 'description': 'Single class name.', 57 | 'classes': { 58 | '1-class': {'description': 'A single class name. Only output 1-class as the class name if present.'}, 59 | '2-class': {'description': 'A single class name. Only output 2-class as the class name if present.'} 60 | } 61 | } 62 | } 63 | } 64 | ) 65 | 66 | _, target, _ = kvp_dataset[0] 67 | self.assertIsInstance(target, KeyValuePairLabelManifest) 68 | self.assertEqual(target.label_data, 69 | {'fields': { 70 | 'classNames': {'value': [{'value': '1-class'}, {'value': '2-class'}]}} 71 | }) 72 | 73 | 74 | if __name__ == '__main__': 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /tests/test_coco_iris_to_kvp_wrapper/test_vqa_as_kvp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tests.test_fixtures import VQATestFixtures 4 | from vision_datasets.common.constants import DatasetTypes 5 | from vision_datasets.key_value_pair.manifest import KeyValuePairLabelManifest 6 | from vision_datasets.visual_question_answering import VQAAsKeyValuePairDataset 7 | 8 | 9 | class TestVQAAsKeyValuePairDataset(unittest.TestCase): 10 | def test_vqa_to_kvp(self): 11 | sample_vqa_dataset, tempdir = VQATestFixtures.create_a_vqa_dataset() 12 | with tempdir: 13 | kvp_dataset = VQAAsKeyValuePairDataset(sample_vqa_dataset) 14 | 15 | self.assertIsInstance(kvp_dataset, VQAAsKeyValuePairDataset) 16 | self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR) 17 | self.assertIn("name", kvp_dataset.dataset_info.schema) 18 | self.assertIn("description", kvp_dataset.dataset_info.schema) 19 | self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema) 20 | 21 | self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"], 22 | {'answer': {'type': 'string', 'description': 'Answer to the question.'}, 23 | 'rationale': {'type': 'string', 'description': 'Rationale for the answer.'}}) 24 | 25 | _, target, _ = kvp_dataset[0] 26 | self.assertIsInstance(target, KeyValuePairLabelManifest) 27 | self.assertEqual(target.label_data, 28 | {'fields': {'answer': {'value': 'answer 1'}}, 'text': {'question': 'question 1'}}) 29 | 30 | self.assertEqual(len(kvp_dataset), 3) 31 | self.assertEqual(len(kvp_dataset.dataset_manifest.images), 2) 32 | 33 | # Last image has 2 questions associated with it 34 | self.assertEqual(kvp_dataset[-2][0][0].size, kvp_dataset[-1][0][0].size) 35 | -------------------------------------------------------------------------------- /tests/test_dataset_downloader.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import os 4 | import pathlib 5 | import unittest 6 | from unittest.mock import ANY, MagicMock 7 | 8 | from vision_datasets.common import DatasetDownloader, DatasetRegistry, DatasetTypes 9 | 10 | 11 | class TestDatasetDownloader(unittest.TestCase): 12 | def test_use_latest_version(self): 13 | datasets = [ 14 | { 15 | 'name': 'dataset_name', 16 | 'version': 42, 17 | 'type': DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS.name, 18 | 'root_folder': '', 19 | 'train': {'index_path': '42.txt', 'files_for_local_usage': []}, 20 | 'test': {'index_path': 'val42.txt', 'files_for_local_usage': []} 21 | }, 22 | { 23 | 'name': 'dataset_name', 24 | 'version': 2, 25 | 'type': DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS.name, 26 | 'root_folder': './', 27 | 'train': {'index_path': '2.txt', 'files_for_local_usage': []}, 28 | 'test': {'index_path': 'val2.txt', 'files_for_local_usage': []} 29 | }, 30 | { 31 | 'name': 'dataset_name', 32 | 'version': 4, 33 | 'type': DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS.name, 34 | 'root_folder': './', 35 | 'train': {'index_path': '4.txt', 'files_for_local_usage': []}, 36 | 'test': {'index_path': 'val4.txt', 'files_for_local_usage': []} 37 | }, 38 | { 39 | 'name': 'dataset_name2', 40 | 'version': 43, 41 | 'type': DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS.name, 42 | 'root_folder': './', 43 | 'train': {'index_path': '43.txt', 'files_for_local_usage': []}, 44 | 'test': {'index_path': 'val43.txt', 'files_for_local_usage': []} 45 | } 46 | ] 47 | dataset_info = self._make_reg(datasets).get_dataset_info('dataset_name') 48 | downloader = self._make_downloader(dataset_info) 49 | downloader._download_files = MagicMock() 50 | downloader.download('dataset_name') 51 | downloader._download_files.assert_called_once_with({pathlib.Path('42.txt'), pathlib.Path('val42.txt')}, unittest.mock.ANY) 52 | 53 | def test_delete_temp_dir(self): 54 | datasets = [{'name': 'dataset_name', 'type': DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS.name, 'root_folder': './', 'version': 42, 55 | 'train': {'index_path': '42.txt', 'files_for_local_usage': []}, 56 | 'test': {'index_path': '42.txt', 'files_for_local_usage': []}}] 57 | 58 | dataset_info = self._make_reg(datasets).get_dataset_info('dataset_name') 59 | downloader = self._make_downloader(dataset_info) 60 | downloader._download_files = MagicMock() 61 | with downloader.download() as downloaded: 62 | for x in downloaded.base_dirs: 63 | self.assertTrue(os.path.isdir(x)) 64 | 65 | for x in downloaded.base_dirs: 66 | self.assertFalse(os.path.isdir(x)) 67 | 68 | def test_concatenate_path(self): 69 | datasets = [{ 70 | 'name': 'dataset_name', 71 | 'type': DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS.name, 72 | 'root_folder': 'somewhere', 73 | 'version': 1, 74 | 'train': {'index_path': 'dir/42.txt', 'files_for_local_usage': []}}] 75 | dataset_info = self._make_reg(datasets).get_dataset_info('dataset_name') 76 | downloader = self._make_downloader(dataset_info) 77 | with unittest.mock.patch('requests.get') as mock_get: 78 | mock_get.return_value.__enter__.return_value.raw = io.BytesIO(b'42') 79 | mock_get.return_value.__enter__.return_value.status_code = 200 80 | downloader.download() 81 | mock_get.assert_called_once_with('http://example.com/somewhere/dir/42.txt?sastoken=something', allow_redirects=True, stream=True, timeout=ANY) 82 | 83 | @staticmethod 84 | def _make_downloader(dataset_info, base_path='http://example.com/?sastoken=something'): 85 | return DatasetDownloader(base_path, dataset_info) 86 | 87 | @staticmethod 88 | def _make_reg(datasets) -> DatasetRegistry: 89 | return DatasetRegistry(json.dumps(datasets)) 90 | 91 | 92 | if __name__ == '__main__': 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /tests/test_dataset_registry.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | 4 | from vision_datasets.common import DatasetRegistry, DatasetTypes, Usages 5 | 6 | 7 | class TestDatasetRegistry(unittest.TestCase): 8 | DUMMY_DATA_1 = { 9 | "name": "dummy1", 10 | "version": 1, 11 | "type": "classification_multiclass", 12 | "root_folder": "dummy", 13 | "test": { 14 | "index_path": "test.txt", 15 | "files_for_local_usage": [ 16 | "Train.zip" 17 | ] 18 | } 19 | } 20 | 21 | DUMMY_DATA_1_V2 = { 22 | "name": "dummy2", 23 | "version": 1, 24 | "type": "classification_multiclass", 25 | "root_folder": "dummy", 26 | "test": { 27 | "index_path": "test.txt", 28 | "files_for_local_usage": [ 29 | "Train.zip" 30 | ] 31 | } 32 | } 33 | 34 | DUMMY_DATA_2 = { 35 | "name": "dummy2", 36 | "version": 1, 37 | "type": "classification_multiclass", 38 | "root_folder": "dummy", 39 | "test": { 40 | "index_path": "test.txt", 41 | "files_for_local_usage": [ 42 | "Train.zip" 43 | ] 44 | } 45 | } 46 | 47 | DUMMY_DATA_KEY_VALUE_PAIR = { 48 | "name": "dummykey_value_pair", 49 | "version": 1, 50 | "type": "key_value_pair", 51 | "format": "coco", 52 | "root_folder": "dummy", 53 | "test": { 54 | "index_path": "test.json", 55 | "files_for_local_usage": [ 56 | "test.zip" 57 | ] 58 | }, 59 | "schema": { 60 | "name": "name_key_value_pair", 61 | "description": "dummy description", 62 | "fieldSchema": { 63 | "key1": "value1", 64 | "key2": 2 65 | } 66 | } 67 | } 68 | 69 | def test_create_dataset_reg(self): 70 | dn = self.DUMMY_DATA_1['name'] 71 | dr = DatasetRegistry(json.dumps([self.DUMMY_DATA_1])) 72 | assert len(dr.list_data_version_and_types()) == 1 73 | info = dr.get_dataset_info(dn) 74 | assert info 75 | assert info.name == dn 76 | assert info.version == self.DUMMY_DATA_1['version'] 77 | assert info.root_folder == self.DUMMY_DATA_1['root_folder'] 78 | assert info.type == DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS 79 | assert info.files_for_local_usage[Usages.TEST] == self.DUMMY_DATA_1['test']['files_for_local_usage'] 80 | assert info.index_files[Usages.TEST] == self.DUMMY_DATA_1['test']['index_path'] 81 | 82 | def test_create_dataset_reg_with_two_jsons(self): 83 | dr = DatasetRegistry([json.dumps([self.DUMMY_DATA_1]), json.dumps([self.DUMMY_DATA_2])]) 84 | assert len(dr.list_data_version_and_types()) == 2 85 | assert dr.get_dataset_info(self.DUMMY_DATA_1['name']) 86 | assert dr.get_dataset_info(self.DUMMY_DATA_2['name']) 87 | 88 | def test_create_key_value_pair_dataset_reg(self): 89 | dn = self.DUMMY_DATA_KEY_VALUE_PAIR['name'] 90 | dr = DatasetRegistry(json.dumps([self.DUMMY_DATA_KEY_VALUE_PAIR])) 91 | assert len(dr.list_data_version_and_types()) == 1 92 | info = dr.get_dataset_info(dn) 93 | assert info 94 | assert info.name == dn 95 | assert info.version == self.DUMMY_DATA_KEY_VALUE_PAIR['version'] 96 | assert info.root_folder == self.DUMMY_DATA_KEY_VALUE_PAIR['root_folder'] 97 | assert info.type == DatasetTypes.KEY_VALUE_PAIR 98 | assert info.files_for_local_usage[Usages.TEST] == self.DUMMY_DATA_KEY_VALUE_PAIR['test']['files_for_local_usage'] 99 | assert info.index_files[Usages.TEST] == self.DUMMY_DATA_KEY_VALUE_PAIR['test']['index_path'] 100 | # schema is required 101 | assert info.schema == self.DUMMY_DATA_KEY_VALUE_PAIR['schema'] 102 | -------------------------------------------------------------------------------- /tests/test_file_reader.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import multiprocessing 3 | import os 4 | import pathlib 5 | import pickle 6 | import tempfile 7 | import unittest 8 | import zipfile 9 | 10 | from vision_datasets.common import FileReader 11 | from vision_datasets.common.data_reader.file_reader import MultiProcessZipFile 12 | 13 | 14 | def open_zipfile(zip_file, filename, queue): 15 | queue.put(zip_file.open(filename).read()) 16 | 17 | 18 | class TestMultiProcessZipFile(unittest.TestCase): 19 | def test_single_process(self): 20 | with self._with_test_zip({'test.txt': b'contents'}) as zip_filepath: 21 | zip_file = MultiProcessZipFile(zip_filepath) 22 | with zip_file.open('test.txt') as z: 23 | self.assertEqual(z.read(), b'contents') 24 | zip_file.close() 25 | 26 | def test_access_from_multiple_process(self): 27 | with self._with_test_zip({'test.txt': b'contents'}) as zip_filepath: 28 | zip_file = MultiProcessZipFile(zip_filepath) 29 | queue = multiprocessing.Queue() 30 | processes = [multiprocessing.Process(target=open_zipfile, args=(zip_file, 'test.txt', queue)) for i in 31 | range(5)] 32 | [p.start() for p in processes] 33 | [p.join() for p in processes] 34 | 35 | self.assertEqual(queue.get(False), b'contents') 36 | self.assertEqual(queue.get(False), b'contents') 37 | self.assertEqual(queue.get(False), b'contents') 38 | self.assertEqual(queue.get(False), b'contents') 39 | self.assertEqual(queue.get(False), b'contents') 40 | self.assertTrue(queue.empty()) 41 | 42 | def test_pickle(self): 43 | with self._with_test_zip({'test.txt': b'contents'}) as zip_filepath: 44 | zip_file = MultiProcessZipFile(zip_filepath) 45 | with zip_file.open('test.txt') as z: 46 | self.assertEqual(z.read(), b'contents') 47 | 48 | serialized = pickle.dumps(zip_file) 49 | deserialized = pickle.loads(serialized) 50 | 51 | with deserialized.open('test.txt') as z: 52 | self.assertEqual(z.read(), b'contents') 53 | 54 | deserialized.close() 55 | zip_file.close() 56 | 57 | @staticmethod 58 | @contextlib.contextmanager 59 | def _with_test_zip(contents): 60 | """ 61 | Args: 62 | contents: {filename: binary_contents} Files to be put in the test zip file. 63 | """ 64 | 65 | with tempfile.TemporaryDirectory() as tempdir: 66 | zip_filepath = pathlib.Path(tempdir) / 'test.zip' 67 | with zipfile.ZipFile(zip_filepath, 'w') as f: 68 | for filename, bin_contents in contents.items(): 69 | f.writestr(filename, bin_contents) 70 | 71 | yield zip_filepath 72 | 73 | 74 | class TestFileReader(unittest.TestCase): 75 | def test_read(self): 76 | with tempfile.TemporaryDirectory() as tempdir: 77 | with zipfile.ZipFile(os.path.join(tempdir, 'test.zip'), 'w') as f: 78 | f.writestr('test.txt', b'zip_contents') 79 | with open(os.path.join(tempdir, 'test.txt'), 'w') as f: 80 | f.write('txt_contents') 81 | 82 | reader = FileReader() 83 | file = reader.open(os.path.join(tempdir, 'test.zip') + '@test.txt') 84 | self.assertIsNotNone(file) 85 | self.assertEqual(file.read(), b'zip_contents') 86 | file.close() 87 | 88 | with reader.open(os.path.join(tempdir, 'test.txt')) as f: 89 | self.assertEqual(f.read(), 'txt_contents') 90 | reader.close() 91 | 92 | 93 | if __name__ == '__main__': 94 | unittest.main() 95 | -------------------------------------------------------------------------------- /tests/test_ic_od_as_image_text_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tests.test_fixtures import DetectionTestFixtures 4 | from vision_datasets.common.constants import DatasetTypes 5 | from vision_datasets.image_text_matching import VisionAsImageTextDataset 6 | from vision_datasets.image_object_detection import DetectionAsClassificationIgnoreBoxesDataset 7 | 8 | 9 | class TestVisionAsImageTextDataset(unittest.TestCase): 10 | def test_od_as_image_text_dataset(self): 11 | n_images = 3 12 | dataset, tempdir = DetectionTestFixtures.create_an_od_dataset(n_images) 13 | with tempdir: 14 | it_dataset = VisionAsImageTextDataset(dataset) 15 | assert it_dataset.dataset_info.type == DatasetTypes.IMAGE_TEXT_MATCHING, it_dataset.dataset_info.type 16 | assert len(it_dataset) == n_images, len(it_dataset) 17 | matches = [label.label_data[1] for x, labels, _ in it_dataset for label in labels] 18 | assert len(matches) == n_images * 2 19 | assert len(set(matches)) == 1 and matches[0] == 1, matches 20 | 21 | def test_od_as_image_text_dataset_with_neg_pairs(self): 22 | n_images = 3 23 | n_categories = 10 24 | dataset, tempdir = DetectionTestFixtures.create_an_od_dataset(n_images, n_categories) 25 | with tempdir: 26 | it_dataset = VisionAsImageTextDataset(dataset, 3, rnd_seed=1) 27 | assert it_dataset.dataset_info.type == DatasetTypes.IMAGE_TEXT_MATCHING, it_dataset.dataset_info.type 28 | assert len(it_dataset) == n_images, len(it_dataset) 29 | matches = [label.label_data[1] for x, labels, _ in it_dataset for label in labels] 30 | assert sum(matches) == 6, matches 31 | assert len(matches) == 24, len(matches) 32 | 33 | def test_od_as_image_text_dataset_with_down_sampling_neg_pairs(self): 34 | n_images = 10 35 | n_categories = 10 36 | dataset, tempdir = DetectionTestFixtures.create_an_od_dataset(n_images, n_categories) 37 | with tempdir: 38 | it_dataset = VisionAsImageTextDataset(dataset, 0.3, rnd_seed=1) 39 | assert it_dataset.dataset_info.type == DatasetTypes.IMAGE_TEXT_MATCHING, it_dataset.dataset_info.type 40 | assert len(it_dataset) == n_images, len(it_dataset) 41 | matches = [label.label_data[1] for x, labels, _ in it_dataset for label in labels] 42 | assert sum(matches) == 20, matches 43 | assert len(matches) == 27, matches 44 | 45 | def test_od_as_image_text_dataset_with_neg_pairs_under_expected_ratio(self): 46 | n_images = 3 47 | n_categories = 3 48 | dataset, tempdir = DetectionTestFixtures.create_an_od_dataset(n_images, n_categories) 49 | with tempdir: 50 | it_dataset = VisionAsImageTextDataset(dataset, 3, rnd_seed=1) 51 | assert it_dataset.dataset_info.type == DatasetTypes.IMAGE_TEXT_MATCHING, it_dataset.dataset_info.type 52 | assert len(it_dataset) == n_images, len(it_dataset) 53 | matches = [label.label_data[1] for x, labels, _ in it_dataset for label in labels] 54 | assert sum(matches) == 6, matches 55 | assert len(matches) == 9, matches 56 | 57 | 58 | class TestClassificationAsImageTextDataset(unittest.TestCase): 59 | def test_ic_as_image_text_dataset(self): 60 | n_images = 10 61 | n_categories = 10 62 | dataset, tempdir = DetectionTestFixtures.create_an_od_dataset(n_images, n_categories) 63 | with tempdir: 64 | dataset = DetectionAsClassificationIgnoreBoxesDataset(dataset) 65 | it_dataset = VisionAsImageTextDataset(dataset, 0.3, rnd_seed=1) 66 | assert it_dataset.dataset_info.type == DatasetTypes.IMAGE_TEXT_MATCHING, it_dataset.dataset_info.type 67 | assert len(it_dataset) == n_images, len(it_dataset) 68 | matches = [label.label_data[1] for x, labels, _ in it_dataset for label in labels] 69 | assert sum(matches) == 20, matches 70 | assert len(matches) == 27, matches 71 | -------------------------------------------------------------------------------- /tests/test_manifest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_manifest/__init__.py -------------------------------------------------------------------------------- /tests/test_manifest/test_key_value_pair.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from vision_datasets.key_value_pair.manifest import KeyValuePairLabelManifest 4 | 5 | 6 | class TestKeyValuePair(unittest.TestCase): 7 | id = 1 8 | img_ids = [1, 2] 9 | 10 | def test_simple(self): 11 | manifest = KeyValuePairLabelManifest(self.id, self.img_ids, {'fields': {'key1': 'val1'}, 'text': {'query': 'what fields are there?'}}) 12 | self.assertEqual(manifest.id, self.id) 13 | self.assertEqual(manifest.img_ids, self.img_ids) 14 | self.assertEqual(manifest.fields, {'key1': 'val1'}) 15 | self.assertEqual(manifest.text, {'query': 'what fields are there?'}) 16 | 17 | def test_missing_key_value_pair(self): 18 | with self.assertRaises(ValueError): 19 | KeyValuePairLabelManifest(self.id, self.img_ids, {'text': {'query': 'what fields are there?'}}) 20 | 21 | def test_missing_text(self): 22 | manifest = KeyValuePairLabelManifest(self.id, self.img_ids, {'fields': {'key1': 'val1'}}) 23 | self.assertIsNone(manifest.text) 24 | -------------------------------------------------------------------------------- /tests/test_manifest/test_pickleable.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pickle 3 | from vision_datasets.common import DatasetTypes 4 | from ..resources.util import coco_dict_to_manifest, coco_database, schema_database 5 | 6 | 7 | class TestManifestIsPickleable: 8 | @pytest.mark.parametrize("task, coco_dict", [(task, coco_dict) for task, coco_dicts in coco_database.items() if task != DatasetTypes.KEY_VALUE_PAIR for coco_dict in coco_dicts]) 9 | def test_create_data_manifest(self, task, coco_dict): 10 | manifest = coco_dict_to_manifest(task, coco_dict) 11 | self._check_pickleable(manifest) 12 | 13 | @pytest.mark.parametrize("coco_dict, schema", zip(coco_database[DatasetTypes.KEY_VALUE_PAIR], schema_database)) 14 | def test_create_key_value_pair_manifest(self, coco_dict, schema): 15 | manifest = coco_dict_to_manifest(DatasetTypes.KEY_VALUE_PAIR, coco_dict, schema) 16 | self._check_pickleable(manifest) 17 | 18 | @staticmethod 19 | def _check_pickleable(manifest): 20 | serialized = pickle.dumps(manifest) 21 | deserialized = pickle.loads(serialized) 22 | assert manifest == deserialized 23 | -------------------------------------------------------------------------------- /tests/test_manifest/test_visual_object_grounding.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from vision_datasets.visual_object_grounding.manifest import VisualObjectGroundingLabelManifest 4 | 5 | 6 | class TestVisualObjectGrounding(unittest.TestCase): 7 | def test_simple(self): 8 | manifest = VisualObjectGroundingLabelManifest({'question': 'hello', 'answer': 'world', 'groundings': [{'id': 1, 'text_span': [0, 2], 'text': 'he', 'bboxes': [[0, 0, 10, 10]]}]}) 9 | self.assertEqual(manifest.question, 'hello') 10 | self.assertEqual(manifest.answer, 'world') 11 | self.assertEqual(manifest.groundings[0].id, 1) 12 | self.assertEqual(manifest.groundings[0].text_span, [0, 2]) 13 | self.assertEqual(manifest.groundings[0].text, 'he') 14 | self.assertEqual(manifest.groundings[0].bboxes, [[0, 0, 10, 10]]) 15 | -------------------------------------------------------------------------------- /tests/test_manifest_to_coco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_manifest_to_coco/__init__.py -------------------------------------------------------------------------------- /tests/test_manifest_to_coco/test_manifest_to_coco.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_datasets.common import CocoDictGeneratorFactory, DatasetTypes 3 | 4 | from ..resources.util import coco_database, schema_database, coco_dict_to_manifest 5 | 6 | 7 | class TestManifestToCoco: 8 | @pytest.mark.parametrize("coco_dict, task", [(coco_dict, task) for task, coco_dicts in coco_database.items() if task not in [DatasetTypes.MULTITASK, DatasetTypes.KEY_VALUE_PAIR] 9 | for coco_dict in coco_dicts]) 10 | def test_manifest_to_coco_dict(self, coco_dict, task): 11 | manifest = coco_dict_to_manifest(task, coco_dict) 12 | coco_generator = CocoDictGeneratorFactory.create(task) 13 | coco_dict = coco_generator.run(manifest) 14 | 15 | @pytest.mark.parametrize("coco_dict, schema", zip(coco_database[DatasetTypes.KEY_VALUE_PAIR], schema_database)) 16 | def test_key_value_pair_manifest_to_coco_dict(self, coco_dict, schema): 17 | manifest = coco_dict_to_manifest(DatasetTypes.KEY_VALUE_PAIR, coco_dict, schema) 18 | coco_generator = CocoDictGeneratorFactory.create(DatasetTypes.KEY_VALUE_PAIR) 19 | coco_dict = coco_generator.run(manifest) 20 | -------------------------------------------------------------------------------- /tests/test_merge_manifest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_merge_manifest/__init__.py -------------------------------------------------------------------------------- /tests/test_sample/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_sample/__init__.py -------------------------------------------------------------------------------- /tests/test_sample/test_sample_manfiest.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytest 4 | 5 | from vision_datasets.common import DatasetTypes, ManifestSampler, SampleByFewShotConfig, SampleByNumSamplesConfig, SampleStrategyFactory, SampleStrategyType 6 | 7 | from ..resources.util import coco_database, schema_database, coco_dict_to_manifest 8 | 9 | 10 | DATATYPES_N_SAMPLE = [x for x in SampleStrategyFactory.list_data_types(SampleStrategyType.NumSamples) if x not in [DatasetTypes.MULTITASK, DatasetTypes.KEY_VALUE_PAIR]] 11 | 12 | 13 | class TestSampleManifestNumSamples: 14 | @pytest.mark.parametrize("task, coco_dict, with_replacement", 15 | [x[0] + (x[1],) 16 | for x in itertools.product( 17 | [(task, coco_dict) for task in DATATYPES_N_SAMPLE for coco_dict in coco_database[task]], 18 | [True, False])]) 19 | def test_sample_data_manifest_by_n_samples_single_task(self, task, coco_dict, with_replacement): 20 | manifest = coco_dict_to_manifest(task, coco_dict) 21 | n_samples = max(len(manifest.images) // 2, 1) 22 | sampler_strategy = SampleStrategyFactory.create(task, SampleStrategyType.NumSamples, SampleByNumSamplesConfig(0, with_replacement, n_samples)) 23 | sampler = ManifestSampler(sampler_strategy) 24 | sampled_manifest = sampler.run(manifest) 25 | assert n_samples == len(sampled_manifest.images) 26 | 27 | @pytest.mark.parametrize("test_case_index, with_replacement", [x for x in itertools.product(range(len(coco_database[DatasetTypes.KEY_VALUE_PAIR])), [True, False])]) 28 | def test_sample_data_manifest_by_n_samples_kvp(self, test_case_index, with_replacement): 29 | task = DatasetTypes.KEY_VALUE_PAIR 30 | manifest = coco_dict_to_manifest(task, coco_database[task][test_case_index], schema_database[test_case_index]) 31 | n_samples = max(len(manifest.annotations) // 2, 1) 32 | sampler_strategy = SampleStrategyFactory.create(task, SampleStrategyType.NumSamples, SampleByNumSamplesConfig(0, with_replacement=with_replacement, n_samples=n_samples)) 33 | sampler = ManifestSampler(sampler_strategy) 34 | sampled_manifest = sampler.run(manifest) 35 | assert n_samples == len(sampled_manifest.annotations) 36 | 37 | 38 | DATATYPES_FEW_SHOT = [x for x in SampleStrategyFactory.list_data_types(SampleStrategyType.FewShot) if x != DatasetTypes.MULTITASK] 39 | 40 | 41 | @pytest.mark.parametrize("task, coco_dict", [(task, coco_dict) for task, coco_dicts in coco_database.items() if task in DATATYPES_FEW_SHOT for coco_dict in coco_dicts]) 42 | class TestSampleManifestFewShots: 43 | def test_sample_data_manifest_by_few_shot_single_task(self, task, coco_dict): 44 | manifest = coco_dict_to_manifest(task, coco_dict) 45 | n_few_shot = 1 46 | sampler_strategy = SampleStrategyFactory.create(task, SampleStrategyType.FewShot, SampleByFewShotConfig(0, n_few_shot)) 47 | sampler = ManifestSampler(sampler_strategy) 48 | sampler.run(manifest) 49 | 50 | def test_sample_data_manifest_by_few_shot_single_task_should_throw(self, task, coco_dict): 51 | manifest = coco_dict_to_manifest(task, coco_dict) 52 | n_few_shot = len(manifest.images) + 1 53 | sampler_strategy = SampleStrategyFactory.create(task, SampleStrategyType.FewShot, SampleByFewShotConfig(0, n_few_shot)) 54 | sampler = ManifestSampler(sampler_strategy) 55 | with pytest.raises(RuntimeError, match=fr"Couldn't find {n_few_shot} samples for some classes:.*"): 56 | sampler.run(manifest) 57 | -------------------------------------------------------------------------------- /tests/test_spawn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_spawn/__init__.py -------------------------------------------------------------------------------- /tests/test_spawn/test_spawn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_datasets.common import DatasetTypes, Spawn, SpawnConfig, SpawnFactory 3 | 4 | from ..resources.util import coco_database, coco_dict_to_manifest 5 | 6 | DATA_TYPES = [x for x in SpawnFactory.list_data_types() if x != DatasetTypes.MULTITASK] 7 | 8 | 9 | class TestSpawn: 10 | @pytest.mark.parametrize("task, coco_dict", [(task, coco_dict) for task in DATA_TYPES for coco_dict in coco_database[task]]) 11 | def test_spawn_single_task(self, task, coco_dict): 12 | manifest = coco_dict_to_manifest(task, coco_dict) 13 | n_target = len(manifest.images) * 2 14 | cfg = SpawnConfig(0, n_target) 15 | sp = Spawn(cfg) 16 | spawned_manifest = sp.run(manifest) 17 | assert len(spawned_manifest.images) == n_target 18 | -------------------------------------------------------------------------------- /tests/test_stand_along_image_dict_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_stand_along_image_dict_generation/__init__.py -------------------------------------------------------------------------------- /tests/test_stand_along_image_dict_generation/test_generate_standalone_image_dict.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch 3 | 4 | from vision_datasets.common import StandAloneImageListGeneratorFactory, Base64Utils, DatasetTypes 5 | 6 | from ..resources.util import coco_database, coco_dict_to_manifest 7 | 8 | DATASET_TYPES = StandAloneImageListGeneratorFactory.list_data_types() 9 | 10 | 11 | class TestManifestToStandAloneImageDict: 12 | 13 | @pytest.mark.parametrize("coco_dict, task", [(coco_dict, task) for task in DATASET_TYPES for coco_dict in coco_database[task]]) 14 | def test_manifest_to_standalone_image_dict_flattened(self, coco_dict, task): 15 | self._test_manifest_to_standalone_image_dict_flattened(coco_dict, task, True, len(coco_dict['annotations'])) 16 | 17 | @pytest.mark.parametrize("coco_dict, task", [(coco_dict, task) for task in DATASET_TYPES for coco_dict in coco_database[task]]) 18 | def test_manifest_to_standalone_image_dict_not_flattened(self, coco_dict, task): 19 | def custom_item_check(item): return isinstance(item['labels'], list) 20 | self._test_manifest_to_standalone_image_dict_flattened(coco_dict, task, False, len(coco_dict['images']), custom_item_check) 21 | 22 | def _test_manifest_to_standalone_image_dict_flattened(self, coco_dict, task, flatten, expected_num_items, custom_check=lambda _: True): 23 | manifest = coco_dict_to_manifest(task, coco_dict) 24 | with patch.object(Base64Utils, 'file_to_b64_str', return_value="b64string") as mocked_method: 25 | coco_generator = StandAloneImageListGeneratorFactory.create(task, flatten) 26 | n_items = 0 27 | for item in coco_generator.run(manifest): 28 | assert item['image'] == "b64string" 29 | assert custom_check(item) 30 | n_items += 1 31 | assert n_items == expected_num_items 32 | self._mock_check_by_type(task, mocked_method, coco_dict) 33 | 34 | def _mock_check_by_type(self, data_type, mocked_method, coco_dict: dict): 35 | if data_type == DatasetTypes.IMAGE_MATTING: 36 | assert mocked_method.call_count == len(coco_dict['images']) + len(coco_dict['annotations']) 37 | else: 38 | assert mocked_method.call_count == len(coco_dict['images']) 39 | -------------------------------------------------------------------------------- /tests/test_torch_dataest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/tests/test_torch_dataest/__init__.py -------------------------------------------------------------------------------- /tests/test_torch_dataest/test_torch_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | import pickle 5 | import tempfile 6 | 7 | import pytest 8 | from PIL import Image 9 | 10 | from vision_datasets import DatasetInfo, DatasetTypes, VisionDataset 11 | from vision_datasets.torch import TorchDataset 12 | 13 | from ..resources.util import coco_database, coco_dict_to_manifest, schema_database 14 | 15 | 16 | class FakeDataset: 17 | pass 18 | 19 | 20 | def _one_arg_method(x): 21 | return x 22 | 23 | 24 | class TestTorchDataset: 25 | @pytest.mark.parametrize("data_type, coco_dicts", [(data_type, coco_database[data_type]) for data_type in DatasetTypes 26 | if data_type not in [DatasetTypes.MULTITASK, DatasetTypes.KEY_VALUE_PAIR]]) 27 | def test_create_torch_dataset(self, data_type, coco_dicts): 28 | coco_dict = coco_dicts[0] 29 | manifest = coco_dict_to_manifest(data_type, coco_dict) 30 | with tempfile.TemporaryDirectory() as temp_dir: 31 | tdir = pathlib.Path(temp_dir) 32 | (tdir / 'test.json').write_text(json.dumps(coco_dict)) 33 | 34 | dataset_info = DatasetInfo({ 35 | 'name': 'test', 36 | 'type': data_type.name, 37 | 'root_folder': tdir.as_posix(), 38 | 'format': 'coco', 39 | 'train': {'index_path': 'test.json'} 40 | }) 41 | td = TorchDataset(VisionDataset(dataset_info, manifest)) 42 | 43 | for image in manifest.images: 44 | image.img_path = image.img_path.split('@')[1] if '@' in image.img_path else image.img_path 45 | image.img_path = tdir / image.img_path 46 | os.makedirs(image.img_path.parent, exist_ok=True) 47 | image.img_path = image.img_path.as_posix() 48 | Image.new(mode="RGB", size=(20, 20)).save(image.img_path) 49 | 50 | for x in td: 51 | pass 52 | 53 | td[0:-1] 54 | 55 | @pytest.mark.parametrize("coco_dict, schema", zip(coco_database[DatasetTypes.KEY_VALUE_PAIR], schema_database)) 56 | def test_create_key_value_pair_torch_dataset(self, coco_dict, schema): 57 | manifest = coco_dict_to_manifest(DatasetTypes.KEY_VALUE_PAIR, coco_dict, schema) 58 | with tempfile.TemporaryDirectory() as temp_dir: 59 | tdir = pathlib.Path(temp_dir) 60 | (tdir / 'test.json').write_text(json.dumps(coco_dict)) 61 | 62 | dataset_info = DatasetInfo({ 63 | 'name': 'test', 64 | 'type': DatasetTypes.KEY_VALUE_PAIR.name, 65 | 'root_folder': tdir.as_posix(), 66 | 'format': 'coco', 67 | 'train': {'index_path': 'test.json'}, 68 | 'schema': schema 69 | }) 70 | td = TorchDataset(VisionDataset(dataset_info, manifest)) 71 | 72 | for image in manifest.images: 73 | image.img_path = image.img_path.split('@')[1] if '@' in image.img_path else image.img_path 74 | image.img_path = tdir / image.img_path 75 | os.makedirs(image.img_path.parent, exist_ok=True) 76 | image.img_path = image.img_path.as_posix() 77 | Image.new(mode="RGB", size=(20, 20)).save(image.img_path) 78 | 79 | for x in td: 80 | pass 81 | 82 | td[0:-1] 83 | 84 | def test_picklable(self): 85 | dataset = TorchDataset(FakeDataset()) 86 | serialized = pickle.dumps(dataset) 87 | new_dataset = pickle.loads(serialized) 88 | assert isinstance(new_dataset, TorchDataset) 89 | 90 | dataset = TorchDataset(FakeDataset(), _one_arg_method) 91 | serialized = pickle.dumps(dataset) 92 | new_dataset = pickle.loads(serialized) 93 | assert isinstance(new_dataset, TorchDataset) 94 | 95 | dataset = TorchDataset(FakeDataset()) 96 | dataset.transform = None 97 | serialized = pickle.dumps(dataset) 98 | new_dataset = pickle.loads(serialized) 99 | assert isinstance(new_dataset, TorchDataset) 100 | 101 | def test_transform(self): 102 | dataset = TorchDataset(FakeDataset(), None) 103 | assert dataset.transform(1, 2) == (1, 2) 104 | dataset = TorchDataset(FakeDataset(), lambda x: x) 105 | assert dataset.transform(1, 2) == (1, 2) 106 | dataset = TorchDataset(FakeDataset(), lambda x, y: (x, y)) 107 | assert dataset.transform(1, 2) == (1, 2) 108 | 109 | dataset.transform = lambda x: x 110 | assert dataset.transform(1, 2) == (1, 2) 111 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import tempfile 3 | import unittest 4 | 5 | import numpy 6 | from PIL import Image, ImageChops 7 | 8 | from vision_datasets.common import Base64Utils 9 | 10 | 11 | class TestBase64Utils(unittest.TestCase): 12 | @staticmethod 13 | def _create_rand_img(size=(100, 100)) -> Image.Image: 14 | assert len(size) == 2 15 | 16 | imarray = numpy.random.rand(size[0], size[1], 3) * 255 17 | im = Image.fromarray(imarray.astype('uint8')).convert('RGB') 18 | return im 19 | 20 | def test_b64_to_file_loses_no_info(self): 21 | with tempfile.TemporaryDirectory() as temp_dir: 22 | temp_dir = pathlib.Path(temp_dir) 23 | img_filepath_1 = temp_dir / 'temp_1.jpg' 24 | TestBase64Utils._create_rand_img().save(img_filepath_1) 25 | 26 | img_filepath_2 = temp_dir / 'temp_2.jpg' 27 | b64str = Base64Utils.file_to_b64_str(img_filepath_1) 28 | Base64Utils.b64_str_to_file(b64str, img_filepath_2) 29 | img1 = Image.open(img_filepath_1) 30 | img2 = Image.open(img_filepath_2) 31 | assert not ImageChops.difference(img1, img2).getbbox() 32 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,build,dist,*venv 3 | max-line-length = 200 4 | 5 | [pytest] 6 | junit_family = xunit2 -------------------------------------------------------------------------------- /vision_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import AnnotationFormats, BalancedInstanceWeightsFactory, BBoxFormat, CocoDictGeneratorFactory, CocoManifestAdaptorFactory, DataManifestFactory, DatasetHub, DatasetInfo, \ 2 | DatasetManifest, DatasetRegistry, DatasetTypes, ImageDataManifest, ImageLabelManifest, ImageLabelWithCategoryManifest, ManifestMergeStrategyFactory, SampleStrategyFactory, \ 3 | SpawnFactory, SplitFactory, SupportedOperationsByDataType, Usages, VisionDataset 4 | from .image_caption import ImageCaptionLabelManifest 5 | from .image_classification import ImageClassificationLabelManifest 6 | from .image_matting import ImageMattingLabelManifest 7 | from .image_object_detection import ImageObjectDetectionLabelManifest 8 | from .image_regression import ImageRegressionLabelManifest 9 | from .image_text_matching import ImageTextMatchingLabelManifest 10 | from .key_value_pair import KeyValuePairLabelManifest 11 | from .multi_task import MultitaskMerge 12 | from .text_2_image_retrieval import Text2ImageRetrievalLabelManifest 13 | from .visual_question_answering import VisualQuestionAnsweringLabelManifest 14 | from .visual_object_grounding import VisualObjectGroundingLabelManifest 15 | 16 | __all__ = ['Usages', 'DatasetTypes', 'AnnotationFormats', 'BBoxFormat', 'DatasetInfo', 17 | 'DatasetManifest', 'ImageDataManifest', 'ImageLabelManifest', 'ImageLabelWithCategoryManifest', 18 | 'VisionDataset', 19 | 'DatasetHub', 'DatasetRegistry', 20 | 'CocoManifestAdaptorFactory', 'DataManifestFactory', 21 | 'BalancedInstanceWeightsFactory', 'CocoDictGeneratorFactory', 'ManifestMergeStrategyFactory', 'SampleStrategyFactory', 'SpawnFactory', 'SplitFactory', 'SupportedOperationsByDataType', 22 | 'ImageCaptionLabelManifest', 'ImageClassificationLabelManifest', 'ImageMattingLabelManifest', 'ImageObjectDetectionLabelManifest', 'ImageRegressionLabelManifest', 23 | 'ImageTextMatchingLabelManifest', 'KeyValuePairLabelManifest', 'MultitaskMerge', 'Text2ImageRetrievalLabelManifest', 24 | 'VisualQuestionAnsweringLabelManifest', 'VisualObjectGroundingLabelManifest'] 25 | -------------------------------------------------------------------------------- /vision_datasets/commands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/vision-datasets/21d86a7219620b790b8a5fdaf9528e6934dd0f18/vision_datasets/commands/__init__.py -------------------------------------------------------------------------------- /vision_datasets/commands/converter_od_to_ic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert a detection dataset into classification dataset 3 | """ 4 | 5 | import argparse 6 | import multiprocessing 7 | import os 8 | import pathlib 9 | import shutil 10 | 11 | from vision_datasets.common import CocoDictGeneratorFactory, DatasetHub, DatasetTypes 12 | from vision_datasets.image_object_detection import DetectionAsClassificationByCroppingDataset 13 | 14 | from .utils import add_args_to_locate_dataset, get_or_generate_data_reg_json_and_usages, set_up_cmd_logger, write_to_json_file_utf8 15 | 16 | logger = set_up_cmd_logger(__name__) 17 | 18 | 19 | def create_arg_parser(): 20 | parser = argparse.ArgumentParser(description='Convert detection dataset to classification dataset.') 21 | add_args_to_locate_dataset(parser) 22 | parser.add_argument('-o', '--output_folder', type=pathlib.Path, required=True, help='target folder of the converted classification dataset') 23 | parser.add_argument('-zb', '--zoom_ratio_bounds', type=str, required=False, 24 | help='lower and bound of the ratio that box height and width can expand (>1) or shrink (0-1), during cropping, e.g, 0.8/1.2') 25 | parser.add_argument('-sb', '--shift_relative_bounds', type=str, required=False, 26 | help='lower/upper bounds of relative ratio wrt box width and height that a box can shift, during cropping, e.g., "-0.3/0.1"') 27 | parser.add_argument('-np', '--n_copies', type=int, required=False, default=1, help='number of copies per bbox') 28 | parser.add_argument('-s', '--rnd_seed', type=int, required=False, help='random see for box expansion/shrink/shifting.', default=0) 29 | 30 | return parser 31 | 32 | 33 | def process_usage(params): 34 | args, data_reg_json, aug_params, usage = params 35 | 36 | logger.info(f'download dataset manifest for {args.name}...') 37 | dataset_resources = DatasetHub(data_reg_json, args.blob_container, args.local_dir.as_posix()) 38 | dataset = dataset_resources.create_vision_dataset(args.name, usage=usage, coordinates='absolute') 39 | if not dataset: 40 | logger.info(f'Skipping non-existent phase {usage}.') 41 | return 42 | 43 | if dataset.dataset_info.type != DatasetTypes.IMAGE_OBJECT_DETECTION: 44 | raise ValueError(f'Data type must be {DatasetTypes.IMAGE_OBJECT_DETECTION}') 45 | logger.info(f'start conversion for {args.name}...') 46 | ic_dataset = DetectionAsClassificationByCroppingDataset(dataset, aug_params) 47 | ic_manifest = ic_dataset.generate_manifest(dir=str(usage), n_copies=args.n_copies) 48 | 49 | coco_gen = CocoDictGeneratorFactory.create(DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 50 | coco = coco_gen.run(ic_manifest) 51 | write_to_json_file_utf8(coco, args.output_folder / f'{usage}.json') 52 | shutil.move(f'{usage}', f'{args.output_folder.as_posix()}/', copy_function=shutil.copytree) 53 | 54 | 55 | def main(): 56 | arg_parser = create_arg_parser() 57 | args = arg_parser.parse_args() 58 | aug_params = {} 59 | if args.zoom_ratio_bounds: 60 | low, up = args.zoom_ratio_bounds.split('/') 61 | aug_params['zoom_ratio_bounds'] = (float(low), float(up)) 62 | 63 | if args.shift_relative_bounds: 64 | low, up = args.shift_relative_bounds.split('/') 65 | aug_params['shift_relative_bounds'] = (float(low), float(up)) 66 | 67 | if aug_params: 68 | aug_params['rnd_seed'] = args.rnd_seed 69 | 70 | if not os.path.exists(args.output_folder): 71 | os.makedirs(args.output_folder) 72 | 73 | if args.blob_container and args.local_dir: 74 | args.local_dir.mkdir(parents=True, exist_ok=True) 75 | 76 | data_reg_json, usages = get_or_generate_data_reg_json_and_usages(args) 77 | params = [(args, data_reg_json, aug_params, phase) for phase in usages] 78 | 79 | with multiprocessing.Pool(len(usages)) as pool: 80 | pool.map(process_usage, params) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /vision_datasets/commands/converter_to_aml_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pathlib 4 | from urllib.parse import urlparse, urlunparse 5 | 6 | from tqdm import tqdm 7 | 8 | from vision_datasets.commands.utils import add_args_to_locate_dataset, get_or_generate_data_reg_json_and_usages 9 | from vision_datasets.common import FileReader, PILImageLoader, CocoDictGeneratorFactory, DatasetHub, DatasetTypes 10 | 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | 15 | def create_arg_parser(): 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='Convert dataset to AML coco format (IC, OD only).') 19 | add_args_to_locate_dataset(parser) 20 | parser.add_argument('-o', '--output_dir', required=True, type=pathlib.Path, help='output dir for coco file(s).') 21 | 22 | return parser 23 | 24 | 25 | def keep_base_url(url_path: str): 26 | url_parts = urlparse(url_path) 27 | return urlunparse((url_parts.scheme, url_parts.netloc, url_parts.path, None, None, None)) 28 | 29 | 30 | def main(): 31 | args = create_arg_parser().parse_args() 32 | assert args.blob_container, '"blob_container" is required for generating "coco_url"' 33 | assert args.local_dir is None, 'Accessing data from "local_dir" is not supported for now. Data must be present in blob_container.' 34 | 35 | data_reg_json, usages = get_or_generate_data_reg_json_and_usages(args) 36 | dataset_hub = DatasetHub(data_reg_json, args.blob_container, args.local_dir) 37 | dataset_info = dataset_hub.dataset_registry.get_dataset_info(args.name, args.version) 38 | 39 | if not dataset_info: 40 | logger.error(f'dataset {args.name} does not exist.') 41 | return 42 | 43 | assert dataset_info.type in [DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL, DatasetTypes.IMAGE_OBJECT_DETECTION] 44 | 45 | coco_gen = CocoDictGeneratorFactory.create(dataset_info.type) 46 | file_reader = FileReader() 47 | for usage in usages: 48 | manifest, _, _ = dataset_hub.create_dataset_manifest(args.name, version=1, usage=usage) 49 | if manifest is None: 50 | logger.info(f'{usage} not exist. Skipping.') 51 | continue 52 | 53 | coco_dict = coco_gen.run(manifest) 54 | for image in tqdm(coco_dict['images'], f'{usage}: Processing images...'): 55 | image['coco_url'] = keep_base_url(image['file_name']) 56 | if not image.get('width') or not image.get('height'): 57 | with file_reader.open(image['file_name'], 'rb') as f: 58 | img = PILImageLoader.load_from_stream(f) 59 | image['width'], image['height'] = img.size 60 | image['file_name'] = image['coco_url'][len(urlunparse(urlparse(keep_base_url(args.blob_container)))):] 61 | 62 | if dataset_info.type == DatasetTypes.IMAGE_OBJECT_DETECTION: 63 | image_wh_by_id = {x['id']: (x['width'], x['height']) for x in coco_dict['images']} 64 | for ann in tqdm(coco_dict['annotations'], f'{usage}: Processing bbox...'): 65 | w, h = image_wh_by_id[ann['image_id']] 66 | box = ann['bbox'] 67 | ann['bbox'] = [box[0]/w, box[1]/h, box[2]/w, box[3]/h] 68 | 69 | output_dir = pathlib.Path(args.output_dir) 70 | output_dir.mkdir(exist_ok=True) 71 | coco_filepath = pathlib.Path(args.output_dir) / f'{dataset_info.name}_{usage}.json' 72 | coco_filepath.write_text(json.dumps(coco_dict, ensure_ascii=False, indent=2), encoding='utf-8') 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /vision_datasets/commands/converter_to_line_oriented_format.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts a dataset from vision_datasets or COCO JSON into TSV or JSONL format. 3 | Each line is an image-oriented representation of an image and its annotations. 4 | """ 5 | 6 | import argparse 7 | import pathlib 8 | from enum import Enum 9 | 10 | from vision_datasets.common import DatasetHub 11 | from vision_datasets.commands.utils import add_args_to_locate_dataset, convert_to_tsv, convert_to_jsonl, enum_type, get_or_generate_data_reg_json_and_usages, set_up_cmd_logger 12 | 13 | logger = set_up_cmd_logger(__name__) 14 | 15 | 16 | class LineFormat(Enum): 17 | JSONL = 'jsonl' 18 | TSV = 'tsv' 19 | 20 | 21 | def logging_prefix(dataset_name, version, format): 22 | return f'Dataset {dataset_name} version {version}, convert to {format}:' 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser('Convert a dataset to TSV(s) or JONSL(s).') 27 | add_args_to_locate_dataset(parser) 28 | parser.add_argument('--format', '-fm', type=enum_type(LineFormat), default=LineFormat.JSONL, help='Format of output data.', choices=list(LineFormat), required=False) 29 | parser.add_argument('--flatten', '-fl', action='store_true', help="If an image has multiple annotations, one image will be flattend in to multiple entries with image being duplicated.") 30 | parser.add_argument('--output_dir', '-o', type=pathlib.Path, required=False, default=pathlib.Path('./'), help='File(s) will be saved here.') 31 | 32 | args = parser.parse_args() 33 | prefix = logging_prefix(args.name, args.version, args.format) 34 | 35 | data_reg_json, usages = get_or_generate_data_reg_json_and_usages(args) 36 | 37 | hub = DatasetHub(data_reg_json, args.blob_container, args.local_dir.as_posix()) 38 | if not hub.dataset_registry.get_dataset_info(args.name, args.version): 39 | raise RuntimeError(f'{prefix} dataset does not exist.') 40 | 41 | if args.blob_container and args.local_dir: 42 | args.local_dir.mkdir(parents=True, exist_ok=True) 43 | 44 | for usage in usages: 45 | logger.info(f'{prefix} Check dataset with usage: {usage}.') 46 | # if args.local_dir is none, then this check will directly try to access data from azure blob. Images must be present in uncompressed folder on azure blob. 47 | manifest, _, _ = hub.create_dataset_manifest(name=args.name, version=args.version, usage=usage) 48 | if manifest is None: 49 | logger.info(f'{prefix} No split for {usage} available.') 50 | else: 51 | if args.format == LineFormat.JSONL: 52 | convert_to_jsonl(manifest, args.output_dir / f"{args.name}.{usage}.jsonl", args.flatten) 53 | else: 54 | convert_to_tsv(manifest, args.output_dir / f"{args.name}.{usage}.tsv") 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /vision_datasets/commands/download_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download a dataset from shared storage either in original format or converted to TSV 3 | """ 4 | 5 | import argparse 6 | import pathlib 7 | 8 | from vision_datasets.common import DatasetHub, DatasetRegistry, DatasetTypes 9 | from vision_datasets.commands.utils import add_args_to_locate_dataset_from_name_and_reg_json, convert_to_tsv, get_or_generate_data_reg_json_and_usages, set_up_cmd_logger 10 | 11 | logger = set_up_cmd_logger(__name__) 12 | 13 | TSV_SUPPORTED_TYPES = [DatasetTypes.IMAGE_CAPTION, DatasetTypes.IMAGE_OBJECT_DETECTION, DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL] 14 | 15 | 16 | def list_datasets(registry: DatasetRegistry): 17 | for dataset in registry.list_data_version_and_types(): 18 | logger.info(f"Name: {dataset['name']}, version: {dataset['version']}, type: {dataset['type']}") 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser('Download dataset from the shared storage') 23 | add_args_to_locate_dataset_from_name_and_reg_json(parser) 24 | 25 | parser.add_argument('--to_tsv', '-t', help='to tsv format or not.', action='store_true') 26 | 27 | args = parser.parse_args() 28 | dataset_reg_json, usages = get_or_generate_data_reg_json_and_usages(args) 29 | dataset_hub = DatasetHub(dataset_reg_json, args.blob_container, args.local_dir) 30 | name = args.name 31 | dataset_info = dataset_hub.dataset_registry.get_dataset_info(name) 32 | args.local_dir.mkdir(parents=True, exist_ok=True) 33 | 34 | if args.to_tsv: 35 | if dataset_info.type not in TSV_SUPPORTED_TYPES: 36 | logger.error(f'Unsupported data type for converting to TSV: {dataset_info.type}.') 37 | return 38 | 39 | logger.info(f'downloading {name}...') 40 | for usage in usages: 41 | dataset_manifest, _, _ = dataset_hub.create_dataset_manifest(name, usage=usage) 42 | if dataset_manifest is None: 43 | continue 44 | 45 | logger.info(f'converting {name}, usage {usage} to TSV format...') 46 | convert_to_tsv(dataset_manifest, pathlib.Path(args.local_dir) / f'{name}-{usage}.tsv') 47 | else: 48 | for usage in usages: 49 | dataset_hub.create_vision_dataset(name, usage=usage) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /vision_datasets/commands/list_operations_by_data_type.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download a dataset from shared storage either in original format or converted to TSV 3 | """ 4 | 5 | import argparse 6 | 7 | from vision_datasets.common import DatasetTypes, SupportedOperationsByDataType 8 | from vision_datasets.commands.utils import set_up_cmd_logger 9 | 10 | from .utils import enum_type 11 | 12 | logger = set_up_cmd_logger(__name__) 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser('List supported operations by data type') 17 | 18 | parser.add_argument('--data_type', '-d', help='list supported operations by data type.', type=enum_type(DatasetTypes), required=True, choices=list(DatasetTypes)) 19 | args = parser.parse_args() 20 | print(SupportedOperationsByDataType.list(args.data_type)) 21 | 22 | 23 | if __name__ == '__main__': 24 | main() 25 | -------------------------------------------------------------------------------- /vision_datasets/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import AnnotationFormats, BBoxFormat, DatasetTypes, Usages 2 | from .data_manifest import BalancedInstanceWeightsGenerator, CategoryManifest, DatasetFilter, DatasetManifest, GenerateCocoDictBase, MultiImageCocoDictGenerator, ImageDataManifest, ImageFilter, \ 3 | ImageLabelManifest, ImageLabelWithCategoryManifest, ImageNoAnnotationFilter, ManifestMerger, ManifestSampler, MergeStrategy, MultiImageDatasetSingleTaskMerge, DatasetManifestWithMultiImageLabel, \ 4 | MultiImageLabelManifest, Operation, RemoveCategories, RemoveCategoriesConfig, SampleBaseConfig, SampleByFewShotConfig, SampleByNumSamples, SampleByNumSamplesConfig, SampleFewShot, \ 5 | SampleStrategy, SampleStrategyType, SingleTaskMerge, Spawn, SpawnConfig, Split, SplitConfig, SplitWithCategories, WeightsGenerationConfig, CocoManifestWithoutCategoriesAdaptor, \ 6 | CocoManifestWithCategoriesAdaptor, CocoManifestWithMultiImageLabelAdaptor, CocoManifestAdaptorBase, GenerateStandAloneImageListBase 7 | from .dataset_info import BaseDatasetInfo, DatasetInfo, DatasetInfoFactory, KeyValuePairDatasetInfo, MultiTaskDatasetInfo 8 | from .data_reader import DatasetDownloader, FileReader, PILImageLoader 9 | from .dataset import VisionDataset 10 | from .factory import CocoManifestAdaptorFactory, CocoDictGeneratorFactory, ManifestMergeStrategyFactory, DataManifestFactory, SampleStrategyFactory, BalancedInstanceWeightsFactory, SpawnFactory, \ 11 | SplitFactory, StandAloneImageListGeneratorFactory, SupportedOperationsByDataType 12 | from .dataset_management import DatasetHub, DatasetRegistry 13 | from .base64_utils import Base64Utils 14 | 15 | __all__ = [ 16 | 'Usages', 'DatasetTypes', 'AnnotationFormats', 'BBoxFormat', 'MultiImageDatasetSingleTaskMerge', 'DatasetManifestWithMultiImageLabel', 'MultiImageLabelManifest', 17 | 'ImageLabelManifest', 'ImageLabelWithCategoryManifest', 'ImageDataManifest', 'CategoryManifest', 'DatasetManifest', 18 | 'BalancedInstanceWeightsGenerator', 'WeightsGenerationConfig', 'DatasetFilter', 'ImageFilter', 'ImageNoAnnotationFilter', 'GenerateCocoDictBase', 'MultiImageCocoDictGenerator', 'ManifestMerger', 19 | 'MergeStrategy', 'SingleTaskMerge', 'Operation', 'RemoveCategories', 'RemoveCategoriesConfig', 'ManifestSampler', 'SampleBaseConfig', 'SampleByFewShotConfig', 'SampleByNumSamples', 20 | 'SampleByNumSamplesConfig', 'SampleFewShot', 'SampleStrategy', 'SampleStrategyType', 'Spawn', 'SpawnConfig', 'Split', 'SplitConfig', 'SplitWithCategories', 21 | 'CocoManifestWithoutCategoriesAdaptor', 'CocoManifestWithCategoriesAdaptor', 'CocoManifestWithMultiImageLabelAdaptor', 'CocoManifestAdaptorBase', 'GenerateStandAloneImageListBase', 22 | 'DatasetInfo', 'BaseDatasetInfo', 'KeyValuePairDatasetInfo', 'MultiTaskDatasetInfo', 'DatasetInfoFactory', 'DatasetDownloader', 'FileReader', 'PILImageLoader', 23 | 'VisionDataset', 24 | 'CocoManifestAdaptorFactory', 'CocoDictGeneratorFactory', 'ManifestMergeStrategyFactory', 'DataManifestFactory', 'SampleStrategyFactory', 'BalancedInstanceWeightsFactory', 'SpawnFactory', 25 | 'SplitFactory', 'StandAloneImageListGeneratorFactory', 'SupportedOperationsByDataType', 26 | 'DatasetHub', 'DatasetRegistry', 'Base64Utils' 27 | ] 28 | -------------------------------------------------------------------------------- /vision_datasets/common/base64_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import pathlib 4 | from typing import Union 5 | 6 | from .data_reader import FileReader, PILImageLoader 7 | 8 | 9 | class Base64Utils: 10 | @staticmethod 11 | def b64_str_to_pil(img_b64_str: str): 12 | assert img_b64_str 13 | 14 | return PILImageLoader.load_from_stream(io.BytesIO(base64.b64decode(img_b64_str))) 15 | 16 | @staticmethod 17 | def file_to_b64_str(filepath: pathlib.Path, file_reader=None): 18 | assert filepath 19 | 20 | fr = file_reader or FileReader() 21 | with fr.open(filepath.as_posix(), "rb") as file_in: 22 | return base64.b64encode(file_in.read()).decode('utf-8') 23 | 24 | @staticmethod 25 | def b64_str_to_file(b64_str: str, file_name: Union[pathlib.Path, str]): 26 | assert b64_str 27 | assert file_name 28 | 29 | with open(file_name, 'wb') as file_out: 30 | file_out.write(base64.b64decode(b64_str)) 31 | -------------------------------------------------------------------------------- /vision_datasets/common/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class DatasetTypes(Enum): 5 | MULTITASK = 1 6 | IMAGE_CLASSIFICATION_MULTILABEL = 2 7 | IMAGE_CLASSIFICATION_MULTICLASS = 3 8 | IMAGE_OBJECT_DETECTION = 4 9 | IMAGE_TEXT_MATCHING = 5 10 | IMAGE_MATTING = 6 11 | IMAGE_REGRESSION = 7 12 | # types below will be consolidated with image text matching in future 13 | IMAGE_CAPTION = 8 14 | TEXT_2_IMAGE_RETRIEVAL = 9 15 | VISUAL_QUESTION_ANSWERING = 10 16 | VISUAL_OBJECT_GROUNDING = 11 17 | KEY_VALUE_PAIR = 12 18 | 19 | 20 | class Usages(Enum): 21 | TRAIN = 1 22 | VAL = 2 23 | TEST = 3 24 | 25 | 26 | class AnnotationFormats(Enum): 27 | COCO = 1 28 | IRIS = 2 29 | 30 | 31 | class BBoxFormat(Enum): 32 | LTRB = 1 33 | LTWH = 2 34 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_manifest import CategoryManifest, DatasetManifest, ImageDataManifest, ImageLabelManifest, ImageLabelWithCategoryManifest, MultiImageLabelManifest, DatasetManifestWithMultiImageLabel 2 | from .operations import MultiImageDatasetSingleTaskMerge, BalancedInstanceWeightsGenerator, DatasetFilter, GenerateCocoDictBase, MultiImageCocoDictGenerator, GenerateStandAloneImageListBase, \ 3 | ImageFilter, ImageNoAnnotationFilter, ManifestMerger, ManifestSampler, MergeStrategy, Operation, RemoveCategories, RemoveCategoriesConfig, \ 4 | SampleBaseConfig, SampleByFewShotConfig, SampleByNumSamples, SampleByNumSamplesConfig, SampleFewShot, SampleStrategy, SampleStrategyType, SingleTaskMerge, \ 5 | Spawn, SpawnConfig, Split, SplitConfig, SplitWithCategories, WeightsGenerationConfig 6 | from .coco_manifest_adaptor import CocoManifestWithCategoriesAdaptor, CocoManifestWithoutCategoriesAdaptor, CocoManifestAdaptorBase, CocoManifestWithMultiImageLabelAdaptor 7 | 8 | __all__ = ["ImageLabelManifest", "ImageLabelWithCategoryManifest", "MultiImageLabelManifest", "ImageDataManifest", "CategoryManifest", "DatasetManifest", "DatasetManifestWithMultiImageLabel", 9 | "BalancedInstanceWeightsGenerator", "WeightsGenerationConfig", "DatasetFilter", "ImageFilter", "ImageNoAnnotationFilter", "GenerateCocoDictBase", "MultiImageCocoDictGenerator", 10 | "GenerateStandAloneImageListBase", "ManifestMerger", "MergeStrategy", "SingleTaskMerge", "MultiImageDatasetSingleTaskMerge", "Operation", 11 | "RemoveCategories", 12 | "RemoveCategoriesConfig", "ManifestSampler", "SampleBaseConfig", "SampleByFewShotConfig", "SampleByNumSamples", "SampleByNumSamplesConfig", "SampleFewShot", "SampleStrategy", 13 | "SampleStrategyType", "Spawn", "SpawnConfig", "Split", "SplitConfig", "SplitWithCategories", 14 | "CocoManifestWithCategoriesAdaptor", "CocoManifestWithoutCategoriesAdaptor", "CocoManifestAdaptorBase", "CocoManifestWithMultiImageLabelAdaptor"] 15 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/__init__.py: -------------------------------------------------------------------------------- 1 | from .balanced_instance_weights_generator import BalancedInstanceWeightsGenerator, WeightsGenerationConfig 2 | from .filter import DatasetFilter, ImageFilter, ImageNoAnnotationFilter 3 | from .generate_coco import GenerateCocoDictBase, MultiImageCocoDictGenerator 4 | from .generate_stand_alone_image_list_base import GenerateStandAloneImageListBase 5 | from .merge import MultiImageDatasetSingleTaskMerge, ManifestMerger, MergeStrategy, SingleTaskMerge 6 | from .operation import Operation 7 | from .remove_categories import RemoveCategories, RemoveCategoriesConfig 8 | from .sample import ManifestSampler, SampleBaseConfig, SampleByFewShotConfig, SampleByNumSamples, SampleByNumSamplesConfig, SampleFewShot, SampleStrategy, SampleStrategyType 9 | from .spawn import Spawn, SpawnConfig 10 | from .split import Split, SplitConfig, SplitWithCategories 11 | 12 | __all__ = ['Operation', 13 | 'GenerateCocoDictBase', 'MultiImageCocoDictGenerator', 14 | 'GenerateStandAloneImageListBase', 15 | 'MultiImageDatasetSingleTaskMerge', 'MergeStrategy', 'ManifestMerger', 'SingleTaskMerge', 16 | 'ManifestSampler', 'SampleBaseConfig', 'SampleByFewShotConfig', 'SampleByNumSamplesConfig', 'SampleStrategy', 'SampleStrategyType', 'SampleByNumSamples', 'SampleFewShot', 17 | 'Spawn', 'SpawnConfig', 18 | 'Split', 'SplitWithCategories', 'SplitConfig', 19 | 'ImageFilter', 'DatasetFilter', 'ImageNoAnnotationFilter', 20 | 'BalancedInstanceWeightsGenerator', 'WeightsGenerationConfig', 21 | 'RemoveCategories', 'RemoveCategoriesConfig'] 22 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/balanced_instance_weights_generator.py: -------------------------------------------------------------------------------- 1 | """Generate instance weights from DatasetManifest, which can be used for balancing the dataset by sampling instances based on the weights. Only works for classification, detection, multitask.""" 2 | 3 | import logging 4 | import typing 5 | from collections import Counter 6 | from dataclasses import dataclass 7 | 8 | import numpy 9 | 10 | from ..data_manifest import DatasetManifest, ImageLabelWithCategoryManifest 11 | from .operation import Operation 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @dataclass 17 | class WeightsGenerationConfig: 18 | soft: bool = True # less aggressive in making the dataset balanced 19 | weight_upper: float = 5.0 20 | weight_lower: float = 0.2 21 | 22 | 23 | class BalancedInstanceWeightsGenerator(Operation): 24 | """ 25 | Generate instance weights, with which sampling can achieve a balanced dataset across different categories. 26 | """ 27 | _NEG_CLASS_INDEX = -1 28 | 29 | def __init__(self, config: WeightsGenerationConfig) -> None: 30 | super().__init__() 31 | self.config = config 32 | 33 | def run(self, *args: DatasetManifest): 34 | data_manifest = args[0] 35 | if data_manifest is None: 36 | raise ValueError('data manifest is None.') 37 | 38 | logger.info("Generating instance weights for dataset balancing.") 39 | image_tags = [self._process_labels(x.labels) for x in data_manifest.images] 40 | 41 | class_wise_image_counter = Counter() 42 | for tags in image_tags: 43 | class_wise_image_counter.update(tags) 44 | 45 | mean_class_wise_image_tag_count = numpy.mean(list(class_wise_image_counter.values())) 46 | class_wise_multipliers = {x: mean_class_wise_image_tag_count / class_wise_image_counter[x] for x in class_wise_image_counter} 47 | if self.config.soft: 48 | class_wise_multipliers = {x: numpy.sqrt(class_wise_multipliers[x]) for x in class_wise_multipliers} 49 | 50 | class_wise_multipliers = {x: BalancedInstanceWeightsGenerator._scope_multiplier(class_wise_multipliers[x], self.config.weight_upper, self.config.weight_lower) for x in class_wise_multipliers} 51 | 52 | image_weights = [BalancedInstanceWeightsGenerator._get_instance_multiplier(tags, class_wise_multipliers, self.config.weight_upper, self.config.weight_lower) for tags in image_tags] 53 | 54 | logger.info(f'instance weights: max {max(image_weights)}, min {min(image_weights)}, len {len(image_weights)}') 55 | 56 | return image_weights 57 | 58 | def _process_labels(self, labels: typing.List[ImageLabelWithCategoryManifest]): 59 | return [x.category_id for x in labels] or [BalancedInstanceWeightsGenerator._NEG_CLASS_INDEX] 60 | 61 | @staticmethod 62 | def _get_instance_multiplier(tags, class_wise_multipliers, weight_upper, weight_lower): 63 | mul = numpy.prod([class_wise_multipliers[tag] for tag in tags]) 64 | 65 | return BalancedInstanceWeightsGenerator._scope_multiplier(mul, weight_upper, weight_lower) 66 | 67 | @staticmethod 68 | def _scope_multiplier(value, weight_upper, weight_lower): 69 | return min(max(value, weight_lower), weight_upper) 70 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/filter.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import copy 3 | 4 | from ..data_manifest import DatasetManifest, ImageDataManifest 5 | from .operation import Operation 6 | 7 | 8 | class ImageFilter(abc.ABC): 9 | @abc.abstractmethod 10 | def should_be_filtered(self, image: ImageDataManifest, data_manifest: DatasetManifest) -> bool: 11 | pass 12 | 13 | 14 | class DatasetFilter(Operation): 15 | """ 16 | Filter images by certain conditions 17 | """ 18 | def __init__(self, image_filter: ImageFilter): 19 | self.image_filter = image_filter 20 | 21 | def run(self, *args: DatasetManifest): 22 | if len(args) != 1: 23 | raise ValueError 24 | 25 | manifest = args[0] 26 | return DatasetManifest(copy.deepcopy([x for x in manifest.images if not self.image_filter.should_be_filtered(x, manifest)]), 27 | copy.deepcopy(manifest.categories), 28 | copy.deepcopy(manifest.data_type), 29 | copy.deepcopy(manifest.additional_info)) 30 | 31 | 32 | class ImageNoAnnotationFilter(ImageFilter): 33 | def should_be_filtered(self, image: ImageDataManifest, data_manifest: DatasetManifest) -> bool: 34 | return image.is_negative() 35 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/generate_coco.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from ..data_manifest import DatasetManifest, DatasetManifestWithMultiImageLabel 4 | from .operation import Operation 5 | 6 | 7 | class GenerateCocoDictBase(Operation): 8 | """ 9 | Base class for generating a COCO dictionary from DatasetManifest that can be serialized 10 | """ 11 | 12 | def _generate_annotations(self, manifest: DatasetManifest): 13 | annotations = [] 14 | for img_id, img in enumerate(manifest.images): 15 | for ann in img.labels: 16 | coco_ann = { 17 | 'id': len(annotations) + 1, 18 | 'image_id': img_id + 1, 19 | } 20 | 21 | self.process_labels(coco_ann, ann) 22 | annotations.append(coco_ann) 23 | 24 | return annotations 25 | 26 | def _generate_images(self, manifest): 27 | images = [{'id': i + 1, 'file_name': x.img_path, 'width': x.width, 'height': x.height} for i, x in enumerate(manifest.images)] 28 | return images 29 | 30 | def run(self, *args): 31 | if len(args) != 1: 32 | raise ValueError 33 | 34 | manifest = args[0] 35 | result = { 36 | "images": self._generate_images(manifest), 37 | "categories": self.generate_categories_or_none(manifest), 38 | "annotations": self._generate_annotations(manifest) 39 | } 40 | 41 | GenerateCocoDictBase._filter_none(result) 42 | return result 43 | 44 | @abc.abstractmethod 45 | def process_labels(self, coco_ann, label): 46 | pass 47 | 48 | def generate_categories_or_none(self, manifest): 49 | if manifest.categories: 50 | return [{'id': i + 1, 'name': x.name, 'supercateogry': x.super_category} for i, x in enumerate(manifest.categories)] 51 | 52 | return None 53 | 54 | @staticmethod 55 | def _filter_none(dict_val: dict): 56 | to_del = [] 57 | for key in dict_val.keys(): 58 | if dict_val[key] is None: 59 | to_del.append(key) 60 | elif isinstance(dict_val[key], dict): 61 | GenerateCocoDictBase._filter_none(dict_val[key]) 62 | elif isinstance(dict_val[key], list): 63 | for x in dict_val[key]: 64 | if isinstance(x, dict): 65 | GenerateCocoDictBase._filter_none(x) 66 | 67 | for key in to_del: 68 | del dict_val[key] 69 | 70 | 71 | class MultiImageCocoDictGenerator(GenerateCocoDictBase): 72 | """ 73 | Base class for generating a COCO dictionary from DatasetManifestWithMultiImageLabel that can be serialized 74 | """ 75 | 76 | def _generate_annotations(self, manifest: DatasetManifestWithMultiImageLabel): 77 | annotations = [] 78 | for id, ann in enumerate(manifest.annotations, 1): 79 | coco_ann = { 80 | 'id': id, 81 | 'image_ids': [img_id + 1 for img_id in ann.img_ids], 82 | } 83 | 84 | self.process_labels(coco_ann, ann) 85 | annotations.append(coco_ann) 86 | 87 | return annotations 88 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/generate_stand_alone_image_list_base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import pathlib 3 | import typing 4 | 5 | from ...base64_utils import Base64Utils 6 | from ...data_reader.file_reader import FileReader 7 | from ..data_manifest import DatasetManifest, ImageDataManifest, ImageLabelManifest 8 | from .operation import Operation 9 | 10 | 11 | class GenerateStandAloneImageListBase(Operation): 12 | """ 13 | Base class for generating an image oriented dictonary where each entry contains all information about the image including the image data, annotation, etc... 14 | """ 15 | 16 | def __init__(self, flatten: bool) -> None: 17 | super().__init__() 18 | self._flatten = flatten 19 | 20 | def run(self, *args) -> typing.Generator: 21 | if len(args) != 1: 22 | raise ValueError 23 | 24 | manifest = args[0] 25 | file_reader = FileReader() 26 | if self._flatten: 27 | for i, image in enumerate(manifest.images): 28 | b64_image = Base64Utils.file_to_b64_str(pathlib.Path(image.img_path), file_reader=file_reader) 29 | for label in image.labels: 30 | img = { 31 | 'image_id': i + 1, 32 | } 33 | label = self._generate_label(label, image, manifest) 34 | if isinstance(label, dict): 35 | img.update(label) 36 | else: 37 | img['label'] = label 38 | 39 | img['image'] = b64_image 40 | yield img 41 | else: 42 | for i, x in enumerate(manifest.images): 43 | yield { 44 | 'id': i + 1, 45 | 'labels': list(self._generate_labels(x, manifest)), 46 | 'image': Base64Utils.file_to_b64_str(pathlib.Path(x.img_path), file_reader=file_reader), 47 | } 48 | 49 | def _generate_labels(self, image: ImageDataManifest, manifest: DatasetManifest) -> typing.Generator: 50 | for label in image.labels: 51 | yield self._generate_label(label, image, manifest) 52 | 53 | @abstractmethod 54 | def _generate_label(self, label: ImageLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> typing.Dict: 55 | raise NotImplementedError 56 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/merge.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import copy 3 | import logging 4 | import typing 5 | 6 | from ....common.utils import deep_merge 7 | from ..data_manifest import CategoryManifest, DatasetManifest, ImageLabelWithCategoryManifest, DatasetManifestWithMultiImageLabel 8 | from .operation import Operation 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class MergeStrategy(abc.ABC): 14 | def __init__(self) -> None: 15 | pass 16 | 17 | @abc.abstractmethod 18 | def merge(self, *args: DatasetManifest): 19 | pass 20 | 21 | def check(self, *args: DatasetManifest): 22 | if len(args) < 1: 23 | raise ValueError('less than one manifest provided.') 24 | if any([arg is None for arg in args]): 25 | raise ValueError('"None" manifest found') 26 | 27 | 28 | class ManifestMerger(Operation): 29 | def __init__(self, strategy: MergeStrategy) -> None: 30 | super().__init__() 31 | self._strategy = strategy 32 | 33 | def run(self, *args: DatasetManifest): 34 | """ 35 | Merge multiple data manifests of the same data type into one, with the assumptions that the images from different manifests are independent 36 | 37 | Args: 38 | args: manifests to be merged 39 | """ 40 | self._strategy.check(*args) 41 | return self._strategy.merge(*args) 42 | 43 | 44 | class SingleTaskMerge(MergeStrategy): 45 | """ 46 | Merge for single task data type 47 | """ 48 | 49 | def merge(self, *args: DatasetManifest): 50 | data_type = args[0].data_type 51 | images = [] 52 | 53 | categories, category_name_to_idx = self._combine_categories(args) if bool(args[0].categories) else (None, None) 54 | 55 | for manifest in args: 56 | for image in manifest.images: 57 | new_image = copy.deepcopy(image) 58 | new_image.id = len(images) 59 | if categories: 60 | for label in new_image.labels: 61 | label: ImageLabelWithCategoryManifest = label 62 | label.category_id = category_name_to_idx[manifest.categories[label.category_id].name] 63 | images.append(new_image) 64 | 65 | additional_info = deep_merge([x.additional_info for x in args]) 66 | return DatasetManifest(images, categories, copy.deepcopy(data_type), additional_info) 67 | 68 | def check(self, *args: typing.Union[DatasetManifest, DatasetManifestWithMultiImageLabel]): 69 | super().check(*args) 70 | 71 | if any([x.is_multitask for x in args]): 72 | raise ValueError('All manifests must be of the same data type and single task.') 73 | if any([x.data_type != args[0].data_type for x in args]): 74 | raise ValueError('All manifests must be of the same data type.') 75 | 76 | def _combine_categories(self, manifests: DatasetManifest): 77 | category_name_to_idx = {} 78 | for manifest in manifests: 79 | for category in manifest.categories: 80 | if category.name not in category_name_to_idx: 81 | category_name_to_idx[category.name] = len(category_name_to_idx) 82 | categories = [CategoryManifest(i, x) for i, x in enumerate(category_name_to_idx.keys())] 83 | 84 | return categories, category_name_to_idx 85 | 86 | 87 | class MultiImageDatasetSingleTaskMerge(MergeStrategy): 88 | """ 89 | Merge for single task data type with DatasetManifestWithMultiImageLabel. 90 | """ 91 | 92 | def merge(self, *args: DatasetManifestWithMultiImageLabel): 93 | data_type = args[0].data_type 94 | images = [] 95 | annotations = [] 96 | for manifest in args: 97 | old_to_new_img_ids = {} 98 | for image in manifest.images: 99 | new_image = copy.deepcopy(image) 100 | new_image.id = len(images) 101 | old_to_new_img_ids[image.id] = new_image.id 102 | images.append(new_image) 103 | 104 | for annotation in manifest.annotations: 105 | new_annotation = copy.deepcopy(annotation) 106 | new_annotation.id = len(annotations) 107 | new_annotation.img_ids = [old_to_new_img_ids[manifest.images[x].id] for x in annotation.img_ids] 108 | annotations.append(new_annotation) 109 | 110 | additional_info = deep_merge([x.additional_info for x in args]) 111 | return DatasetManifestWithMultiImageLabel(images, annotations, copy.deepcopy(data_type), additional_info) 112 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/operation.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | from ..data_manifest import DatasetManifest, DatasetManifestWithMultiImageLabel 4 | 5 | 6 | class Operation(abc.ABC): 7 | """ 8 | Base class for operations on DatasetManifest 9 | """ 10 | 11 | def __init__(self) -> None: 12 | pass 13 | 14 | def run(*args: typing.Union[DatasetManifest, DatasetManifestWithMultiImageLabel]): 15 | pass 16 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/remove_categories.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import typing 3 | from dataclasses import dataclass 4 | 5 | from ..data_manifest import DatasetManifest 6 | from .operation import Operation 7 | 8 | 9 | @dataclass 10 | class RemoveCategoriesConfig: 11 | category_names: typing.List[str] 12 | 13 | 14 | class RemoveCategories(Operation): 15 | """ 16 | Remove categories. 17 | """ 18 | def __init__(self, config: RemoveCategoriesConfig) -> None: 19 | super().__init__() 20 | if not config: 21 | raise ValueError 22 | 23 | self.config = config 24 | 25 | def run(self, *args: DatasetManifest): 26 | if len(args) != 1: 27 | raise ValueError 28 | 29 | manifest = args[0] 30 | if not manifest.categories: 31 | raise ValueError 32 | 33 | result = copy.deepcopy(manifest) 34 | if not self.config.category_names: 35 | return result 36 | 37 | c_name_to_idx = {c.name: i for i, c in enumerate(manifest.categories)} 38 | c_indices_to_remove = sorted([c_name_to_idx[c] for c in self.config.category_names]) 39 | old_c_idx_to_new_idx = {} 40 | j = 0 41 | for i in range(len(manifest.categories)): 42 | if j < len(c_indices_to_remove) and i == c_indices_to_remove[j]: 43 | j += 1 44 | else: 45 | old_c_idx_to_new_idx[i] = i - j 46 | 47 | def alter_cid(label, new_category_id): 48 | label.category_id = new_category_id 49 | return label 50 | 51 | for image in result.images: 52 | image.labels = [alter_cid(label, old_c_idx_to_new_idx[label.category_id]) for label in image.labels if label.category_id in old_c_idx_to_new_idx] 53 | 54 | def alter_category(category, new_idx): 55 | category.id = new_idx 56 | return category 57 | 58 | result.categories = [alter_category(c, old_c_idx_to_new_idx[i]) for i, c in enumerate(result.categories) if i in old_c_idx_to_new_idx] 59 | return result 60 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/operations/spawn.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import typing 4 | from dataclasses import dataclass 5 | 6 | from ..data_manifest import DatasetManifest 7 | from .merge import SingleTaskMerge 8 | from .operation import Operation 9 | from .sample import SampleByNumSamples, SampleByNumSamplesConfig 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @dataclass 15 | class SpawnConfig: 16 | random_seed: int 17 | target_n_samples: int 18 | instance_weights: typing.List[float] = None 19 | 20 | 21 | class Spawn(Operation): 22 | """ 23 | Spawn the dataset (oversample). 24 | 25 | This will be consolidated with sample operation. 26 | """ 27 | 28 | def __init__(self, config: SpawnConfig) -> None: 29 | super().__init__() 30 | self.config = config 31 | 32 | def run(self, *args: DatasetManifest): 33 | """Spawn manifest to a size. 34 | To ensure each class has samples after spawn, we first keep a copy of original data, then merge with sampled data. 35 | If instance_weights is not provided, spawn follows class distribution. 36 | Otherwise spawn the dataset so that the instances follow the given weights. In this case the spawned size is not guranteed to be num_samples. 37 | 38 | Returns: 39 | Spawned dataset (DatasetManifest) 40 | """ 41 | 42 | if len(args) != 1: 43 | raise ValueError 44 | 45 | manifest = args[0] 46 | cfg = self.config 47 | if cfg.instance_weights: 48 | if len(cfg.instance_weights) != len(manifest) or any([x < 0 for x in cfg.instance_weights]): 49 | raise ValueError 50 | 51 | sum_weights = sum(cfg.instance_weights) 52 | # Distribute the number of num_samples to each image by the weights. The original image is subtracted. 53 | n_copies_per_sample = [max(0, round(w / sum_weights * cfg.target_n_samples - 1)) for w in cfg.instance_weights] 54 | spawned_images = [] 55 | for image, n_copies in zip(manifest.images, n_copies_per_sample): 56 | spawned_images += [copy.deepcopy(image) for _ in range(n_copies)] 57 | 58 | sampled_manifest = DatasetManifest(spawned_images, manifest.categories, manifest.data_type, manifest.additional_info) 59 | else: 60 | cfg = SampleByNumSamplesConfig(cfg.random_seed, True, cfg.target_n_samples - len(manifest)) 61 | sampled_manifest = SampleByNumSamples(cfg).sample(manifest) 62 | 63 | # Merge with the copy of the original dataset to ensure each class has sample. 64 | merger = SingleTaskMerge() 65 | return merger.merge(manifest, sampled_manifest) 66 | -------------------------------------------------------------------------------- /vision_datasets/common/data_manifest/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from .data_manifest import DatasetManifest, ImageDataManifest 4 | 5 | 6 | def generate_multitask_dataset_manifest(manifest_by_task: Dict[str, DatasetManifest]): 7 | images_by_id = {} 8 | for task_name, task_manifest in manifest_by_task.items(): 9 | if not task_manifest: 10 | continue 11 | 12 | for image in task_manifest.images: 13 | if image.id not in images_by_id: 14 | multi_task_image_manifest = ImageDataManifest(image.id, image.img_path, image.width, image.height, {task_name: image.labels}) 15 | images_by_id[image.id] = multi_task_image_manifest 16 | else: 17 | images_by_id[image.id].labels[task_name] = image.labels 18 | 19 | if not images_by_id: 20 | return None 21 | 22 | categories_by_task = {k: manifest.categories for k, manifest in manifest_by_task.items()} 23 | dataset_types_by_task = {k: manifest.data_type for k, manifest in manifest_by_task.items()} 24 | additional_info_by_task = {k: manifest.additional_info for k, manifest in manifest_by_task.items() if manifest.additional_info} 25 | 26 | return DatasetManifest([v for v in images_by_id.values()], categories_by_task, dataset_types_by_task, addtional_info=additional_info_by_task) 27 | -------------------------------------------------------------------------------- /vision_datasets/common/data_reader/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_downloader import DatasetDownloader, DownloadedDatasetsResources 2 | from .file_reader import FileReader 3 | from .image_loader import PILImageLoader 4 | 5 | __all__ = ['DatasetDownloader', 'DownloadedDatasetsResources', 'FileReader', 'PILImageLoader'] 6 | -------------------------------------------------------------------------------- /vision_datasets/common/data_reader/file_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import zipfile 4 | from typing import Union 5 | from urllib.parse import quote 6 | from urllib.request import urlopen 7 | 8 | from ..utils import can_be_url 9 | 10 | 11 | class MultiProcessZipFile: 12 | """ZipFile which is readable from multi processes""" 13 | 14 | def __init__(self, filename): 15 | self.filename = filename 16 | self.zipfiles = {} 17 | 18 | def open(self, file): 19 | if os.getpid() not in self.zipfiles: 20 | self.zipfiles[os.getpid()] = zipfile.ZipFile(self.filename) 21 | return self.zipfiles[os.getpid()].open(file) 22 | 23 | def close(self): 24 | for z in self.zipfiles.values(): 25 | z.close() 26 | self.zipfiles = {} 27 | 28 | def __getstate__(self): 29 | return {'filename': self.filename} 30 | 31 | def __setstate__(self, state): 32 | self.filename = state['filename'] 33 | self.zipfiles = {} 34 | 35 | 36 | class FileReader: 37 | """Reader to support files of different path styles. 38 | 1. @ 39 | 2. url 40 | 3. regular file name 41 | """ 42 | 43 | def __init__(self): 44 | self.zip_files = {} 45 | 46 | def open(self, name: Union[pathlib.Path, str], mode='r', encoding=None): 47 | name = str(name) 48 | # read file from url 49 | if can_be_url(name): 50 | return urlopen(self._encode_non_ascii(name)) 51 | 52 | # read file from local zip: @, e.g. images.zip@1.jpg 53 | if '@' in name: 54 | zip_path, file_path = name.split('@', 1) 55 | if zip_path not in self.zip_files: 56 | self.zip_files[zip_path] = MultiProcessZipFile(zip_path) 57 | return self.zip_files[zip_path].open(file_path) 58 | 59 | # read file from local dir 60 | return open(name, mode, encoding=encoding) 61 | 62 | def close(self): 63 | for zip_file in self.zip_files.values(): 64 | zip_file.close() 65 | self.zip_files = {} 66 | 67 | @staticmethod 68 | def _encode_non_ascii(s): 69 | return ''.join([c if ord(c) < 128 else quote(c) for c in s]) 70 | -------------------------------------------------------------------------------- /vision_datasets/common/data_reader/image_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from PIL import Image 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | # see https://exiv2.org/tags.html 8 | ORIENTATION_EXIF_TAG = 0x0112 9 | 10 | 11 | class PILImageLoader: 12 | """Load PIL image and fix image orientation using EXIF""" 13 | 14 | @staticmethod 15 | def load_from_stream(f): 16 | image = Image.open(f) 17 | img_format = image.format 18 | 19 | try: 20 | exif = image.getexif() 21 | except Exception as e: 22 | logger.warning(f'Failed to get EXIF from an image: {e}') 23 | exif = None 24 | 25 | orientation = exif.get(ORIENTATION_EXIF_TAG) if exif else None 26 | if orientation: 27 | # orientation is 1 based, shift to zero based and flip/transpose based on 0-based values 28 | orientation -= 1 29 | if orientation >= 4: 30 | image = image.transpose(Image.TRANSPOSE) 31 | if orientation == 2 or orientation == 3 or orientation == 6 or orientation == 7: 32 | image = image.transpose(Image.FLIP_TOP_BOTTOM) 33 | if orientation == 1 or orientation == 2 or orientation == 5 or orientation == 6: 34 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 35 | # not supported by the convert function 36 | if image.mode != "I" and image.mode != "F": 37 | image = image.convert('RGB') 38 | image.format = img_format 39 | return image 40 | 41 | @staticmethod 42 | def load_from_file(filepath): 43 | try: 44 | with open(filepath, 'rb') as f: 45 | return PILImageLoader.load_from_stream(f) 46 | except Exception: 47 | logger.exception(f'Failed to load an image: {filepath}') 48 | raise 49 | -------------------------------------------------------------------------------- /vision_datasets/common/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .vision_dataset import VisionDataset 2 | 3 | __all__ = ['VisionDataset'] 4 | -------------------------------------------------------------------------------- /vision_datasets/common/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from ..dataset_info import BaseDatasetInfo 4 | 5 | 6 | class BaseDataset(ABC): 7 | def __init__(self, dataset_info: BaseDatasetInfo): 8 | 9 | self.dataset_info = dataset_info 10 | 11 | def __getitem__(self, idx): 12 | """ iterate through the dataset 13 | 14 | Args: 15 | idx: can be a single index or range 16 | 17 | Returns: 18 | requested sample(s) 19 | """ 20 | 21 | if isinstance(idx, int): 22 | if idx >= self.__len__(): 23 | raise IndexError 24 | return self._get_single_item(idx) 25 | 26 | stop = min(self.__len__(), idx.stop) 27 | return [self.__getitem__(i) for i in range(idx.start, stop, idx.step)] if idx.step else [self.__getitem__(i) for i in range(idx.start, stop)] 28 | 29 | @property 30 | @abstractmethod 31 | def categories(self): 32 | pass 33 | 34 | @abstractmethod 35 | def __len__(self): 36 | pass 37 | 38 | @abstractmethod 39 | def _get_single_item(self, index): 40 | pass 41 | 42 | @abstractmethod 43 | def close(self): 44 | """ release resources 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /vision_datasets/common/dataset_info.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import AnnotationFormats, DatasetTypes, Usages 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | def _data_type_to_enum(val: str): 8 | # mapping from legacy names before pkg version 1.0.0 9 | 10 | legacy_mapping = { 11 | 'classification_multilabel': DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL, 12 | 'classification_multiclass': DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS, 13 | 'object_detection': DatasetTypes.IMAGE_OBJECT_DETECTION, 14 | 'image_retrieval': DatasetTypes.TEXT_2_IMAGE_RETRIEVAL 15 | } 16 | 17 | if val.lower() in legacy_mapping: 18 | return legacy_mapping[val.lower()] 19 | 20 | return DatasetTypes[val.upper()] 21 | 22 | 23 | class DatasetInfoFactory: 24 | @staticmethod 25 | def create(dataset_info_dict: dict): 26 | data_type = _data_type_to_enum(dataset_info_dict.get('type')) 27 | if data_type == DatasetTypes.MULTITASK: 28 | return MultiTaskDatasetInfo(dataset_info_dict) 29 | if data_type == DatasetTypes.KEY_VALUE_PAIR: 30 | return KeyValuePairDatasetInfo(dataset_info_dict) 31 | return DatasetInfo(dataset_info_dict) 32 | 33 | 34 | class BaseDatasetInfo: 35 | """ 36 | Info fields common to both all datasets regardless of whether it is coco or iris, single task or multitask 37 | """ 38 | 39 | def __init__(self, dataset_info_dict): 40 | self.name = dataset_info_dict['name'] 41 | self.version = dataset_info_dict.get('version', 1) 42 | self.type = _data_type_to_enum(dataset_info_dict['type']) 43 | self.root_folder = dataset_info_dict.get('root_folder') 44 | self.description = dataset_info_dict.get('description', '') 45 | self.data_format = AnnotationFormats[dataset_info_dict.get('format', 'IRIS').upper()] 46 | 47 | 48 | class DatasetInfo(BaseDatasetInfo): 49 | def __init__(self, dataset_info_dict): 50 | data_type = _data_type_to_enum(dataset_info_dict.get('type')) 51 | if data_type == DatasetTypes.MULTITASK: 52 | raise ValueError 53 | super(DatasetInfo, self).__init__(dataset_info_dict) 54 | 55 | self.index_files = dict() 56 | self.files_for_local_usage = dict() 57 | for usage in Usages: 58 | usage_str = usage.name.lower() 59 | if usage_str in dataset_info_dict: 60 | self.index_files[usage] = dataset_info_dict[usage_str]['index_path'] 61 | self.files_for_local_usage[usage] = dataset_info_dict[usage_str].get('files_for_local_usage', []) 62 | 63 | if self.data_format == AnnotationFormats.IRIS: 64 | self.labelmap = dataset_info_dict.get('labelmap') 65 | self.image_metadata_path = dataset_info_dict.get('image_metadata_path') 66 | else: 67 | self.labelmap = None 68 | self.image_metadata_path = None 69 | 70 | 71 | class MultiTaskDatasetInfo(BaseDatasetInfo): 72 | def __init__(self, dataset_info_dict): 73 | if 'tasks' not in dataset_info_dict: 74 | raise ValueError 75 | data_type = _data_type_to_enum(dataset_info_dict.get('type')) 76 | if data_type != DatasetTypes.MULTITASK: 77 | raise ValueError 78 | 79 | super(MultiTaskDatasetInfo, self).__init__(dataset_info_dict) 80 | 81 | tasks = dataset_info_dict['tasks'] 82 | info_dict = {} 83 | for task_name, task_info in tasks.items(): 84 | info_dict[task_name] = DatasetInfo({**dataset_info_dict, **task_info}) 85 | 86 | self.sub_task_infos = info_dict 87 | 88 | @property 89 | def task_names(self): 90 | return list(self.sub_task_infos.keys()) 91 | 92 | 93 | class KeyValuePairDatasetInfo(DatasetInfo): 94 | def __init__(self, dataset_info_dict): 95 | if 'schema' not in dataset_info_dict: 96 | raise ValueError('schema must be provided for key_value_pair dataset!') 97 | 98 | schema = dataset_info_dict['schema'] 99 | self._check_schema(schema) 100 | 101 | super(KeyValuePairDatasetInfo, self).__init__(dataset_info_dict) 102 | self.schema = schema 103 | 104 | def _check_schema(self, schema: dict) -> bool: 105 | for name in ['name', 'fieldSchema']: 106 | if name not in schema: 107 | raise ValueError(f'{name} is required in schema!') 108 | if 'description' not in schema: 109 | logger.warning('description is not provided in schema!') 110 | # TODO: add more checking when definition is finalized 111 | -------------------------------------------------------------------------------- /vision_datasets/common/dataset_management/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_hub import DatasetHub 2 | from .dataset_registry import DatasetRegistry 3 | 4 | __all__ = ['DatasetHub', 'DatasetRegistry'] 5 | -------------------------------------------------------------------------------- /vision_datasets/common/dataset_management/dataset_registry.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from typing import Union 4 | 5 | from ..dataset_info import DatasetInfoFactory 6 | 7 | 8 | class DatasetRegistry: 9 | """ 10 | A central registry of all available datasets 11 | """ 12 | 13 | def __init__(self, datasets_json: Union[str, list]): 14 | if isinstance(datasets_json, list): 15 | self.datasets = [DatasetInfoFactory.create(d) for dj in datasets_json for d in json.loads(dj)] 16 | else: 17 | self.datasets = [DatasetInfoFactory.create(d) for d in json.loads(datasets_json)] 18 | 19 | def get_dataset_info(self, dataset_name, dataset_version=None): 20 | datasets = [d for d in self.datasets if d.name == dataset_name and (not dataset_version or d.version == dataset_version)] 21 | if not datasets: 22 | return None 23 | 24 | sorted_datasets = sorted(datasets, key=lambda d: d.version) 25 | return copy.deepcopy(sorted_datasets[-1]) 26 | 27 | def list_data_version_and_types(self): 28 | return [{'name': d.name, 'version': d.version, 'type': d.type, 'description': d.description} for d in self.datasets] 29 | 30 | @staticmethod 31 | def _get_default_dataset_json(json_file_name): 32 | import sys 33 | py_version = sys.version_info 34 | if py_version.minor >= 7: 35 | import importlib.resources as pkg_resources 36 | 37 | from vision_datasets.common import resources 38 | datasets_json = pkg_resources.read_text(resources, json_file_name) 39 | else: 40 | import pkgutil 41 | resource_package = 'vision_datasets' 42 | resource_path = '/'.join(('resources', json_file_name)) 43 | datasets_json = pkgutil.get_data(resource_package, resource_path) 44 | return datasets_json 45 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor_factory import CocoManifestAdaptorFactory 2 | from .data_manifest_factory import DataManifestFactory 3 | from .operations import (BalancedInstanceWeightsFactory, CocoDictGeneratorFactory, ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, 4 | StandAloneImageListGeneratorFactory, SupportedOperationsByDataType) 5 | 6 | __all__ = ['CocoManifestAdaptorFactory', 'CocoDictGeneratorFactory', 'ManifestMergeStrategyFactory', 7 | 'DataManifestFactory', 8 | 'SampleStrategyFactory', 'BalancedInstanceWeightsFactory', 'SpawnFactory', 'SplitFactory', 'StandAloneImageListGeneratorFactory', 'SupportedOperationsByDataType'] 9 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/coco_manifest_adaptor_factory.py: -------------------------------------------------------------------------------- 1 | from ..data_manifest import CocoManifestAdaptorBase 2 | 3 | 4 | class CocoManifestAdaptorFactory: 5 | _mapping = {} 6 | 7 | @classmethod 8 | def register(cls, data_type: str): 9 | def decorator(klass): 10 | cls._mapping[data_type] = klass 11 | return klass 12 | return decorator 13 | 14 | @classmethod 15 | def create(cls, data_type: str, *args, **kwargs) -> CocoManifestAdaptorBase: 16 | 17 | return cls._mapping[data_type](*args, **kwargs) 18 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/data_manifest_factory.py: -------------------------------------------------------------------------------- 1 | from ..constants import AnnotationFormats, DatasetTypes, Usages 2 | from ..data_manifest.iris_data_manifest_adaptor import IrisManifestAdaptor 3 | from ..dataset_info import BaseDatasetInfo 4 | from ..factory import CocoManifestAdaptorFactory 5 | from ..utils import construct_full_url_or_path_func 6 | 7 | 8 | class DataManifestFactory: 9 | @staticmethod 10 | def create(dataset_info: BaseDatasetInfo, usage: Usages, container_sas_or_root_dir: str = None): 11 | if dataset_info.data_format == AnnotationFormats.IRIS: 12 | return IrisManifestAdaptor.create_dataset_manifest(dataset_info, usage, container_sas_or_root_dir) 13 | 14 | if dataset_info.data_format == AnnotationFormats.COCO: 15 | container_sas_or_root_dir = construct_full_url_or_path_func(container_sas_or_root_dir, dataset_info.root_folder)('') 16 | if dataset_info.type == DatasetTypes.MULTITASK: 17 | coco_file_by_task = {k: sub_taskinfo.index_files.get(usage) for k, sub_taskinfo in dataset_info.sub_task_infos.items()} 18 | data_type_by_task = {k: sub_taskinfo.type for k, sub_taskinfo in dataset_info.sub_task_infos.items()} 19 | adaptor = CocoManifestAdaptorFactory.create(DatasetTypes.MULTITASK, data_type_by_task) 20 | return adaptor.create_dataset_manifest(coco_file_by_task, container_sas_or_root_dir) 21 | if dataset_info.type == DatasetTypes.KEY_VALUE_PAIR: 22 | adaptor = CocoManifestAdaptorFactory.create(DatasetTypes.KEY_VALUE_PAIR, dataset_info.schema) 23 | else: 24 | adaptor = CocoManifestAdaptorFactory.create(dataset_info.type) 25 | return adaptor.create_dataset_manifest(dataset_info.index_files.get(usage), container_sas_or_root_dir) 26 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/__init__.py: -------------------------------------------------------------------------------- 1 | from .balanced_instance_weights_factory import BalancedInstanceWeightsFactory 2 | from .coco_generator_factory import CocoDictGeneratorFactory 3 | from .stand_alone_image_dicts_factory import StandAloneImageListGeneratorFactory 4 | from .manifest_merger_factory import ManifestMergeStrategyFactory 5 | from .sampler_factory import SampleStrategyFactory 6 | from .spawn_factory import SpawnFactory 7 | from .split_factory import SplitFactory 8 | from .supported_operations_by_data_type import SupportedOperationsByDataType 9 | 10 | __all__ = ['BalancedInstanceWeightsFactory', 'CocoDictGeneratorFactory', 'StandAloneImageListGeneratorFactory', 'ManifestMergeStrategyFactory', 11 | 'SampleStrategyFactory', 'SpawnFactory', 'SplitFactory', 'SupportedOperationsByDataType'] 12 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/balanced_instance_weights_factory.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from ...constants import DatasetTypes 3 | from ...data_manifest import BalancedInstanceWeightsGenerator, WeightsGenerationConfig 4 | from .supported_operations_by_data_type import SupportedOperationsByDataType 5 | 6 | 7 | class BalancedInstanceWeightsFactory: 8 | _mapping = {} 9 | 10 | @classmethod 11 | def direct_register(cls, klass, data_type: DatasetTypes): 12 | cls._mapping[data_type] = klass 13 | SupportedOperationsByDataType.add(data_type, klass) 14 | return klass 15 | 16 | @classmethod 17 | def register(cls, data_type: DatasetTypes): 18 | def decorator(klass): 19 | return BalancedInstanceWeightsFactory.direct_register(klass, data_type) 20 | return decorator 21 | 22 | @classmethod 23 | def create(cls, data_type: DatasetTypes, config: WeightsGenerationConfig, *args, **kwargs) -> BalancedInstanceWeightsGenerator: 24 | 25 | return cls._mapping[data_type](config, *args, **kwargs) 26 | 27 | @classmethod 28 | def list_data_types(cls) -> typing.Iterable[DatasetTypes]: 29 | return list(cls._mapping.keys()) 30 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/coco_generator_factory.py: -------------------------------------------------------------------------------- 1 | from ...constants import DatasetTypes 2 | from ...data_manifest import GenerateCocoDictBase 3 | from .supported_operations_by_data_type import SupportedOperationsByDataType 4 | 5 | 6 | class CocoDictGeneratorFactory: 7 | _mapping = {} 8 | 9 | @classmethod 10 | def register(cls, data_type: DatasetTypes): 11 | def decorator(klass): 12 | cls._mapping[data_type] = klass 13 | SupportedOperationsByDataType.add(data_type, klass) 14 | return klass 15 | return decorator 16 | 17 | @classmethod 18 | def create(cls, data_type: DatasetTypes, *args, **kwargs) -> GenerateCocoDictBase: 19 | return cls._mapping[data_type](*args, **kwargs) 20 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/manifest_merger_factory.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from ...constants import DatasetTypes 3 | from ...data_manifest import MergeStrategy 4 | from .supported_operations_by_data_type import SupportedOperationsByDataType 5 | 6 | 7 | class ManifestMergeStrategyFactory: 8 | _mapping = {} 9 | 10 | @classmethod 11 | def direct_register(cls, klass, data_type: DatasetTypes): 12 | SupportedOperationsByDataType.add(data_type, klass) 13 | cls._mapping[data_type] = klass 14 | return klass 15 | 16 | @classmethod 17 | def register(cls, data_type: DatasetTypes): 18 | def decorator(klass): 19 | return ManifestMergeStrategyFactory.direct_register(klass, data_type) 20 | return decorator 21 | 22 | @classmethod 23 | def create(cls, data_type: DatasetTypes, *args, **kwargs) -> MergeStrategy: 24 | 25 | return cls._mapping[data_type](*args, **kwargs) 26 | 27 | @classmethod 28 | def list_data_types(cls) -> typing.Iterable[DatasetTypes]: 29 | return list(cls._mapping.keys()) 30 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/sampler_factory.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from ...constants import DatasetTypes 3 | from ...data_manifest import SampleBaseConfig, SampleStrategy, SampleStrategyType 4 | from .supported_operations_by_data_type import SupportedOperationsByDataType 5 | 6 | 7 | class SampleStrategyFactory: 8 | _mapping = {} 9 | 10 | @classmethod 11 | def direct_register(cls, klass, data_type: DatasetTypes, strategy_name: SampleStrategyType): 12 | cls._mapping[(data_type, strategy_name)] = klass 13 | return klass 14 | 15 | @classmethod 16 | def register(cls, data_type: DatasetTypes, strategy_name: SampleStrategyType): 17 | def decorator(klass): 18 | SupportedOperationsByDataType.add(data_type, klass) 19 | return SampleStrategyFactory.direct_register(klass, data_type, strategy_name) 20 | return decorator 21 | 22 | @classmethod 23 | def create(cls, data_type: DatasetTypes, strategy_type: SampleStrategyType, config: SampleBaseConfig, *args, **kwargs) -> SampleStrategy: 24 | return cls._mapping[(data_type, strategy_type)](config, *args, **kwargs) 25 | 26 | @classmethod 27 | def list_data_types(cls, strategy_type: SampleStrategyType) -> typing.Iterable[DatasetTypes]: 28 | for key in cls._mapping: 29 | if key[1] == strategy_type: 30 | yield key[0] 31 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/spawn_factory.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from ...constants import DatasetTypes 3 | from ...data_manifest import Spawn, SpawnConfig 4 | from .supported_operations_by_data_type import SupportedOperationsByDataType 5 | 6 | 7 | class SpawnFactory: 8 | _mapping = {} 9 | 10 | @classmethod 11 | def direct_register(cls, klass, data_type: DatasetTypes): 12 | SupportedOperationsByDataType.add(data_type, klass) 13 | cls._mapping[data_type] = klass 14 | return klass 15 | 16 | @classmethod 17 | def register(cls, data_type: DatasetTypes): 18 | def decorator(klass): 19 | return SpawnFactory.direct_register(klass, data_type) 20 | return decorator 21 | 22 | @classmethod 23 | def create(cls, data_type: DatasetTypes, config: SpawnConfig, *args, **kwargs) -> Spawn: 24 | 25 | return cls._mapping[data_type](config, *args, **kwargs) 26 | 27 | @classmethod 28 | def list_data_types(cls) -> typing.Iterable[DatasetTypes]: 29 | return list(cls._mapping.keys()) 30 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/split_factory.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from ...constants import DatasetTypes 3 | from ...data_manifest import Operation, SplitConfig 4 | from .supported_operations_by_data_type import SupportedOperationsByDataType 5 | 6 | 7 | class SplitFactory: 8 | _mapping = {} 9 | 10 | @classmethod 11 | def direct_register(cls, klass, data_type: DatasetTypes): 12 | cls._mapping[data_type] = klass 13 | SupportedOperationsByDataType.add(data_type, klass) 14 | return klass 15 | 16 | @classmethod 17 | def register(cls, data_type: DatasetTypes): 18 | def decorator(klass): 19 | return SplitFactory.direct_register(klass, data_type) 20 | return decorator 21 | 22 | @classmethod 23 | def create(cls, data_type: DatasetTypes, config: SplitConfig, *args, **kwargs) -> Operation: 24 | 25 | return cls._mapping[data_type](config, *args, **kwargs) 26 | 27 | @classmethod 28 | def list_data_types(cls) -> typing.Iterable[DatasetTypes]: 29 | return list(cls._mapping.keys()) 30 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/stand_alone_image_dicts_factory.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from ...constants import DatasetTypes 3 | from ...data_manifest import GenerateCocoDictBase 4 | from .supported_operations_by_data_type import SupportedOperationsByDataType 5 | 6 | 7 | class StandAloneImageListGeneratorFactory: 8 | _mapping = {} 9 | 10 | @classmethod 11 | def direct_register(cls, klass, data_type: DatasetTypes): 12 | cls._mapping[data_type] = klass 13 | SupportedOperationsByDataType.add(data_type, klass) 14 | return klass 15 | 16 | @classmethod 17 | def register(cls, data_type: DatasetTypes): 18 | def decorator(klass): 19 | return StandAloneImageListGeneratorFactory.direct_register(klass, data_type) 20 | return decorator 21 | 22 | @classmethod 23 | def create(cls, data_type: DatasetTypes, *args, **kwargs) -> GenerateCocoDictBase: 24 | return cls._mapping[data_type](*args, **kwargs) 25 | 26 | @classmethod 27 | def list_data_types(cls) -> typing.Iterable[DatasetTypes]: 28 | return list(cls._mapping.keys()) 29 | -------------------------------------------------------------------------------- /vision_datasets/common/factory/operations/supported_operations_by_data_type.py: -------------------------------------------------------------------------------- 1 | from ...constants import DatasetTypes 2 | 3 | 4 | class SupportedOperationsByDataType: 5 | _mapping = {} 6 | 7 | @classmethod 8 | def add(cls, data_type: DatasetTypes, klass): 9 | ops = cls._mapping.setdefault(data_type, []) 10 | ops.append(klass) 11 | 12 | @classmethod 13 | def list(cls, data_type: DatasetTypes): 14 | return cls._mapping.get(data_type, []) 15 | -------------------------------------------------------------------------------- /vision_datasets/common/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from urllib import parse as urlparse 3 | import os 4 | import pathlib 5 | 6 | 7 | def deep_merge(*dicts): 8 | merged = {} 9 | 10 | for d in dicts: 11 | if not isinstance(d, dict): 12 | continue 13 | 14 | for key, value in d.items(): 15 | if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): 16 | merged[key] = deep_merge(merged[key], value) 17 | else: 18 | merged[key] = value 19 | 20 | return merged 21 | 22 | 23 | def can_be_url(candidate: Union[str, pathlib.Path]): 24 | """ 25 | necessary conditions for candidate to be a url (not sufficient) 26 | Args: 27 | candidate (str): 28 | 29 | Returns: 30 | whether it could be a url or not 31 | 32 | """ 33 | try: 34 | if not candidate or not isinstance(candidate, str): 35 | return False 36 | 37 | result = urlparse.urlparse(candidate) 38 | return result.scheme and result.netloc 39 | except ValueError: 40 | return False 41 | 42 | 43 | def unix_path(path: Union[pathlib.Path, str]) -> Union[pathlib.Path, str]: 44 | if path is None: 45 | raise ValueError 46 | 47 | if isinstance(path, pathlib.Path): 48 | return path.as_posix() 49 | 50 | return path.replace('\\', '/') 51 | 52 | 53 | def _construct_full_path_generator(dirs: List[str]): 54 | """ 55 | Construct a function that appends dirs to a provided path. 56 | 57 | Args: 58 | dirs (str): dirs to be appended to a given path. None or empty str in dirs will be filtered. 59 | 60 | Returns: 61 | full_path_func: a func that appends dirs to a given path 62 | 63 | """ 64 | dirs = [x for x in dirs if x] 65 | 66 | if dirs: 67 | def full_path_func(path: Union[pathlib.Path, str]): 68 | if isinstance(path, pathlib.Path): 69 | path = path.as_posix() 70 | to_join = [x for x in dirs + [path] if x] 71 | return unix_path(os.path.join(*to_join)) 72 | else: 73 | full_path_func = unix_path 74 | 75 | return full_path_func 76 | 77 | 78 | def _construct_full_url_generator(container_url: str): 79 | if not container_url: 80 | return unix_path 81 | 82 | def add_path_to_url(url, path_or_dir): 83 | if not url: 84 | raise ValueError 85 | 86 | if not path_or_dir: 87 | return url 88 | 89 | parts = urlparse.urlparse(url) 90 | path = unix_path(os.path.join(parts[2], path_or_dir)) 91 | url = urlparse.urlunparse((parts[0], parts[1], path, parts[3], parts[4], parts[5])) 92 | 93 | return url 94 | 95 | def func(file_path): 96 | file_path = file_path.replace('.zip@', '/') # cannot read from zip file with path targeting a url 97 | return add_path_to_url(container_url, file_path) 98 | 99 | return func 100 | 101 | 102 | def construct_full_url_or_path_func(url_or_root_dir: Union[str, pathlib.Path], prefix_dir: Union[str, pathlib.Path] = None): 103 | if url_or_root_dir and can_be_url(url_or_root_dir): 104 | return lambda path: _construct_full_url_generator(url_or_root_dir)(_construct_full_path_generator([prefix_dir])(path)) 105 | else: 106 | return lambda path: _construct_full_path_generator([url_or_root_dir, prefix_dir])(path) 107 | -------------------------------------------------------------------------------- /vision_datasets/image_caption/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import ImageCaptionCocoManifestAdaptor 2 | from .operations import ImageCaptionCocoDictGenerator 3 | from .manifest import ImageCaptionLabelManifest 4 | 5 | __all__ = ['ImageCaptionCocoManifestAdaptor', 'ImageCaptionCocoDictGenerator', 'ImageCaptionLabelManifest'] 6 | -------------------------------------------------------------------------------- /vision_datasets/image_caption/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, ImageDataManifest, CocoManifestAdaptorFactory, CocoManifestWithoutCategoriesAdaptor 2 | from .manifest import ImageCaptionLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.IMAGE_CAPTION) 6 | class ImageCaptionCocoManifestAdaptor(CocoManifestWithoutCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.IMAGE_CAPTION) 9 | 10 | def process_label(self, image: ImageDataManifest, annotation: dict, coco_manifest: dict): 11 | image.labels.append(ImageCaptionLabelManifest(annotation['caption'], additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'caption'}))) 12 | -------------------------------------------------------------------------------- /vision_datasets/image_caption/manifest.py: -------------------------------------------------------------------------------- 1 | from ..common import ImageLabelManifest 2 | 3 | 4 | class ImageCaptionLabelManifest(ImageLabelManifest): 5 | """ 6 | caption: in str 7 | """ 8 | 9 | @property 10 | def caption(self): 11 | return self.label_data 12 | 13 | def _read_label_data(self): 14 | raise NotImplementedError 15 | 16 | def _check_label(self, label_data): 17 | if label_data is None: 18 | raise ValueError 19 | -------------------------------------------------------------------------------- /vision_datasets/image_caption/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, GenerateCocoDictBase, ImageLabelManifest, SampleByNumSamples, SampleStrategyType, SingleTaskMerge, Spawn, Split, CocoDictGeneratorFactory, \ 2 | ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, \ 3 | ImageDataManifest, DatasetManifest 4 | from .manifest import ImageCaptionLabelManifest 5 | 6 | _DATA_TYPE = DatasetTypes.IMAGE_CAPTION 7 | 8 | 9 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 10 | class ImageCaptionCocoDictGenerator(GenerateCocoDictBase): 11 | def process_labels(self, coco_ann, label: ImageLabelManifest): 12 | coco_ann['caption'] = label.label_data 13 | 14 | 15 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 16 | 17 | 18 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 19 | 20 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 21 | 22 | SplitFactory.direct_register(Split, _DATA_TYPE) 23 | 24 | StandAloneImageListGeneratorFactory.direct_register(GenerateStandAloneImageListBase, _DATA_TYPE) 25 | 26 | 27 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 28 | class ImageCaptionStandAloneImageListGenerator(GenerateStandAloneImageListBase): 29 | def _generate_label(self, label: ImageCaptionLabelManifest, image: ImageDataManifest, manifest: DatasetManifest): 30 | return {'caption': label.caption} 31 | -------------------------------------------------------------------------------- /vision_datasets/image_classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import MultiClassClassificationCocoManifestAdaptor, MultiLabelClassificationCocoManifestAdaptor 2 | from .operations import ImageClassificationCocoDictGenerator 3 | from .manifest import ImageClassificationLabelManifest 4 | from .classification_as_kvp_dataset import MulticlassClassificationAsKeyValuePairDataset, MultilabelClassificationAsKeyValuePairDataset 5 | 6 | __all__ = ['MultiClassClassificationCocoManifestAdaptor', 'MultiLabelClassificationCocoManifestAdaptor', 7 | 'ImageClassificationCocoDictGenerator', 8 | 'ImageClassificationLabelManifest', 9 | 'MulticlassClassificationAsKeyValuePairDataset', 'MultilabelClassificationAsKeyValuePairDataset'] 10 | -------------------------------------------------------------------------------- /vision_datasets/image_classification/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes 2 | from ..common.data_manifest import ImageDataManifest 3 | from ..common.data_manifest.coco_manifest_adaptor import CocoManifestWithCategoriesAdaptor 4 | from ..common.factory import CocoManifestAdaptorFactory 5 | from .manifest import ImageClassificationLabelManifest 6 | 7 | 8 | @CocoManifestAdaptorFactory.register(DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 9 | class MultiClassClassificationCocoManifestAdaptor(CocoManifestWithCategoriesAdaptor): 10 | def __init__(self) -> None: 11 | super().__init__(DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 12 | 13 | def process_label(self, image: ImageDataManifest, annotation, coco_manifest, label_id_to_pos): 14 | if len(image.labels) != 0: 15 | raise ValueError( 16 | f"image with id {annotation['image_id']} will possess unexpected number of annotations {len(image.labels) + 1} for {DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS} dataset.") 17 | 18 | label = ImageClassificationLabelManifest(label_id_to_pos[annotation['category_id']], additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'category_id'})) 19 | image.labels.append(label) 20 | 21 | 22 | @CocoManifestAdaptorFactory.register(DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 23 | class MultiLabelClassificationCocoManifestAdaptor(CocoManifestWithCategoriesAdaptor): 24 | def __init__(self) -> None: 25 | super().__init__(DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 26 | 27 | def process_label(self, image: ImageDataManifest, annotation, coco_manifest, label_id_to_pos): 28 | label = ImageClassificationLabelManifest(label_id_to_pos[annotation['category_id']], additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'category_id'})) 29 | image.labels.append(label) 30 | -------------------------------------------------------------------------------- /vision_datasets/image_classification/manifest.py: -------------------------------------------------------------------------------- 1 | from ..common import ImageLabelWithCategoryManifest 2 | 3 | 4 | class ImageClassificationLabelManifest(ImageLabelWithCategoryManifest): 5 | """ 6 | c_id: class id starting from zero 7 | """ 8 | 9 | @property 10 | def category_id(self): 11 | return self.label_data 12 | 13 | @category_id.setter 14 | def category_id(self, value): 15 | self._category_id_check(value) 16 | self.label_data = value 17 | 18 | def _read_label_data(self): 19 | raise NotImplementedError 20 | 21 | def _check_label(self, label_data): 22 | if label_data is None or label_data < 0: 23 | raise ValueError 24 | -------------------------------------------------------------------------------- /vision_datasets/image_classification/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, BalancedInstanceWeightsGenerator, GenerateCocoDictBase, SampleByNumSamples, SampleFewShot, SampleStrategyType, SingleTaskMerge, Spawn, \ 2 | SplitWithCategories, BalancedInstanceWeightsFactory, CocoDictGeneratorFactory, ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, \ 3 | StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, DatasetManifest, ImageDataManifest 4 | from .manifest import ImageClassificationLabelManifest 5 | 6 | 7 | @CocoDictGeneratorFactory.register(DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 8 | @CocoDictGeneratorFactory.register(DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 9 | class ImageClassificationCocoDictGenerator(GenerateCocoDictBase): 10 | def process_labels(self, coco_ann, label: ImageClassificationLabelManifest): 11 | coco_ann['category_id'] = label.label_data + 1 12 | 13 | 14 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 15 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 16 | 17 | 18 | SampleStrategyFactory.direct_register(SampleByNumSamples, DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS, SampleStrategyType.NumSamples) 19 | SampleStrategyFactory.direct_register(SampleFewShot, DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS, SampleStrategyType.FewShot) 20 | 21 | SampleStrategyFactory.direct_register(SampleByNumSamples, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL, SampleStrategyType.NumSamples) 22 | SampleStrategyFactory.direct_register(SampleFewShot, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL, SampleStrategyType.FewShot) 23 | 24 | SpawnFactory.direct_register(Spawn, DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 25 | SpawnFactory.direct_register(Spawn, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 26 | 27 | 28 | BalancedInstanceWeightsFactory.direct_register(BalancedInstanceWeightsGenerator, DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 29 | BalancedInstanceWeightsFactory.direct_register(BalancedInstanceWeightsGenerator, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 30 | 31 | SplitFactory.direct_register(SplitWithCategories, DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 32 | SplitFactory.direct_register(SplitWithCategories, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 33 | 34 | 35 | @StandAloneImageListGeneratorFactory.register(DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS) 36 | @StandAloneImageListGeneratorFactory.register(DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 37 | class ImageClassificationStandAloneImageListGenerator(GenerateStandAloneImageListBase): 38 | def _generate_label(self, label: ImageClassificationLabelManifest, image: ImageDataManifest, manifest: DatasetManifest): 39 | return {'category_name': manifest.categories[label.category_id]} 40 | -------------------------------------------------------------------------------- /vision_datasets/image_matting/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import ImageMattingCocoManifestAdaptor 2 | from .operations import ImageMattingCocoDictGenerator 3 | from .manifest import ImageMattingLabelManifest 4 | 5 | __all__ = ['ImageMattingCocoManifestAdaptor', 'ImageMattingCocoDictGenerator', 'ImageMattingLabelManifest'] 6 | -------------------------------------------------------------------------------- /vision_datasets/image_matting/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, ImageDataManifest, CocoManifestWithoutCategoriesAdaptor, CocoManifestAdaptorFactory 2 | from .manifest import ImageMattingLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.IMAGE_MATTING) 6 | class ImageMattingCocoManifestAdaptor(CocoManifestWithoutCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.IMAGE_MATTING) 9 | 10 | def process_label(self, image: ImageDataManifest, annotation: dict, coco_manifest: dict): 11 | image.labels.append(ImageMattingLabelManifest(label_path=self._append_zip_prefix_if_needed(annotation, annotation['label']), 12 | additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'label'}))) 13 | -------------------------------------------------------------------------------- /vision_datasets/image_matting/manifest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | from ..common import ImageLabelManifest, FileReader 5 | 6 | 7 | class ImageMattingLabelManifest(ImageLabelManifest): 8 | """ 9 | matting: 2D numpy array that has the same width and height with the image 10 | """ 11 | 12 | @property 13 | def matting_image(self) -> np.ndarray: 14 | return self.label_data 15 | 16 | def _read_label_data(self): 17 | file_reader = FileReader() 18 | with file_reader.open(self.label_path) as f: 19 | label = np.asarray(Image.open(f)) 20 | file_reader.close() 21 | 22 | return label 23 | 24 | def _check_label(self, label_data): 25 | if label_data is None: 26 | raise ValueError 27 | -------------------------------------------------------------------------------- /vision_datasets/image_matting/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, GenerateCocoDictBase, DatasetManifest, ImageDataManifest, ImageLabelManifest, SampleByNumSamples, SampleStrategyType, SingleTaskMerge, Spawn, Split, \ 2 | CocoDictGeneratorFactory, ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, FileReader 3 | from ..common.base64_utils import Base64Utils 4 | _DATA_TYPE = DatasetTypes.IMAGE_MATTING 5 | 6 | 7 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 8 | class ImageMattingCocoDictGenerator(GenerateCocoDictBase): 9 | def process_labels(self, coco_ann, label: ImageLabelManifest): 10 | coco_ann['label'] = label.label_path 11 | 12 | 13 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 14 | 15 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 16 | 17 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 18 | SplitFactory.direct_register(Split, _DATA_TYPE) 19 | 20 | 21 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 22 | class ImageMattingStandAloneImageListGenerator(GenerateStandAloneImageListBase): 23 | def __init__(self, flatten: bool) -> None: 24 | super().__init__(flatten) 25 | 26 | def _generate_label(self, label: ImageLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> dict: 27 | file_reader = FileReader() 28 | return {"matting_image": Base64Utils.file_to_b64_str(label.label_path, file_reader)} 29 | -------------------------------------------------------------------------------- /vision_datasets/image_object_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import ImageObjectDetectionCocoManifestAdaptor 2 | from .manifest import ImageObjectDetectionLabelManifest 3 | from .operations import ImageObjectDetectionCocoDictGenerator 4 | from .detection_as_classification_dataset import DetectionAsClassificationBaseDataset, DetectionAsClassificationByCroppingDataset, DetectionAsClassificationIgnoreBoxesDataset 5 | from .detection_as_kvp_dataset import DetectionAsKeyValuePairDataset, DetectionAsKeyValuePairDatasetForMultilabelClassification, DetectionAsKeyValuePairDatasetForObjectCounting 6 | 7 | __all__ = ['ImageObjectDetectionCocoManifestAdaptor', 'ImageObjectDetectionLabelManifest', 'ImageObjectDetectionCocoDictGenerator', 'DetectionAsClassificationBaseDataset', 8 | 'DetectionAsClassificationByCroppingDataset', 'DetectionAsClassificationIgnoreBoxesDataset', 'DetectionAsKeyValuePairDataset', 'DetectionAsKeyValuePairDatasetForMultilabelClassification', 9 | 'DetectionAsKeyValuePairDatasetForObjectCounting'] 10 | -------------------------------------------------------------------------------- /vision_datasets/image_object_detection/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import BBoxFormat, CocoManifestAdaptorFactory, CocoManifestWithCategoriesAdaptor, DatasetTypes 2 | from .manifest import ImageObjectDetectionLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.IMAGE_OBJECT_DETECTION) 6 | class ImageObjectDetectionCocoManifestAdaptor(CocoManifestWithCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.IMAGE_OBJECT_DETECTION) 9 | 10 | def process_label(self, image, annotation, coco_manifest, label_id_to_pos): 11 | bbox_format = coco_manifest.get('bbox_format') 12 | bbox_format = BBoxFormat[bbox_format.upper()] if bbox_format else BBoxFormat.LTWH 13 | 14 | c_id = label_id_to_pos[annotation['category_id']] 15 | bbox = annotation['bbox'] 16 | bbox = bbox if bbox_format == BBoxFormat.LTRB else [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]] 17 | label = ImageObjectDetectionLabelManifest([c_id] + bbox, additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'category_id', 'bbox'})) 18 | image.labels.append(label) 19 | -------------------------------------------------------------------------------- /vision_datasets/image_object_detection/manifest.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from ..common import ImageLabelWithCategoryManifest 4 | 5 | 6 | class ImageObjectDetectionLabelManifest(ImageLabelWithCategoryManifest): 7 | """ 8 | [c_id, left, top, right, bottom], ...] (absolute coordinates); 9 | """ 10 | 11 | @property 12 | def category_id(self) -> int: 13 | return self.label_data[0] 14 | 15 | @property 16 | def left(self): 17 | return self.label_data[1] 18 | 19 | @property 20 | def top(self): 21 | return self.label_data[2] 22 | 23 | @property 24 | def right(self): 25 | return self.label_data[3] 26 | 27 | @property 28 | def bottom(self): 29 | return self.label_data[4] 30 | 31 | @category_id.setter 32 | def category_id(self, value): 33 | self._category_id_check(value) 34 | self.label_data[0] = value 35 | 36 | def _read_label_data(self): 37 | raise NotImplementedError 38 | 39 | def _check_label(self, label_data): 40 | if not label_data or len(label_data) != 5: 41 | raise ValueError 42 | -------------------------------------------------------------------------------- /vision_datasets/image_object_detection/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, BalancedInstanceWeightsGenerator, GenerateCocoDictBase, SampleByNumSamples, SampleFewShot, SampleStrategyType, \ 2 | SingleTaskMerge, Spawn, Split, BalancedInstanceWeightsFactory, CocoDictGeneratorFactory, ManifestMergeStrategyFactory, SampleStrategyFactory, \ 3 | SpawnFactory, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, \ 4 | ImageDataManifest, DatasetManifest 5 | from .manifest import ImageObjectDetectionLabelManifest 6 | 7 | _DATA_TYPE = DatasetTypes.IMAGE_OBJECT_DETECTION 8 | 9 | 10 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 11 | class ImageObjectDetectionCocoDictGenerator(GenerateCocoDictBase): 12 | def process_labels(self, coco_ann, label: ImageObjectDetectionLabelManifest): 13 | ann = label.label_data 14 | coco_ann['category_id'] = ann[0] + 1 15 | coco_ann['bbox'] = [ann[1], ann[2], ann[3] - ann[1], ann[4] - ann[2]] 16 | 17 | 18 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 19 | 20 | 21 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 22 | SampleStrategyFactory.direct_register(SampleFewShot, _DATA_TYPE, SampleStrategyType.FewShot) 23 | 24 | 25 | BalancedInstanceWeightsFactory.direct_register(BalancedInstanceWeightsGenerator, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL) 26 | 27 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 28 | SplitFactory.direct_register(Split, _DATA_TYPE) 29 | 30 | 31 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 32 | class ImageObjectDetectionStandAloneImageListGenerator(GenerateStandAloneImageListBase): 33 | def _generate_label(self, label: ImageObjectDetectionLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> dict: 34 | return {'category_name': manifest.categories[label.category_id], 'bbox': label.label_data[1:]} 35 | -------------------------------------------------------------------------------- /vision_datasets/image_regression/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import ImageRegressionCocoManifestAdaptor 2 | from .manifest import ImageRegressionLabelManifest 3 | from .operations import ImageRegressionCocoDictGenerator 4 | 5 | __all__ = ['ImageRegressionCocoManifestAdaptor', 'ImageRegressionLabelManifest', 'ImageRegressionCocoDictGenerator'] 6 | -------------------------------------------------------------------------------- /vision_datasets/image_regression/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, ImageDataManifest, CocoManifestWithoutCategoriesAdaptor, CocoManifestAdaptorFactory 2 | from .manifest import ImageRegressionLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.IMAGE_REGRESSION) 6 | class ImageRegressionCocoManifestAdaptor(CocoManifestWithoutCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.IMAGE_REGRESSION) 9 | 10 | def process_label(self, image: ImageDataManifest, annotation: dict, coco_manifest: dict): 11 | if len(image.labels) != 0: 12 | raise ValueError(f"image with id {annotation['image_id']} will possess unexpected number of annotations {len(image.labels) + 1} for {DatasetTypes.IMAGE_REGRESSION} dataset.") 13 | image.labels.append(ImageRegressionLabelManifest(annotation['target'], additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'target'}))) 14 | -------------------------------------------------------------------------------- /vision_datasets/image_regression/manifest.py: -------------------------------------------------------------------------------- 1 | from ..common import ImageLabelManifest 2 | 3 | 4 | class ImageRegressionLabelManifest(ImageLabelManifest): 5 | """ 6 | value: regression target in float 7 | """ 8 | 9 | @property 10 | def target(self) -> float: 11 | return self.label_data 12 | 13 | def _read_label_data(self): 14 | raise NotImplementedError 15 | 16 | def _check_label(self, label_data): 17 | if label_data is None: 18 | raise ValueError 19 | -------------------------------------------------------------------------------- /vision_datasets/image_regression/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, GenerateCocoDictBase, ImageLabelManifest, SampleByNumSamples, SampleStrategyType, SingleTaskMerge, Spawn, Split, CocoDictGeneratorFactory, \ 2 | ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, \ 3 | ImageDataManifest, DatasetManifest 4 | from .manifest import ImageRegressionLabelManifest 5 | 6 | _DATA_TYPE = DatasetTypes.IMAGE_REGRESSION 7 | 8 | 9 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 10 | class ImageRegressionCocoDictGenerator(GenerateCocoDictBase): 11 | def process_labels(self, coco_ann, label: ImageLabelManifest): 12 | coco_ann['target'] = label.label_data 13 | 14 | 15 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 16 | 17 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 18 | 19 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 20 | SplitFactory.direct_register(Split, _DATA_TYPE) 21 | 22 | 23 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 24 | class ImageRegressionStandAloneImageListGenerator(GenerateStandAloneImageListBase): 25 | def _generate_label(self, label: ImageRegressionLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> dict: 26 | return {'target': label.target} 27 | -------------------------------------------------------------------------------- /vision_datasets/image_text_matching/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import ImageTextMatchingCocoManifestAdaptor 2 | from .manifest import ImageTextMatchingLabelManifest 3 | from .operations import ImageTextMatchingCocoDictGenerator 4 | from .vision_as_image_text_dataset import VisionAsImageTextDataset 5 | 6 | __all__ = ['ImageTextMatchingCocoManifestAdaptor', 'ImageTextMatchingLabelManifest', 'ImageTextMatchingCocoDictGenerator', 'VisionAsImageTextDataset'] 7 | -------------------------------------------------------------------------------- /vision_datasets/image_text_matching/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, ImageDataManifest, CocoManifestWithoutCategoriesAdaptor, CocoManifestAdaptorFactory 2 | from .manifest import ImageTextMatchingLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.IMAGE_TEXT_MATCHING) 6 | class ImageTextMatchingCocoManifestAdaptor(CocoManifestWithoutCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.IMAGE_TEXT_MATCHING) 9 | 10 | def process_label(self, image: ImageDataManifest, annotation: dict, coco_manifest: dict): 11 | image.labels.append(ImageTextMatchingLabelManifest((annotation['text'], annotation['match']), additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'text', 'match'}))) 12 | -------------------------------------------------------------------------------- /vision_datasets/image_text_matching/manifest.py: -------------------------------------------------------------------------------- 1 | from ..common import ImageLabelManifest 2 | 3 | 4 | class ImageTextMatchingLabelManifest(ImageLabelManifest): 5 | """ 6 | (text, match): where text is str, and match is between [0, 1], where 0 means not match at all, 1 means perfect match 7 | """ 8 | 9 | @property 10 | def text(self) -> str: 11 | return self.label_data[0] 12 | 13 | @property 14 | def match(self) -> float: 15 | return self.label_data[1] 16 | 17 | def _read_label_data(self): 18 | raise NotImplementedError 19 | 20 | def _check_label(self, label_data): 21 | if not label_data or len(label_data) != 2 or label_data[0] is None or label_data[1] not in [0, 1]: 22 | raise ValueError 23 | -------------------------------------------------------------------------------- /vision_datasets/image_text_matching/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, GenerateCocoDictBase, ImageLabelManifest, SampleByNumSamples, SampleStrategyType, SingleTaskMerge, Spawn, Split, CocoDictGeneratorFactory, \ 2 | ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, \ 3 | ImageDataManifest, DatasetManifest 4 | from .manifest import ImageTextMatchingLabelManifest 5 | _DATA_TYPE = DatasetTypes.IMAGE_TEXT_MATCHING 6 | 7 | 8 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 9 | class ImageTextMatchingCocoDictGenerator(GenerateCocoDictBase): 10 | def process_labels(self, coco_ann, label: ImageLabelManifest): 11 | coco_ann['text'] = label.label_data[0] 12 | coco_ann['match'] = label.label_data[1] 13 | 14 | 15 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 16 | 17 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 18 | 19 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 20 | SplitFactory.direct_register(Split, _DATA_TYPE) 21 | 22 | 23 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 24 | class ImageTextMatchingStandAloneImageListGenerator(GenerateStandAloneImageListBase): 25 | def _generate_label(self, label: ImageTextMatchingLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> dict: 26 | return {'text': label.text, 'match': label.match} 27 | -------------------------------------------------------------------------------- /vision_datasets/image_text_matching/vision_as_image_text_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | 4 | from ..common import DatasetTypes 5 | from ..common.dataset.base_dataset import BaseDataset 6 | from .manifest import ImageTextMatchingLabelManifest 7 | 8 | 9 | class VisionAsImageTextDataset(BaseDataset): 10 | """ 11 | Consume traditional vision datasets of type 12 | [DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL, DatasetTypes.IMAGE_OBJECT_DETECTION], as DatasetTypes.IMAGE_TEXT_MATCHING dataset. 13 | For a certain image, negative image-text pairs are generated from the labels that this image does not possess. 14 | """ 15 | 16 | def __init__(self, dataset: BaseDataset, neg_to_pos_ratio=0, text_aug=None, rnd_seed=0): 17 | """ 18 | Args: 19 | dataset: dataset of expected type 20 | neg_to_pos_ratio: ratio of negative against positive image text pairs 21 | text_aug: a func that augments a string, i.e., a class name, e.g. dog => a photo of dog 22 | rnd_seed: random seed for choosing negative class names for negative image text pairs 23 | """ 24 | if dataset is None: 25 | raise ValueError 26 | if dataset.dataset_info.type not in [DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS, DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL, DatasetTypes.IMAGE_OBJECT_DETECTION]: 27 | raise ValueError 28 | if neg_to_pos_ratio < 0: 29 | raise ValueError 30 | dataset_info = deepcopy(dataset.dataset_info) 31 | dataset_info.type = DatasetTypes.IMAGE_TEXT_MATCHING 32 | 33 | super().__init__(dataset_info) 34 | self._dataset = dataset 35 | self._negative_pair_ratio = neg_to_pos_ratio 36 | self._text_aug = text_aug or (lambda x: x) 37 | self._rand_seed = rnd_seed 38 | 39 | @property 40 | def categories(self): 41 | return None 42 | 43 | def __len__(self): 44 | return len(self._dataset) 45 | 46 | def _get_single_item(self, index): 47 | img, target, _ = self._dataset[index] 48 | pos_class_indices = [x.category_id for x in target] 49 | pos_class_names = [self._dataset.categories[x].name for x in pos_class_indices] 50 | labels = [ImageTextMatchingLabelManifest((self._text_aug(class_name), 1)) for class_name in pos_class_names] 51 | if self._negative_pair_ratio > 0: 52 | neg_class_indices = set(range(len(self._dataset.categories))) - set(pos_class_indices) 53 | neg_class_names = [self._dataset.categories[x].name for x in neg_class_indices] 54 | if neg_class_names: 55 | down_sample_ratio = self._negative_pair_ratio * len(pos_class_names) / len(neg_class_names) 56 | if down_sample_ratio < 1: 57 | neg_class_names = [ncn for j, ncn in enumerate(neg_class_names) if random.Random(self._rand_seed+(index) * len(neg_class_names) + j).random() < down_sample_ratio] 58 | 59 | neg_labels = [ImageTextMatchingLabelManifest((self._text_aug(class_name), 0)) for class_name in neg_class_names] 60 | labels += neg_labels 61 | return img, labels, str(index) 62 | 63 | def close(self): 64 | self._dataset.close() 65 | -------------------------------------------------------------------------------- /vision_datasets/key_value_pair/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import KeyValuePairCocoManifestAdaptor 2 | from .manifest import KeyValuePairLabelManifest, KeyValuePairDatasetManifest, KeyValuePairSchema 3 | from .operations import KeyValuePairCocoDictGenerator, KeyValuePairDatasetSampleByNumSamples 4 | 5 | __all__ = ['KeyValuePairCocoManifestAdaptor', 'KeyValuePairCocoDictGenerator', 'KeyValuePairDatasetManifest', 6 | 'KeyValuePairLabelManifest', 'KeyValuePairSchema', 'KeyValuePairDatasetSampleByNumSamples'] 7 | -------------------------------------------------------------------------------- /vision_datasets/key_value_pair/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ..common import BBoxFormat, DatasetTypes, CocoManifestAdaptorFactory, CocoManifestWithMultiImageLabelAdaptor 3 | from .manifest import KeyValuePairLabelManifest, KeyValuePairDatasetManifest, KeyValuePairSchema 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @CocoManifestAdaptorFactory.register(DatasetTypes.KEY_VALUE_PAIR) 9 | class KeyValuePairCocoManifestAdaptor(CocoManifestWithMultiImageLabelAdaptor): 10 | def __init__(self, schema: dict) -> None: 11 | self.schema_dict = schema 12 | self.schema = KeyValuePairSchema(schema['name'], schema['fieldSchema'], schema.get('description', None)) 13 | super().__init__(DatasetTypes.KEY_VALUE_PAIR) 14 | 15 | def _construct_label_manifest(self, img_ids, ann, coco_manifest): 16 | label_data = self.process_label(ann, coco_manifest) 17 | return KeyValuePairLabelManifest(ann['id'], img_ids, label_data, self._get_additional_info(ann, {'id', KeyValuePairLabelManifest.IMAGES_INPUT_KEY, KeyValuePairLabelManifest.LABEL_KEY, 18 | KeyValuePairLabelManifest.TEXT_INPUT_KEY})) 19 | 20 | def _construct_manifest(self, images_by_id, coco_manifest, data_type, additional_info): 21 | images, annotations = self.get_images_and_annotations(images_by_id, coco_manifest) 22 | return KeyValuePairDatasetManifest(images, annotations, self.schema_dict, additional_info) 23 | 24 | def convert_bbox_ltwh_to_ltrb(self, value): 25 | if isinstance(value, list): 26 | for i in range(len(value)): 27 | self.convert_bbox_ltwh_to_ltrb(value[i]) 28 | elif isinstance(value, dict): 29 | for k in value.keys(): 30 | if k == KeyValuePairLabelManifest.LABEL_GROUNDINGS_KEY: 31 | for i, grounding in enumerate(value[k]): 32 | value[k][i] = [grounding[0], grounding[1], grounding[0] + grounding[2], grounding[1] + grounding[3]] 33 | else: 34 | self.convert_bbox_ltwh_to_ltrb(value[k]) 35 | 36 | def check_no_groundings_for_multi_image_annotation(self, value: dict): 37 | if isinstance(value, list): 38 | for v in value: 39 | self.check_no_groundings_for_multi_image_annotation(v) 40 | elif isinstance(value, dict): 41 | if KeyValuePairLabelManifest.LABEL_GROUNDINGS_KEY in value: 42 | raise ValueError('Groundings are not allowed for multi-image annotations') 43 | for v in value.values(): 44 | self.check_no_groundings_for_multi_image_annotation(v) 45 | 46 | def process_label(self, annotation: dict, coco_manifest: dict): 47 | if KeyValuePairLabelManifest.LABEL_KEY not in annotation: 48 | raise ValueError(f'{KeyValuePairLabelManifest.LABEL_KEY} not found in annotation {annotation}') 49 | 50 | bbox_format = coco_manifest.get('bbox_format') 51 | bbox_format = BBoxFormat[bbox_format.upper()] if bbox_format else BBoxFormat.LTWH 52 | if bbox_format == BBoxFormat.LTWH: 53 | logger.info('Provided bounding box format is LTWH, converting bounding boxes (if any) to LTRB') 54 | for k in annotation[KeyValuePairLabelManifest.LABEL_KEY]: 55 | self.convert_bbox_ltwh_to_ltrb(annotation[KeyValuePairLabelManifest.LABEL_KEY][k]) 56 | 57 | # If the annotation is for multiple images, groundings should be disabled 58 | if len(annotation['image_ids']) > 1: 59 | for field in annotation[KeyValuePairLabelManifest.LABEL_KEY].values(): 60 | self.check_no_groundings_for_multi_image_annotation(field) 61 | 62 | label_data = {KeyValuePairLabelManifest.LABEL_KEY: annotation[KeyValuePairLabelManifest.LABEL_KEY], 63 | KeyValuePairLabelManifest.TEXT_INPUT_KEY: annotation.get(KeyValuePairLabelManifest.TEXT_INPUT_KEY, None)} 64 | KeyValuePairLabelManifest.check_schema_match(label_data[KeyValuePairLabelManifest.LABEL_KEY], self.schema) 65 | return label_data 66 | -------------------------------------------------------------------------------- /vision_datasets/key_value_pair/operations.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | from ..common import DatasetTypes, \ 5 | MultiImageCocoDictGenerator, \ 6 | MultiImageDatasetSingleTaskMerge, CocoDictGeneratorFactory, ManifestMergeStrategyFactory, \ 7 | SampleStrategyFactory, SampleStrategyType, SampleStrategy 8 | from .manifest import KeyValuePairLabelManifest, KeyValuePairDatasetManifest 9 | 10 | _DATA_TYPE = DatasetTypes.KEY_VALUE_PAIR 11 | 12 | 13 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 14 | class KeyValuePairCocoDictGenerator(MultiImageCocoDictGenerator): 15 | def process_labels(self, coco_ann, label: KeyValuePairLabelManifest): 16 | coco_ann[KeyValuePairLabelManifest.LABEL_KEY] = label.fields 17 | if label.text is not None: 18 | coco_ann[KeyValuePairLabelManifest.TEXT_INPUT_KEY] = label.text 19 | 20 | def _generate_images(self, manifest: KeyValuePairDatasetManifest): 21 | images = super()._generate_images(manifest) 22 | # add metadata field if exists 23 | for img, img_manifest in zip(images, manifest.images): 24 | if img_manifest.additional_info is not None and 'metadata' in img_manifest.additional_info: 25 | img['metadata'] = img_manifest.additional_info['metadata'] 26 | return images 27 | 28 | 29 | @ManifestMergeStrategyFactory.register(_DATA_TYPE) 30 | class KeyValuePairDatasetMerge(MultiImageDatasetSingleTaskMerge): 31 | def merge(self, *args: KeyValuePairDatasetManifest): 32 | schema = args[0].schema 33 | for manifest in args[1:]: 34 | if manifest.schema != schema: 35 | raise ValueError('Schema mismatch') 36 | return super().merge(*args) 37 | 38 | 39 | @SampleStrategyFactory.register(_DATA_TYPE, SampleStrategyType.NumSamples) 40 | class KeyValuePairDatasetSampleByNumSamples(SampleStrategy): 41 | def __init__(self, config): 42 | super().__init__(config) 43 | if config.n_samples <= 0: 44 | raise ValueError('n samples must be greater than zero.') 45 | 46 | def sample(self, manifest: KeyValuePairDatasetManifest) -> KeyValuePairDatasetManifest: 47 | if not self.config.with_replacement and self.config.n_samples > len(manifest.annotations): 48 | raise ValueError('When with_replacement is disabled, n_samples must be less than or equal to the number of annotations in the dataset.') 49 | 50 | rng = np.random.default_rng(self.config.random_seed) 51 | normalized_weights = [w / sum(self.config.weights) for w in self.config.weights] if self.config.weights else None 52 | sampled_indices = rng.choice(len(manifest.annotations), size=self.config.n_samples, replace=self.config.with_replacement, p=normalized_weights) 53 | sampled_annotations = [manifest.annotations[i] for i in sampled_indices] 54 | 55 | return KeyValuePairDatasetManifest(copy.deepcopy(manifest.images), 56 | copy.deepcopy(sampled_annotations), 57 | copy.deepcopy(manifest.schema), 58 | copy.deepcopy(manifest.additional_info)) 59 | -------------------------------------------------------------------------------- /vision_datasets/multi_task/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import MultiTaskCocoManifestAdaptor 2 | from .operations import MultitaskMerge 3 | 4 | 5 | __all__ = ['MultiTaskCocoManifestAdaptor', 'MultitaskMerge'] 6 | -------------------------------------------------------------------------------- /vision_datasets/multi_task/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import typing 3 | 4 | from ..common import DatasetTypes, CocoManifestAdaptorBase, CocoManifestAdaptorFactory 5 | from ..common.data_manifest.utils import generate_multitask_dataset_manifest 6 | 7 | 8 | @CocoManifestAdaptorFactory.register(DatasetTypes.MULTITASK) 9 | class MultiTaskCocoManifestAdaptor(CocoManifestAdaptorBase): 10 | def create_dataset_manifest(self, coco_file_path_or_url: typing.Union[str, dict, pathlib.Path], container_sas_or_root_dir: str = None): 11 | """ construct a dataset manifest out of coco file 12 | Args: 13 | coco_file_path_or_url (str or pathlib.Path or dict): path or url to coco file. dict if multitask 14 | container_sas_or_root_dir (str): container sas if resources are store in blob container, or a local dir 15 | """ 16 | 17 | if not coco_file_path_or_url: 18 | return None 19 | 20 | if not isinstance(coco_file_path_or_url, dict): 21 | raise ValueError 22 | if not isinstance(self.data_type, dict): 23 | raise ValueError 24 | dataset_manifest_by_task = {k: CocoManifestAdaptorFactory.create(self.data_type[k]).create_dataset_manifest(coco_file_path_or_url[k], container_sas_or_root_dir) 25 | for k in coco_file_path_or_url} 26 | 27 | return generate_multitask_dataset_manifest(dataset_manifest_by_task) 28 | 29 | def get_images_and_categories(self, images_by_id, coco_manifest): 30 | pass 31 | -------------------------------------------------------------------------------- /vision_datasets/multi_task/operations.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from ..common import DatasetTypes, DatasetManifest, MergeStrategy, ManifestMergeStrategyFactory 4 | from ..common.utils import deep_merge 5 | 6 | _DATA_TYPE = DatasetTypes.MULTITASK 7 | 8 | 9 | @ManifestMergeStrategyFactory.register(_DATA_TYPE) 10 | class MultitaskMerge(MergeStrategy): 11 | def merge(self, *args: DatasetManifest): 12 | all_categories = {} 13 | data_tasks = {} 14 | for manifest in args: 15 | for task_name, categories in manifest.categories.items(): 16 | all_categories[task_name] = copy.deepcopy(categories) 17 | 18 | for task_name, tasks in manifest.data_type.items(): 19 | data_tasks[task_name] = copy.deepcopy(tasks) 20 | 21 | addional_info = deep_merge([x.additional_info for x in args]) 22 | if not addional_info: 23 | addional_info = None 24 | return DatasetManifest([y for x in args for y in x.images], args[0].categories, args[0].data_type, addional_info) 25 | 26 | def check(self, *args: DatasetManifest): 27 | super().check(*args) 28 | if not all([x.is_multitask for x in args]): 29 | raise ValueError('all manifests must be multitask.') 30 | 31 | for i in range(len(args)-1): 32 | m1, m2 = args[i], args[i+1] 33 | if m1.data_type != m2.data_type: 34 | raise ValueError 35 | if m1.categories != m2.categories: 36 | raise ValueError 37 | 38 | 39 | # SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 40 | # SampleStrategyFactory.direct_register(SampleFewShot, _DATA_TYPE, SampleStrategyType.FewShot) 41 | 42 | # SplitFactory.direct_register(Split, _DATA_TYPE) 43 | -------------------------------------------------------------------------------- /vision_datasets/text_2_image_retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import Text2ImageRetrievalCocoManifestAdaptor 2 | from .manifest import Text2ImageRetrievalLabelManifest 3 | from .operations import Text2ImageRetrievalCocoDictGenerator 4 | 5 | __all__ = ['Text2ImageRetrievalCocoManifestAdaptor', 'Text2ImageRetrievalLabelManifest', 'Text2ImageRetrievalCocoDictGenerator'] 6 | -------------------------------------------------------------------------------- /vision_datasets/text_2_image_retrieval/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, ImageDataManifest, CocoManifestWithoutCategoriesAdaptor, CocoManifestAdaptorFactory 2 | from .manifest import Text2ImageRetrievalLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.TEXT_2_IMAGE_RETRIEVAL) 6 | class Text2ImageRetrievalCocoManifestAdaptor(CocoManifestWithoutCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.TEXT_2_IMAGE_RETRIEVAL) 9 | 10 | def process_label(self, image: ImageDataManifest, annotation: dict, coco_manifest: dict): 11 | image.labels.append(Text2ImageRetrievalLabelManifest(annotation['query'], additional_info=self._get_additional_info(annotation, {'id', 'image_id', 'query'}))) 12 | -------------------------------------------------------------------------------- /vision_datasets/text_2_image_retrieval/manifest.py: -------------------------------------------------------------------------------- 1 | from ..common import ImageLabelManifest 2 | 3 | 4 | class Text2ImageRetrievalLabelManifest(ImageLabelManifest): 5 | """ 6 | query: in str 7 | """ 8 | def query(self) -> str: 9 | return self.label_data 10 | 11 | def _read_label_data(self): 12 | raise NotImplementedError 13 | 14 | def _check_label(self, label_data): 15 | if label_data is None: 16 | raise ValueError 17 | -------------------------------------------------------------------------------- /vision_datasets/text_2_image_retrieval/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import CocoDictGeneratorFactory, DatasetTypes, GenerateCocoDictBase, ImageLabelManifest, ManifestMergeStrategyFactory, SampleByNumSamples, SampleStrategyFactory, \ 2 | SampleStrategyType, SingleTaskMerge, Spawn, SpawnFactory, Split, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, \ 3 | DatasetManifest, ImageDataManifest 4 | from .manifest import Text2ImageRetrievalLabelManifest 5 | 6 | _DATA_TYPE = DatasetTypes.TEXT_2_IMAGE_RETRIEVAL 7 | 8 | 9 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 10 | class Text2ImageRetrievalCocoDictGenerator(GenerateCocoDictBase): 11 | def process_labels(self, coco_ann, label: ImageLabelManifest): 12 | coco_ann['query'] = label.label_data 13 | 14 | 15 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 16 | 17 | 18 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 19 | 20 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 21 | SplitFactory.direct_register(Split, _DATA_TYPE) 22 | 23 | 24 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 25 | class Text2ImageRetrievalStandAloneImageListGenerator(GenerateStandAloneImageListBase): 26 | def _generate_label(self, label: Text2ImageRetrievalLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> dict: 27 | return {'query': label.query} 28 | -------------------------------------------------------------------------------- /vision_datasets/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .torch_dataset import TorchDataset 3 | 4 | __all__ = ['Dataset', 'TorchDataset'] 5 | -------------------------------------------------------------------------------- /vision_datasets/torch/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractclassmethod 2 | from inspect import signature 3 | 4 | import torch 5 | from PIL import ImageFile 6 | 7 | 8 | def _identity(*args): 9 | return args 10 | 11 | 12 | class _ImageOnlyTransform: 13 | def __init__(self, transform): 14 | self._transform = transform 15 | 16 | def __call__(self, image, targets): 17 | return self._transform(image), targets 18 | 19 | 20 | class Dataset(torch.utils.data.Dataset, ABC): 21 | def __init__(self, transform=None): 22 | super().__init__() 23 | self.transform = transform 24 | # Work around for corrupted files in datasets 25 | ImageFile.LOAD_TRUNCATED_IMAGES = True 26 | 27 | @property 28 | @abstractclassmethod 29 | def categories(self): 30 | """Returns a list of categories.""" 31 | raise NotImplementedError 32 | 33 | @property 34 | def transform(self): 35 | return self._transform 36 | 37 | @transform.setter 38 | def transform(self, val): 39 | if not val: 40 | self._transform = _identity 41 | elif len(signature(val).parameters) == 1: 42 | self._transform = _ImageOnlyTransform(val) 43 | else: 44 | self._transform = val 45 | 46 | def close(self): 47 | """Release the resources allocated for this dataset.""" 48 | pass 49 | 50 | def __getitem__(self, index): 51 | raise NotImplementedError 52 | 53 | def __len__(self): 54 | raise NotImplementedError 55 | -------------------------------------------------------------------------------- /vision_datasets/torch/torch_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .dataset import Dataset 4 | from ..common.dataset.vision_dataset import VisionDataset 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class TorchDataset(Dataset): 10 | """ 11 | Dataset class used for pytorch training 12 | """ 13 | 14 | def __init__(self, manifest_dataset: VisionDataset, transform=None): 15 | Dataset.__init__(self, transform) 16 | self.dataset = manifest_dataset 17 | 18 | @property 19 | def categories(self): 20 | return self.dataset.categories 21 | 22 | @property 23 | def dataset_resources(self): 24 | return self.dataset.dataset_resources 25 | 26 | @property 27 | def dataset_info(self): 28 | return self.dataset.dataset_info 29 | 30 | def __getitem__(self, index): 31 | if isinstance(index, int): 32 | image, target, idx_str = self.dataset[index] 33 | image, target = self.transform(image, target) 34 | return image, target, idx_str 35 | else: 36 | return [self.transform(img, target) + (idx,) for img, target, idx in self.dataset[index]] 37 | 38 | def __len__(self): 39 | return len(self.dataset) 40 | 41 | def close(self): 42 | self.dataset.close() 43 | -------------------------------------------------------------------------------- /vision_datasets/visual_object_grounding/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import VisualObjectGroundingCocoManifestAdaptor 2 | from .manifest import VisualObjectGroundingLabelManifest 3 | from .operations import VisualObjectGroundingCocoDictGenerator 4 | 5 | __all__ = ['VisualObjectGroundingCocoManifestAdaptor', 'VisualObjectGroundingLabelManifest', 'VisualObjectGroundingCocoDictGenerator'] 6 | -------------------------------------------------------------------------------- /vision_datasets/visual_object_grounding/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, ImageDataManifest, CocoManifestWithoutCategoriesAdaptor, CocoManifestAdaptorFactory, BBoxFormat 2 | from .manifest import VisualObjectGroundingLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.VISUAL_OBJECT_GROUNDING) 6 | class VisualObjectGroundingCocoManifestAdaptor(CocoManifestWithoutCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.VISUAL_OBJECT_GROUNDING) 9 | 10 | def process_label(self, image: ImageDataManifest, annotation: dict, coco_manifest: dict): 11 | bbox_format = coco_manifest.get('bbox_format') 12 | bbox_format = BBoxFormat[bbox_format.upper()] if bbox_format else BBoxFormat.LTWH 13 | groundings = annotation['groundings'] 14 | if bbox_format == BBoxFormat.LTWH: 15 | for g in groundings: 16 | for bbox in g['bboxes']: 17 | bbox[2] += bbox[0] 18 | bbox[3] += bbox[1] 19 | 20 | image.labels.append(VisualObjectGroundingLabelManifest({'question': annotation['question'], 'answer': annotation['answer'], 'groundings': annotation['groundings']}, 21 | additional_info=self._get_additional_info(annotation, {'id', 'question', 'answer', 'groundings'}))) 22 | -------------------------------------------------------------------------------- /vision_datasets/visual_object_grounding/manifest.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from ..common import ImageLabelManifest 3 | 4 | 5 | class Grounding: 6 | def __init__(self, label_data: dict, answer_len: int): 7 | self.check_label(label_data, answer_len) 8 | 9 | self._label_data = label_data 10 | 11 | @staticmethod 12 | def check_label(label_data, answer_len): 13 | if not label_data or any(label_data.get(key) is None for key in ['id', 'text_span', 'text', 'bboxes']): 14 | raise ValueError 15 | 16 | for bbox in label_data['bboxes']: 17 | if len(bbox) != 4: 18 | raise ValueError 19 | 20 | if len(label_data['text_span']) != 2: 21 | raise ValueError 22 | 23 | start = label_data['text_span'][0] 24 | end = label_data['text_span'][1] 25 | if start < 0 or end < 0 or start >= end or start >= answer_len or end > answer_len: 26 | raise ValueError 27 | 28 | @property 29 | def id(self): 30 | return self._label_data['id'] 31 | 32 | @property 33 | def text_span(self): 34 | return self._label_data['text_span'] 35 | 36 | @property 37 | def text(self) -> str: 38 | return self._label_data['text'] 39 | 40 | @property 41 | def bboxes(self) -> List[List[Union[int, float]]]: 42 | """ returns a list of bounding boxes in the format of [[left, top, right, bottom], ...] 43 | 44 | Returns: 45 | List[List[Union[int, float]]]: list of boxes in the format of [[left, top, right, bottom], ...] 46 | """ 47 | return self._label_data['bboxes'] 48 | 49 | 50 | class VisualObjectGroundingLabelManifest(ImageLabelManifest): 51 | """ 52 | { 53 | "question": "a question about the image", 54 | "answer": "generic caption or answer to the question", 55 | "groundings": [{"text": "....", "text_span": [start, end], "bboxes": [[left, top, right, bottom], ...]}, ...] 56 | } 57 | """ 58 | 59 | def _read_label_data(self): 60 | raise NotImplementedError 61 | 62 | def _check_label(self, label_data): 63 | if not label_data or any(label_data.get(key, None) is None for key in ['question', 'answer', 'groundings']): 64 | raise ValueError(str(label_data.keys())) 65 | 66 | for grounding in label_data["groundings"]: 67 | Grounding.check_label(grounding, len(label_data["answer"])) 68 | 69 | @property 70 | def question(self) -> str: 71 | return self.label_data["question"] 72 | 73 | @property 74 | def answer(self) -> str: 75 | return self.label_data["answer"] 76 | 77 | @property 78 | def groundings(self) -> List[Grounding]: 79 | return [Grounding(x, len(self.label_data["answer"])) for x in self.label_data["groundings"]] 80 | -------------------------------------------------------------------------------- /vision_datasets/visual_object_grounding/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, GenerateCocoDictBase, ImageLabelManifest, SampleByNumSamples, SampleStrategyType, SingleTaskMerge, Spawn, Split, CocoDictGeneratorFactory, \ 2 | ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, \ 3 | ImageDataManifest, DatasetManifest 4 | from .manifest import VisualObjectGroundingLabelManifest 5 | _DATA_TYPE = DatasetTypes.VISUAL_OBJECT_GROUNDING 6 | 7 | 8 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 9 | class VisualObjectGroundingCocoDictGenerator(GenerateCocoDictBase): 10 | def process_labels(self, coco_ann, label: ImageLabelManifest): 11 | for key, val in label.label_data.items(): 12 | coco_ann[key] = val 13 | 14 | 15 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 16 | 17 | 18 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 19 | 20 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 21 | SplitFactory.direct_register(Split, _DATA_TYPE) 22 | 23 | 24 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 25 | class Text2ImageRetrievalStandAloneImageListGenerator(GenerateStandAloneImageListBase): 26 | def _generate_label(self, label: VisualObjectGroundingLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> dict: 27 | return {'question': label.question, 'answer': label.answer, 'groundings': label.label_data['groundings']} 28 | -------------------------------------------------------------------------------- /vision_datasets/visual_question_answering/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_manifest_adaptor import VisualQuestionAnswerinCocoManifestAdaptor 2 | from .manifest import VisualQuestionAnsweringLabelManifest 3 | from .operations import VisualQuestionAnsweringCocoDictGenerator 4 | from .vqa_as_kvp_dataset import VQAAsKeyValuePairDataset 5 | 6 | __all__ = ['VisualQuestionAnswerinCocoManifestAdaptor', 'VisualQuestionAnsweringLabelManifest', 'VisualQuestionAnsweringCocoDictGenerator', 'VQAAsKeyValuePairDataset'] 7 | -------------------------------------------------------------------------------- /vision_datasets/visual_question_answering/coco_manifest_adaptor.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, ImageDataManifest, CocoManifestWithoutCategoriesAdaptor, CocoManifestAdaptorFactory 2 | from .manifest import VisualQuestionAnsweringLabelManifest 3 | 4 | 5 | @CocoManifestAdaptorFactory.register(DatasetTypes.VISUAL_QUESTION_ANSWERING) 6 | class VisualQuestionAnswerinCocoManifestAdaptor(CocoManifestWithoutCategoriesAdaptor): 7 | def __init__(self) -> None: 8 | super().__init__(DatasetTypes.VISUAL_QUESTION_ANSWERING) 9 | 10 | def process_label(self, image: ImageDataManifest, annotation: dict, coco_manifest: dict): 11 | image.labels.append(VisualQuestionAnsweringLabelManifest({"question": annotation['question'], "answer": annotation['answer']}, 12 | additional_info=self._get_additional_info(annotation, {'id', 'question', ''}))) 13 | -------------------------------------------------------------------------------- /vision_datasets/visual_question_answering/manifest.py: -------------------------------------------------------------------------------- 1 | from ..common import ImageLabelManifest 2 | 3 | 4 | class VisualQuestionAnsweringLabelManifest(ImageLabelManifest): 5 | """ 6 | {"question": "a question about the image", "answer": "answer to the question"} 7 | """ 8 | 9 | def _read_label_data(self): 10 | raise NotImplementedError 11 | 12 | def _check_label(self, label_data): 13 | if label_data is None or 'question' not in label_data or 'answer' not in label_data: 14 | raise ValueError 15 | 16 | @property 17 | def question(self) -> str: 18 | return self.label_data["question"] 19 | 20 | @property 21 | def answer(self) -> str: 22 | return self.label_data["answer"] 23 | -------------------------------------------------------------------------------- /vision_datasets/visual_question_answering/operations.py: -------------------------------------------------------------------------------- 1 | from ..common import DatasetTypes, GenerateCocoDictBase, ImageLabelManifest, SampleByNumSamples, SampleStrategyType, SingleTaskMerge, Spawn, Split, CocoDictGeneratorFactory, \ 2 | ManifestMergeStrategyFactory, SampleStrategyFactory, SpawnFactory, SplitFactory, StandAloneImageListGeneratorFactory, GenerateStandAloneImageListBase, \ 3 | ImageDataManifest, DatasetManifest 4 | from .manifest import VisualQuestionAnsweringLabelManifest 5 | _DATA_TYPE = DatasetTypes.VISUAL_QUESTION_ANSWERING 6 | 7 | 8 | @CocoDictGeneratorFactory.register(_DATA_TYPE) 9 | class VisualQuestionAnsweringCocoDictGenerator(GenerateCocoDictBase): 10 | def process_labels(self, coco_ann, label: ImageLabelManifest): 11 | for key, val in label.label_data.items(): 12 | coco_ann[key] = val 13 | 14 | 15 | ManifestMergeStrategyFactory.direct_register(SingleTaskMerge, _DATA_TYPE) 16 | 17 | 18 | SampleStrategyFactory.direct_register(SampleByNumSamples, _DATA_TYPE, SampleStrategyType.NumSamples) 19 | 20 | SpawnFactory.direct_register(Spawn, _DATA_TYPE) 21 | SplitFactory.direct_register(Split, _DATA_TYPE) 22 | 23 | 24 | @StandAloneImageListGeneratorFactory.register(_DATA_TYPE) 25 | class VisualQuestionAnsweringStandAloneImageListGenerator(GenerateStandAloneImageListBase): 26 | def _generate_label(self, label: VisualQuestionAnsweringLabelManifest, image: ImageDataManifest, manifest: DatasetManifest) -> dict: 27 | return {'question': label.question, 'answer': label.answer} 28 | -------------------------------------------------------------------------------- /vision_datasets/visual_question_answering/vqa_as_kvp_dataset.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import logging 3 | from typing import Any, Dict 4 | 5 | from vision_datasets.common import DatasetTypes, KeyValuePairDatasetInfo, VisionDataset 6 | from vision_datasets.key_value_pair import ( 7 | KeyValuePairDatasetManifest, 8 | KeyValuePairLabelManifest, 9 | ) 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class VQAAsKeyValuePairDataset(VisionDataset): 15 | """Dataset class that access Visual Question Answering (VQA) datset as KeyValuePair dataset.""" 16 | 17 | ANSWER_KEY = "answer" 18 | RATIONALE_KEY = "rationale" 19 | QUESTION_KEY = "question" 20 | 21 | def __init__(self, vqa_dataset: VisionDataset): 22 | """ 23 | Initializes an instance of the VQAAsKeyValuePairDataset class. 24 | Args: 25 | vqa_dataset (VisionDataset): The VQA dataset to convert to key-value pair dataset. 26 | """ 27 | 28 | if vqa_dataset is None or vqa_dataset.dataset_info.type is not DatasetTypes.VISUAL_QUESTION_ANSWERING: 29 | raise ValueError("Input dataset must be a Visual Question Answering dataset.") 30 | 31 | # Generate schema and update dataset info 32 | vqa_dataset = deepcopy(vqa_dataset) 33 | 34 | dataset_info_dict = deepcopy(vqa_dataset.dataset_info.__dict__) 35 | dataset_info_dict["type"] = DatasetTypes.KEY_VALUE_PAIR.name.lower() 36 | 37 | schema = self._schema 38 | # Update dataset_info with schema 39 | dataset_info = KeyValuePairDatasetInfo({**dataset_info_dict, "schema": schema}) 40 | 41 | dataset_manifest = vqa_dataset.dataset_manifest 42 | self.img_id_to_pos = {x.id: i for i, x in enumerate(dataset_manifest.images)} 43 | 44 | # Construct KeyValuePairDatasetManifest 45 | annotations = [] 46 | id = 1 47 | for _, img in enumerate(dataset_manifest.images, 1): 48 | label_data = [label.label_data for label in img.labels] 49 | 50 | for label in label_data: 51 | kvp_label_data = self.construct_kvp_label_data(label) 52 | img_ids = [self.img_id_to_pos[img.id]] # 0-based index 53 | kvp_annotation = KeyValuePairLabelManifest(id, img_ids, label_data=kvp_label_data) 54 | id += 1 55 | 56 | # KVPDatasetManifest expects img.labels to be empty. Labels are instead stored in KVP annotation 57 | img.labels = [] 58 | annotations.append(kvp_annotation) 59 | 60 | dataset_manifest = KeyValuePairDatasetManifest(deepcopy(dataset_manifest.images), annotations, schema, additional_info=deepcopy(vqa_dataset.dataset_manifest.additional_info)) 61 | super().__init__(dataset_info, dataset_manifest, dataset_resources=vqa_dataset.dataset_resources) 62 | 63 | @property 64 | def _schema(self) -> Dict[str, Any]: 65 | return { 66 | "name": "Visual Question Answering", 67 | "description": "Answer questions on given images and provide rationale for the answer.", 68 | "fieldSchema": { 69 | self.ANSWER_KEY: { 70 | "type": "string", 71 | "description": "Answer to the question.", 72 | }, 73 | self.RATIONALE_KEY: { 74 | "type": "string", 75 | "description": "Rationale for the answer.", 76 | }, 77 | } 78 | } 79 | 80 | def construct_kvp_label_data(self, label: Dict[str, str]) -> Dict[str, Dict[str, str]]: 81 | """ 82 | Convert the VQA dataset label to the desired format for KVP annotation as defined by the SCHEMA_BASE. 83 | E.g. {"fields": 84 | {"answer": {"value": "yes"}}, 85 | "text": {"question": "Is there a dog in the image?"} 86 | } 87 | """ 88 | 89 | if self.QUESTION_KEY not in label: 90 | raise KeyError(f"Question key '{self.QUESTION_KEY}' not found in label.") 91 | if self.ANSWER_KEY not in label: 92 | raise KeyError(f"Answer key '{self.ANSWER_KEY}' not found in label.") 93 | 94 | kvp_label_data = { 95 | KeyValuePairLabelManifest.LABEL_KEY: { 96 | self.ANSWER_KEY: {KeyValuePairLabelManifest.LABEL_VALUE_KEY: label[self.ANSWER_KEY]}, 97 | }, 98 | KeyValuePairLabelManifest.TEXT_INPUT_KEY: {self.QUESTION_KEY: label[self.QUESTION_KEY]}, 99 | } 100 | 101 | return kvp_label_data 102 | --------------------------------------------------------------------------------