├── LICENSE ├── README.md └── wds_create_shards.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 LAION AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dataset-spec 2 | Check out this video: https://www.youtube.com/watch?v=cmVBaShtygA 3 | 4 | ## Metadata for image/text datasets 5 | 6 | If the dataset can be downloaded from image urls, 7 | Prefer distributing data publicly (eg on huggingface) as parquet files with these columns: 8 | * URL 9 | * TEXT 10 | If other information is available, feel free to provide it was well in other columns 11 | 12 | ## Image-Text-Datasets 13 | 14 | The format is a collection of tar files (that dataset format is called webdataset) containing images, captions and metadata 15 | 16 | * 00000.tar containing 10k samples 17 | * 0.jpg 18 | * 0.txt containing the caption 19 | * 0.json containing metadata such as the url, the original width, the exif data, whether the image is NSFW 20 | 21 | 22 | If the biggest dimension of an image is bigger than 512, then resize so that the biggest dimension would be 512 and keep the aspect ratio. 23 | 24 | If the smallest dimension of the image is below 64, drop the sample. 25 | 26 | Do not increase the resolution of the sample if it is below 512, but above 64, just keep it. 27 | 28 | Save the image data in the webdataset format as JPEG highest quality. 29 | 30 | Use “jpg” as field for the image and “txt” as field for the caption. 31 | A "json" field should contain at least height and width of the image, eventually if more data is available. 32 | 33 | If you have a VQA - dataset, put the prefix "Q: " before the question & the prefix "A: " before the answer and then concatenate both texts. Put those into the "txt" field. 34 | Also out an entry into the "json" field with the key "question" & the question as value. Also add an entry with the key "answer" to the json with the answer as the value. 35 | 36 | A help to create wds tar files from jpg & txt files can be this script: [wds_create_shards.py](wds_create_shards.py) (**json support added, you can use the --json tag to automatically read json files and add them to the tar files 04-19-2022**) 37 | 38 | | Dataset info | Who works on it | 39 | |---|---| 40 | | | | 41 | | | | 42 | | | | 43 | | | | 44 | | | | 45 | | | | 46 | | | | 47 | | | | 48 | | | | 49 | | | | 50 | | | | 51 | | | | 52 | | | | 53 | | | | 54 | | | | 55 | | | | 56 | | | | 57 | | | | 58 | | | | 59 | | | | 60 | | | | 61 | | | | 62 | | | | 63 | | | | 64 | 65 | 66 | -------------------------------------------------------------------------------- /wds_create_shards.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path 4 | import random 5 | import argparse 6 | import json 7 | 8 | from pathlib import Path 9 | from PIL import Image 10 | 11 | import webdataset as wds 12 | 13 | parser = argparse.ArgumentParser("""Generate sharded dataset from image-text-datasets.""") 14 | parser.add_argument("--maxsize", type=float, default=1e9) 15 | parser.add_argument("--maxcount", type=float, default=100000) 16 | parser.add_argument( 17 | "--compression", 18 | dest="compression", 19 | action="store_true", 20 | help="Creates compressed .tar.gz files instead of uncompressed .tar files." 21 | ) 22 | parser.add_argument( 23 | "--json", 24 | dest="json", 25 | action="store_true", 26 | help="Reads json files and add them to the .tar files." 27 | ) 28 | parser.add_argument( 29 | "--image_text_keys", 30 | type=str, 31 | default="img,cap", 32 | help="Comma separated WebDataset dictionary keys for images (first argument) and texts (second argument). \ 33 | The exact argument has to be provided to train_dalle.py, e.g. python train_dalle.py --wds img,cp --image_text_folder ../shards" 34 | ) 35 | parser.add_argument( 36 | "--shards", 37 | default="./shards", 38 | help="directory where shards are written" 39 | ) 40 | parser.add_argument( 41 | "--shard_prefix", 42 | default="ds_", 43 | help="prefix of shards' filenames created in the shards-folder" 44 | ) 45 | parser.add_argument( 46 | "--data", 47 | default="./data", 48 | help="directory path containing data suitable for DALLE-pytorch training", 49 | ) 50 | args = parser.parse_args() 51 | 52 | assert len(args.image_text_keys.split(',')) == 2, 'Too many arguments provided' 53 | assert args.maxsize > 10000000 54 | assert args.maxcount < 1000000 55 | 56 | image_key, caption_key = tuple(args.image_text_keys.split(',')) 57 | 58 | if not os.path.isdir(os.path.join(args.data)): 59 | print(f"{args.data}: should be directory containing image-text pairs", file=sys.stderr) 60 | print(f"or subfolders containing image-text-pairs", file=sys.stderr) 61 | sys.exit(1) 62 | 63 | os.makedirs(Path(args.shards), exist_ok=True) 64 | 65 | def readfile(fname): 66 | "Read a binary file from disk." 67 | with open(fname, "rb") as stream: 68 | return stream.read() 69 | 70 | path = Path(args.data) 71 | text_files = [*path.glob('**/*.txt')] 72 | text_files = {text_file.stem: text_file for text_file in text_files} # str(text_file.parents[0]) + 73 | text_total = len(text_files) 74 | 75 | if args.json: 76 | json_files = [*path.glob('**/*.json')] 77 | json_files = {json_file.stem: json_file for json_file in json_files} 78 | json_dicts = {} 79 | # json_files_old = json_files.copy() 80 | 81 | for key in json_files: 82 | try: 83 | with open(json_files[key], "r") as f: 84 | json_dicts[key] = json.dumps(json.load(f)) 85 | except: 86 | pass 87 | # del json_files["key"] 88 | print("Found {} corrupt json file(s).".format(len(json_files.keys()) - len(json_dicts.keys()))) 89 | json_keys = json_files.keys() 90 | 91 | image_files = [ 92 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 93 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 94 | ] 95 | image_files = {image_file.stem: image_file for image_file in image_files} # str(image_file.parents[0]) + 96 | image_total = len(image_files) 97 | 98 | print('Found {:,} textfiles and {:,} images.'.format(text_total, image_total)) 99 | 100 | keys = (image_files.keys() & text_files.keys()) 101 | 102 | text_files = {k: v for k, v in text_files.items() if k in keys} 103 | image_files = {k: v for k, v in image_files.items() if k in keys} 104 | 105 | for key in image_files: 106 | img = Image.open(image_files[key]) 107 | try: 108 | img.verify() 109 | except Exception: 110 | print('Invalid image on path {}'.format(key)) 111 | keys.remove(key) 112 | 113 | print("Remaining keys after image sanity check: {:,}".format(len(keys))) 114 | 115 | total_pairs = len(keys) 116 | keys = list(keys) 117 | 118 | indexes = list(range(total_pairs)) 119 | random.shuffle(indexes) 120 | 121 | # This is the output pattern under which we write shards. 122 | pattern = os.path.join(args.shards, args.shard_prefix + f"%06d.tar" + (".gz" if args.compression else '')) 123 | 124 | with wds.ShardWriter(pattern, maxsize=int(args.maxsize), maxcount=int(args.maxcount)) as sink: 125 | for i in indexes: 126 | with open(image_files[keys[i]], "rb") as imgstream: 127 | image = imgstream.read() 128 | with open(text_files[keys[i]], "rb") as txtstream: 129 | text = txtstream.read() 130 | 131 | ds_key = "%09d" % i 132 | 133 | sample = { 134 | "__key__": ds_key, 135 | image_key: image, 136 | caption_key: text 137 | } 138 | if args.json and keys[i] in json_keys: 139 | sample["json"] = json_dicts[keys[i]] 140 | sink.write(sample) 141 | --------------------------------------------------------------------------------