├── .gitignore ├── LICENSE ├── README.md ├── data ├── w00000 │ ├── w00000000.jpg │ ├── w00000000.json │ ├── w00000000.txt │ ├── w00000001.jpg │ ├── w00000001.txt │ ├── w00000002.jpg │ ├── w00000002.txt │ ├── w00000003.jpg │ ├── w00000003.txt │ ├── w00000004.jpg │ └── w00000004.txt ├── w00001 │ ├── w00000005.jpg │ ├── w00000005.txt │ ├── w00000006.jpg │ ├── w00000006.txt │ ├── w00000007.jpg │ ├── w00000007.txt │ ├── w00000008.jpg │ └── w00000008.txt ├── w00002 │ ├── w00000009.jpg │ ├── w00000009.txt │ ├── w00000010.jpg │ ├── w00000010.txt │ ├── w00000011.jpg │ └── w00000011.txt └── w00003 │ ├── w00000013.jpg │ ├── w00000013.txt │ ├── w00000014.jpg │ ├── w00000014.txt │ ├── w00000015.jpg │ └── w00000015.txt ├── download_open_images.txt ├── general ├── cc12m.py ├── cc3m.py ├── filtered_yfcc100m.py ├── helper_scripts │ ├── wit_clip_class.py │ ├── wit_dtype.py │ ├── wit_image_downloader.py │ └── wit_url_downloader.py ├── openimages_labels.py ├── openimages_narrative.py ├── wit.py ├── wit_clip.py └── wit_old.py ├── setup.py └── utilities ├── clip_wit.py ├── dataset_sanitycheck.py ├── tokenizer_from_wds_or_text.py ├── wds_create_legacy.py ├── wds_create_shards.py ├── wds_from_tfrecords.py ├── wds_from_tfrecords_alternative.py ├── wds_pytorchread.py └── wds_read.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | witurls/ 3 | datasets 4 | general/helper_scripts/__pycache__ 5 | general/wit_old.py 6 | general/clip_wit.py 7 | general/wit.py 8 | general/wit_clip copy.py 9 | utilities/clip_wit.py 10 | general/wit_old.py 11 | general/wit_clip copy2.py 12 | tfrecords 13 | tfr 14 | tartest.py 15 | test.py 16 | skips 17 | shards 18 | output 19 | openimages 20 | openimages_old.py 21 | .DS_Store 22 | .gitignore 23 | build 24 | dist 25 | .ipynb_checkpoints 26 | dalle_datasets.egg-info 27 | captions_train.json 28 | openimages-train-000000.tar 29 | downsampled-open-images-v4 30 | downsampled-open-images-v4-9208d33aceb2ca3eb2beb70a192600c9c41efba1.torrent 31 | downsampled-open-images-v4.aria2 32 | wit_urls 33 | wds_create_shards_backup.py 34 | testfolder 35 | testfolder_backup 36 | dataset.tar.gz 37 | dataset_sanitycheck_backup.py 38 | incomplete_files.csv 39 | wit 40 | mytest.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 robvanvolt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DALLE-datasets 2 | This is a summary of easily available, high-quality datasets consisiting of captioned image files for generalized DALLE-pytorch training (https://github.com/lucidrains/DALLE-pytorch). 3 | 4 | The scripts help you download and resize the files from the given sources. 5 | 6 | * general datasets 7 | * Conceptual Images 12m 8 | * Wikipedia 9 | * Filtered yfcc100m 10 | * Open Images 11 | * specific datasets 12 | * None yet 13 | 14 | 15 | ## Helper scripts 16 | 17 | All helper scripts can be found in the utilities folder now: 18 | * TFrecords to WebDataset converter 19 | * Image-Text-Folder to WebDataset converter 20 | * Dataset sanitycheck for image-text-files 21 | * Example reader for WebDataset files 22 | 23 | 24 | ### Sanitycheck for downloaded datasets 25 | 26 | The following command will look for image-text-pairs (.jpg / .png / .bmp) and return a csv table with incomplete data. 27 | When you add the optional argument -DEL, the incomplete files get deleted. The python scripts checks one folder and the first subdirectories. 28 | 29 | ```python sanity_check.py --dataset_folder my-dataset-folder``` 30 | 31 | 32 | ## Pretrained models 33 | 34 | If you want to continue training on pretrained models or even upload your own Dall-E model, head over to https://github.com/robvanvolt/DALLE-models 35 | 36 | ## Credits 37 | 38 | Special thanks go to Romaine, who improved the download scripts and made the great WebDataset format more accessible with his continuous coding efforts! 🙏 39 | 40 | A lot of inspiration was taken from https://github.com/yashbonde/dall-e-baby - unfortunately that repo does not get updated anymore... 41 | Also, the shard creator was inspired by https://github.com/tmbdev-archive/webdataset-examples/blob/master/makeshards.py. 42 | The custom tokenizer was inspired by afiaka87, who showed a simple way to generate custom tokenizers with youtokentome. 43 | -------------------------------------------------------------------------------- /data/w00000/w00000000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00000/w00000000.jpg -------------------------------------------------------------------------------- /data/w00000/w00000000.json: -------------------------------------------------------------------------------- 1 | { 2 | "A": 12, 3 | "B": "Test" 4 | } -------------------------------------------------------------------------------- /data/w00000/w00000000.txt: -------------------------------------------------------------------------------- 1 | Galego: Logo do Movemento Galego ao Socialismo -------------------------------------------------------------------------------- /data/w00000/w00000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00000/w00000001.jpg -------------------------------------------------------------------------------- /data/w00000/w00000001.txt: -------------------------------------------------------------------------------- 1 | Lesser bulldog bat (Noctilio albiventris) -------------------------------------------------------------------------------- /data/w00000/w00000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00000/w00000002.jpg -------------------------------------------------------------------------------- /data/w00000/w00000002.txt: -------------------------------------------------------------------------------- 1 | Coin of Ukraine Русский: Юбилейная монета Украины -------------------------------------------------------------------------------- /data/w00000/w00000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00000/w00000003.jpg -------------------------------------------------------------------------------- /data/w00000/w00000003.txt: -------------------------------------------------------------------------------- 1 | mendeleevo -------------------------------------------------------------------------------- /data/w00000/w00000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00000/w00000004.jpg -------------------------------------------------------------------------------- /data/w00000/w00000004.txt: -------------------------------------------------------------------------------- 1 | Sehemu za Mji wa Brookline, Massachusetts 2 | Brookline MA August 2015 Photo Collage 2 -------------------------------------------------------------------------------- /data/w00001/w00000005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00001/w00000005.jpg -------------------------------------------------------------------------------- /data/w00001/w00000005.txt: -------------------------------------------------------------------------------- 1 | Hay Street 中文(繁體)‎: 禧街 -------------------------------------------------------------------------------- /data/w00001/w00000006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00001/w00000006.jpg -------------------------------------------------------------------------------- /data/w00001/w00000006.txt: -------------------------------------------------------------------------------- 1 | Jayson Musson on October 29, 2007 2 | Jayson Scott Musson on October 29, 2007 -------------------------------------------------------------------------------- /data/w00001/w00000007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00001/w00000007.jpg -------------------------------------------------------------------------------- /data/w00001/w00000007.txt: -------------------------------------------------------------------------------- 1 | Չիբո կապելլա 2 | Photo of the Cybo Chapel of Santa Maria del Popolo, Rome, Italy. -------------------------------------------------------------------------------- /data/w00001/w00000008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00001/w00000008.jpg -------------------------------------------------------------------------------- /data/w00001/w00000008.txt: -------------------------------------------------------------------------------- 1 | Euodynerus megaera -------------------------------------------------------------------------------- /data/w00002/w00000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00002/w00000009.jpg -------------------------------------------------------------------------------- /data/w00002/w00000009.txt: -------------------------------------------------------------------------------- 1 | Simon Wolfe Rosendale (June 23, 1842 - April 22, 1937) was an American lawyer and politician. Rosendale was the first Jew elected to a statewide elective office in New York -------------------------------------------------------------------------------- /data/w00002/w00000010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00002/w00000010.jpg -------------------------------------------------------------------------------- /data/w00002/w00000010.txt: -------------------------------------------------------------------------------- 1 | ქართული: არქიმანდრიტი ადამი (ერისკაცობაში ვახტანგ მიხეილის ძე ახალაძე) -------------------------------------------------------------------------------- /data/w00002/w00000011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00002/w00000011.jpg -------------------------------------------------------------------------------- /data/w00002/w00000011.txt: -------------------------------------------------------------------------------- 1 | Photograph of Rainbow Springs in Marion County, Florida (2005). -------------------------------------------------------------------------------- /data/w00003/w00000013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00003/w00000013.jpg -------------------------------------------------------------------------------- /data/w00003/w00000013.txt: -------------------------------------------------------------------------------- 1 | Die California Clipper in der Bucht von Manila, ca. 1940 2 | The Boeing 314 California Clipper (civil registration NC18602) off the Cavite Navy Yard, Philippines, 1939-1941. Delivered on 27 January 1939, it went to the USAAF as C-98 42-88632, 18 December 1941; then to the US Navy as BuNo 99084. It was sold to Universal Airways in 1946, and to American International in 1948. It was finally scrapped in 1950. -------------------------------------------------------------------------------- /data/w00003/w00000014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00003/w00000014.jpg -------------------------------------------------------------------------------- /data/w00003/w00000014.txt: -------------------------------------------------------------------------------- 1 | Rozdělení kantonu Basilej v letech 1832-1833 2 | Deutsch: Karte zur Basler Kantonstrennung 1832/33 -------------------------------------------------------------------------------- /data/w00003/w00000015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robvanvolt/DALLE-datasets/bb983e3abe99d76a6fefbb173dcb640a6d8deb17/data/w00003/w00000015.jpg -------------------------------------------------------------------------------- /data/w00003/w00000015.txt: -------------------------------------------------------------------------------- 1 | 展示在芝加哥菲爾德自然史博物館的兩頭食人獅 2 | The maneless male Lions of Tsavo. -------------------------------------------------------------------------------- /download_open_images.txt: -------------------------------------------------------------------------------- 1 | aria2c --bt-metadata-only=true --bt-save-metadata=true https://academictorrents.com/download/9208d33aceb2ca3eb2beb70a192600c9c41efba1.torrent; 2 | aria2c --show-files downsampled-open-images-v4-9208d33aceb2ca3eb2beb70a192600c9c41efba1.torrent; 3 | aria2c --select-file=9,11,15 downsampled-open-images-v4-9208d33aceb2ca3eb2beb70a192600c9c41efba1.torrent; 4 | 5 | # next steps is to go to each folder and unzip the files 6 | echo "Gathering downsampled-open-images-v4"; 7 | rm downsampled-open-images-v4*; 8 | cd downsampled-open-images-v4/; 9 | rm -rf 512px/; 10 | cd 256px/; 11 | for i in test-256.tar.gz test_challenge_2018-256.tar.gz train-256.tar.gz validation-256.tar.gz 12 | do 13 | echo "Untarring: $i"; 14 | tar -xf $i; 15 | done 16 | -------------------------------------------------------------------------------- /general/cc12m.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import requests 4 | from pathlib import Path 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from multiprocessing import Pool 8 | import gc 9 | import glob 10 | 11 | cc_url = 'https://storage.googleapis.com/conceptual_12m/cc12m.tsv' 12 | root_folder = './' 13 | total = 12423374 14 | maxwidth = 256 15 | maxheight = 256 16 | thread_count = 16 17 | batch = 10000 18 | 19 | def load_caption(x): 20 | name, caption, text_folder = x 21 | fid = str(int(int(name) / 10000 )) 22 | subdir = "0"*(5-len(fid)) + fid 23 | os.makedirs(Path(text_folder+"/"+subdir), exist_ok=True) 24 | fp = text_folder + '/' + subdir + "/" + "0"*(9-len(str(name))) + str(name) + '.txt' 25 | with open(fp, 'w') as f: 26 | f.write(caption) 27 | 28 | def download_file(url): 29 | response = requests.get(url, stream=True) 30 | total_size_in_bytes= int(response.headers.get('content-length', 0)) 31 | block_size = 1024 32 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 33 | with open(Path(root_folder + '/cc12m.tsv'), 'wb') as file: 34 | for data in response.iter_content(block_size): 35 | progress_bar.update(len(data)) 36 | file.write(data) 37 | progress_bar.close() 38 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 39 | print("Error, something went wrong...") 40 | 41 | def load_image(x): 42 | name, url, image_folder, skip_folder = x 43 | fid = str(int(int(name) / 10000 )) 44 | subdir = "0"*(5-len(fid)) + fid 45 | os.makedirs(Path(image_folder+"/"+subdir), exist_ok=True) 46 | id = subdir + "/" + "0"*(9-len(str(name))) + str(name) 47 | try: 48 | with Image.open(requests.get(url, 49 | headers={'User-Agent': 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0'}, 50 | stream=True, timeout=3).raw) as foo: 51 | a = max(maxwidth/foo.size[0], maxheight/foo.size[1]) 52 | foo = foo.resize((int(foo.size[0] * a), int(foo.size[1] * a)), Image.ANTIALIAS) 53 | with open(Path(image_folder + "/" + id + '.jpg'), 'wb') as file: 54 | foo.save(file, optimize=True, quality=85) 55 | except Exception: 56 | os.makedirs(Path(skip_folder+"/"+subdir), exist_ok=True) 57 | open(Path(skip_folder + '/' + id), 'a').close 58 | pass 59 | 60 | if __name__ == '__main__': 61 | if not os.path.isfile(Path(root_folder + '/cc12m.tsv')): 62 | print('Missing cc12m url-caption-dataset. Downloading...') 63 | download_file(cc_url) 64 | else: 65 | print('cc12m.tsv already downloaded. Proceeding with downloading images!') 66 | 67 | dfc = pd.read_csv(root_folder + "cc12m.tsv", sep='\t', names=["url", "caption"]) 68 | 69 | image_folder = root_folder + '/images' 70 | text_folder = root_folder + '/texts' 71 | skip_folder = root_folder + '/skip' 72 | 73 | paths = [image_folder, text_folder, skip_folder] 74 | 75 | for path in paths: 76 | os.makedirs(path, exist_ok=True) 77 | 78 | def list_ids(path): 79 | return [int(os.path.splitext(os.path.basename(a))[0]) for a in glob.glob(path+"/**/*")] 80 | 81 | skiplist = list_ids(text_folder) 82 | remaining = total - len(skiplist) 83 | percent_remaining = 100 * (total - remaining) / total 84 | df = dfc.loc[~dfc.index.isin(skiplist)] 85 | 86 | print('Remaining {} captions to be written - {} ({:.5f} %) already written.'.format(remaining, len(skiplist), percent_remaining)) 87 | 88 | if len(df) > 0: 89 | captions = zip(df.index, df["caption"], [text_folder]*len(df)) 90 | pool = Pool(thread_count) 91 | for _ in tqdm(pool.imap_unordered(load_caption, captions), total=len(df)): 92 | pass 93 | pool.close() 94 | print('Done with captions!') 95 | 96 | skiplist = list_ids(skip_folder) + list_ids(image_folder) 97 | remaining = total - len(skiplist) 98 | percent_remaining = 100 * (total - remaining) / total 99 | 100 | df = dfc.loc[~dfc.index.isin(skiplist)] 101 | print('Remaining {} images to be downloaded - {} ({:.5f} %) already downloaded.'.format(remaining, len(skiplist), percent_remaining)) 102 | images = list(zip(df.index, df["url"], [image_folder]*len(df), [skip_folder]*len(df))) 103 | 104 | for i in tqdm(range(0, len(df), batch)): 105 | pool = Pool(thread_count) 106 | for _ in tqdm(pool.imap_unordered(load_image, images[i:i+batch]), total=batch): 107 | pass 108 | pool.terminate() 109 | pool.join() 110 | del pool 111 | gc.collect() 112 | 113 | print('Finished downloading available images from conceptual images!') 114 | -------------------------------------------------------------------------------- /general/cc3m.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import requests 6 | import os 7 | from pandarallel import pandarallel 8 | 9 | ### seperator = | 10 | 11 | ##### https://ai.google.com/research/ConceptualCaptions/download 12 | ##### download url-caption dataset from 13 | ##### https://storage.cloud.google.com/gcc-data/Train/GCC-training.tsv?_ga=2.191230122.-1896153081.1529438250 14 | 15 | DATASETFOLDER = 'content' 16 | DATASET = 'Train_GCC-training.tsv' 17 | FILEID = '1edNr-GEYz69RWcsSgskNzjtM--Qxepdz' 18 | URL = 'https://storage.cloud.google.com/gcc-data/Train/GCC-training.tsv?_ga=2.191230122.-1896153081.1529438250' 19 | 20 | ##### download location of image-caption pairs 21 | PARENTPATH = 'output' 22 | TEXTFOLDER = 'texts' 23 | IMAGEFOLDER = 'images' 24 | PREFIX = "" 25 | CHECKALLFOLDERS = True 26 | 27 | KEEPTHESECOLS = ['caption', 'url'] 28 | IMAGEFORMATS = ['jpg', 'jpeg', 'bmp', 'png'] 29 | MAXWIDTH = 320 30 | MAXHEIGHT = 320 31 | CHUNKS = 500000 32 | THREAD_COUNT = 16 33 | HIDE_ERRORS = False 34 | 35 | os.makedirs(Path(DATASETFOLDER), exist_ok=True) 36 | 37 | #### Helper scripts to download url-caption dataset 38 | def download_file_from_google_drive(id, destination): 39 | URL = "https://docs.google.com/uc?export=download" 40 | session = requests.Session() 41 | response = session.get(URL, params = { 'id' : id }, stream = True) 42 | token = get_confirm_token(response) 43 | if token: 44 | params = { 'id' : id, 'confirm' : token } 45 | response = session.get(URL, params = params, stream = True) 46 | 47 | save_response_content(response, destination) 48 | 49 | def download_file(url, root_folder): 50 | response = requests.get(url, stream=True) 51 | total_size_in_bytes= int(response.headers.get('content-length', 0)) 52 | block_size = 1024 53 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 54 | with open(Path(root_folder + '/cc3.tsv'), 'wb') as file: 55 | for data in response.iter_content(block_size): 56 | progress_bar.update(len(data)) 57 | file.write(data) 58 | progress_bar.close() 59 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 60 | print("Error, something went wrong...") 61 | 62 | def get_confirm_token(response): 63 | for key, value in response.cookies.items(): 64 | if key.startswith('download_warning'): 65 | return value 66 | return None 67 | 68 | def save_response_content(response, destination): 69 | CHUNK_SIZE = 32768 70 | with open(destination, "wb") as f: 71 | for chunk in response.iter_content(CHUNK_SIZE): 72 | if chunk: # filter out keep-alive new chunks 73 | f.write(chunk) 74 | 75 | if __name__ == '__main__': 76 | assert os.path.isfile(Path(DATASETFOLDER + '/' + DATASET)), print(''' 77 | ################################################################################################################# 78 | Missing cc3m url-caption-dataset. Automatic downloading not supported yet. 79 | Download https://storage.cloud.google.com/gcc-data/Train/GCC-training.tsv?_ga=2.191230122.-1896153081.1529438250 80 | And put it into following folder: {} 81 | ################################################################################################################# 82 | '''.format(DATASETFOLDER)) 83 | 84 | pandarallel.initialize(nb_workers=THREAD_COUNT) 85 | 86 | ### downloading dataset and resizsing images in parallel 87 | def write_files(x, folderpath): 88 | id = PREFIX + "0"*(8-len(str(x.name))) + str(x.name) 89 | try: 90 | foo = Image.open(requests.get(x.url, stream=True, timeout=4).raw) 91 | a = max(MAXWIDTH/foo.size[0], MAXHEIGHT/foo.size[1]) 92 | foo = foo.resize((int(foo.size[0] * a), int(foo.size[1] * a)), Image.ANTIALIAS) 93 | foo.save(Path(folderpath + '/' + id + '.jpg'), optimize=True, quality=85) 94 | except Exception as exc: 95 | if not HIDE_ERRORS: 96 | print('Failed downloading {} with url {}'.format(id, x.url)) 97 | print(exc) 98 | pass 99 | else: 100 | with open(Path(folderpath + '/' + id + '.txt'), 'w') as f: 101 | f.write(x.caption) 102 | 103 | os.makedirs(Path(PARENTPATH), exist_ok=True) 104 | 105 | keep_downloading = True 106 | if CHECKALLFOLDERS: 107 | batch = 0 108 | else: 109 | batch = len(os.listdir(Path(PARENTPATH))) - 1 110 | batch = 0 if batch == -1 else batch 111 | 112 | while keep_downloading: 113 | try: 114 | df = pd.read_csv(Path(DATASETFOLDER + '/' + DATASET), sep="\t", skiprows=range(0, batch * CHUNKS), nrows=CHUNKS, names=KEEPTHESECOLS) 115 | # df = pd.read_csv(Path(DATASETFOLDER + '/' + DATASET), sep="\t", skiprows=range(0, batch * CHUNKS), nrows=CHUNKS, names=KEEPTHESECOLS) 116 | df.index = [x + batch * CHUNKS for x in list(df.index)] 117 | folderid = str(PREFIX) + "0"*(4-len(str(batch))) + str(batch) 118 | folderpath = PARENTPATH + '/' + folderid 119 | os.makedirs(folderpath, exist_ok=True) 120 | skip = list(set([int(x[1:-4]) for x in os.listdir(folderpath)])) 121 | df = df[~df.index.isin(skip)] 122 | print('Saving {} images to {}.'.format(len(df), folderpath)) 123 | print('Skipping {} already downloaded urls.'.format(len(skip))) 124 | df.apply(lambda x: write_files(x, folderpath), axis=1) 125 | # df.parallel_apply(lambda x: write_files(x, folderpath), axis=1) 126 | except Exception as excp: 127 | print('An error occurred trying to download the filtered dataframe.') 128 | print(excp) 129 | keep_downloading = False 130 | pass 131 | else: 132 | if len(df) == 0: 133 | print('Alredy finished downloading images of batch {}!'.format(batch)) 134 | batch += 1 135 | 136 | print('Finished downloading dataset to {}.'.format(PARENTPATH)) 137 | -------------------------------------------------------------------------------- /general/filtered_yfcc100m.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | from PIL import Image 4 | import requests 5 | import zipfile 6 | import os 7 | from pandarallel import pandarallel 8 | 9 | ##### url-caption dataset from https://github.com/christophschuhmann/4MC-4M-Image-Text-Pairs-with-CLIP-embeddings 10 | DATASETFOLDER = 'content' 11 | DATASETZIP = 'yfcc_filtered.zip' 12 | DATASET = 'yfcc_filtered.csv' 13 | FILEID = '1edNr-GEYz69RWcsSgskNzjtM--Qxepdz' 14 | 15 | ##### download location of image-caption pairs 16 | PARENTPATH = 'output' 17 | TEXTFOLDER = 'texts' 18 | IMAGEFOLDER = 'images' 19 | PREFIX = "F" 20 | CHECKALLFOLDERS = True 21 | 22 | KEEPTHESECOLS = ['final_caption', 'url'] 23 | IMAGEFORMATS = ['jpg', 'jpeg', 'bmp', 'png'] 24 | MAXWIDTH = 320 25 | MAXHEIGHT = 320 26 | CHUNKS = 100000 27 | 28 | os.makedirs(Path(DATASETFOLDER), exist_ok=True) 29 | 30 | #### Helper scripts to download url-caption dataset 31 | def download_file_from_google_drive(id, destination): 32 | URL = "https://docs.google.com/uc?export=download" 33 | session = requests.Session() 34 | response = session.get(URL, params = { 'id' : id }, stream = True) 35 | token = get_confirm_token(response) 36 | if token: 37 | params = { 'id' : id, 'confirm' : token } 38 | response = session.get(URL, params = params, stream = True) 39 | 40 | save_response_content(response, destination) 41 | 42 | def get_confirm_token(response): 43 | for key, value in response.cookies.items(): 44 | if key.startswith('download_warning'): 45 | return value 46 | return None 47 | 48 | def save_response_content(response, destination): 49 | CHUNK_SIZE = 32768 50 | with open(destination, "wb") as f: 51 | for chunk in response.iter_content(CHUNK_SIZE): 52 | if chunk: # filter out keep-alive new chunks 53 | f.write(chunk) 54 | 55 | if not os.path.isfile(Path(DATASETFOLDER + '/' + DATASET)): 56 | if not os.path.isfile(Path(DATASETFOLDER + '/' + DATASETZIP)): 57 | download_file_from_google_drive(FILEID, Path(DATASETFOLDER + '/' + DATASETZIP)) 58 | 59 | with zipfile.ZipFile(Path(DATASETFOLDER + '/' + DATASETZIP), 'r') as zip_ref: 60 | zipname = zip_ref.namelist()[0].split('/')[-1] 61 | 62 | with zipfile.ZipFile(Path(DATASETFOLDER + '/' + DATASETZIP), 'r') as zip_ref: 63 | zip_ref.extractall() 64 | os.rename(Path(DATASETFOLDER + '/' + zipname), Path(DATASETFOLDER + '/' + DATASET)) 65 | 66 | pandarallel.initialize() 67 | 68 | ### downloading dataset and resizsing images in parallel 69 | def write_files(x, folderpath): 70 | id = PREFIX + "0"*(8-len(str(x.name))) + str(x.name) 71 | try: 72 | foo = Image.open(requests.get(x.url, stream=True, timeout=4).raw) 73 | a = max(MAXWIDTH/foo.size[0], MAXHEIGHT/foo.size[1]) 74 | foo = foo.resize((int(foo.size[0] * a), int(foo.size[1] * a)), Image.ANTIALIAS) 75 | foo.save(Path(folderpath + '/' + id + '.jpg'), optimize=True, quality=85) 76 | except Exception as exc: 77 | print('Failed downloading {} with url {}'.format(id, x.url)) 78 | print(exc) 79 | pass 80 | else: 81 | with open(Path(folderpath + '/' + id + '.txt'), 'w') as f: 82 | f.write(x.final_caption) 83 | 84 | os.makedirs(Path(PARENTPATH), exist_ok=True) 85 | 86 | keep_downloading = True 87 | if CHECKALLFOLDERS: 88 | batch = 0 89 | else: 90 | batch = len(os.listdir(Path(PARENTPATH))) - 1 91 | batch = 0 if batch == -1 else batch 92 | 93 | while keep_downloading: 94 | try: 95 | df = pd.read_csv(Path(DATASETFOLDER + '/' + DATASET), sep="|", skiprows=range(1, batch * CHUNKS + 1), nrows=CHUNKS, header=0, usecols=KEEPTHESECOLS) 96 | df.index = [x + batch * CHUNKS for x in list(df.index)] 97 | folderid = PREFIX + "0"*(4-len(str(batch))) + str(batch) 98 | folderpath = PARENTPATH + '/' + folderid 99 | os.makedirs(folderpath, exist_ok=True) 100 | skip = list(set([int(x[1:-4]) for x in os.listdir(folderpath)])) 101 | df = df[~df.index.isin(skip)] 102 | print('Saving {} images to {}.'.format(len(df), folderpath)) 103 | print('Skipping {} already downloaded urls.'.format(len(skip))) 104 | df.parallel_apply(lambda x: write_files(x, folderpath), axis=1) 105 | except Exception as excp: 106 | print('An error occurred trying to download the filtered dataframe.') 107 | print(excp) 108 | keep_downloading = False 109 | pass 110 | else: 111 | if len(df) == 0: 112 | print('Alredy finished downloading images of batch {}!'.format(batch)) 113 | batch += 1 114 | 115 | print('Finished downloading dataset to {}.'.format(PARENTPATH)) 116 | -------------------------------------------------------------------------------- /general/helper_scripts/wit_clip_class.py: -------------------------------------------------------------------------------- 1 | import os 2 | import clip 3 | import torch 4 | from PIL import Image 5 | from multiprocessing import cpu_count 6 | from multiprocessing.queues import JoinableQueue 7 | from svglib.svglib import svg2rlg 8 | from reportlab.graphics import renderPM 9 | 10 | device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" 11 | use_jit = False # torch.cuda.is_available() 12 | 13 | class CLIP: 14 | def __init__(self): 15 | self.model, self.preprocess = clip.load("ViT-B/32", device=device, jit=use_jit) 16 | self.tokenizer = clip.tokenize 17 | 18 | def return_similarities(self, image, captions, image_url): 19 | if '.svg' in image_url: 20 | svgname = image_url.split('/')[-1] 21 | pngname = svgname[:-4] + '.png' 22 | with open(svgname, 'wb') as f: 23 | f.write(image.content) 24 | svg_image = svg2rlg(svgname) 25 | renderPM.drawToFile(svg_image, pngname, fmt="PNG") 26 | openedImage = Image.open(pngname) 27 | image_tokens = self.preprocess(openedImage).unsqueeze(0).to(device) 28 | os.remove(svgname) 29 | os.remove(pngname) 30 | else: 31 | openedImage = Image.open(image.raw) 32 | image_tokens = self.preprocess(openedImage).unsqueeze(0).to(device) 33 | openedImage.close() 34 | logits = [] 35 | for caption in captions: 36 | text_tokens = self.tokenizer(caption, context_length=77, truncate=True).to(device) 37 | with torch.no_grad(): 38 | logits_per_image, _ = self.model(image_tokens, text_tokens) 39 | logits.append(list(torch.flatten(logits_per_image))[0].item()) 40 | return logits, image_tokens -------------------------------------------------------------------------------- /general/helper_scripts/wit_dtype.py: -------------------------------------------------------------------------------- 1 | DTYPE = { 2 | 'language': str, 3 | 'page_url': str, 4 | 'image_url': str, 5 | 'page_title': str, 6 | 'section_title': str, 7 | 'hierarchical_section_title': str, 8 | 'caption_reference_description': str, 9 | 'caption_attribution_description': str, 10 | 'caption_alt_text_description': str, 11 | 'mime_type': str, 12 | 'original_height': int, 13 | 'original_width': int, 14 | 'is_main_image': bool, 15 | 'attribution_passes_lang_id': bool, 16 | 'page_changed_recently': str, 17 | 'context_page_description': str, 18 | 'context_section_description': str 19 | } 20 | 21 | DFLENGTH = { 22 | 'wit_v1.train.all-00004-of-00010.tsv.gz': 3701161, 23 | 'wit_v1.train.all-00001-of-00010.tsv.gz': 3702075, 24 | 'wit_v1.train.all-00005-of-00010.tsv.gz': 3708106, 25 | 'wit_v1.train.all-00006-of-00010.tsv.gz': 3704684, 26 | 'wit_v1.train.all-00002-of-00010.tsv.gz': 3701785, 27 | 'wit_v1.train.all-00007-of-00010.tsv.gz': 3703736, 28 | 'wit_v1.train.all-00008-of-00010.tsv.gz': 3705646, 29 | 'wit_v1.train.all-00000-of-00010.tsv.gz': 3708026, 30 | 'wit_v1.train.all-1percent_sample.tsv.gz': 370373, 31 | 'wit_v1.train.all-00003-of-00010.tsv.gz': 3706924 32 | } 33 | 34 | DFLENGTH_ENGLISH = { 35 | 'wit_v1.train.all-00004-of-00010.tsv.gz': 540463, 36 | 'wit_v1.train.all-00001-of-00010.tsv.gz': 542006, 37 | 'wit_v1.train.all-00005-of-00010.tsv.gz': 540982, 38 | 'wit_v1.train.all-00006-of-00010.tsv.gz': 540387, 39 | 'wit_v1.train.all-00002-of-00010.tsv.gz': 540499, 40 | 'wit_v1.train.all-00007-of-00010.tsv.gz': 541728, 41 | 'wit_v1.train.all-00008-of-00010.tsv.gz': 540557, 42 | 'wit_v1.train.all-00000-of-00010.tsv.gz': 542593, 43 | 'wit_v1.train.all-1percent_sample.tsv.gz': 54071, 44 | 'wit_v1.train.all-00003-of-00010.tsv.gz': 541391 45 | } -------------------------------------------------------------------------------- /general/helper_scripts/wit_image_downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from PIL import Image 4 | 5 | maxwidth = 256 6 | maxheight = 256 7 | 8 | def wit_download_image(url, saveimages=False): 9 | foo = requests.get( 10 | url, 11 | headers={'User-Agent': 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0'}, 12 | stream=True, 13 | timeout=3) 14 | if saveimages: 15 | with Image.open(foo) as fooimage: 16 | a = max(maxwidth/fooimage.size[0], maxheight/fooimage.size[1]) 17 | fooimage = fooimage.resize((int(fooimage.size[0] * a), int(fooimage.size[1] * a)), Image.ANTIALIAS) 18 | with open(os.path.join('./wit_images/', + id + '.jpg'), 'wb') as file: 19 | fooimage.save(file, optimize=True, quality=85) 20 | return foo -------------------------------------------------------------------------------- /general/helper_scripts/wit_url_downloader.py: -------------------------------------------------------------------------------- 1 | import urllib, os 2 | from tqdm import tqdm 3 | import urllib.request 4 | 5 | def download_wit_urls(urlfolder='../wit_urls', onepercentsample=True): 6 | links = ["https://storage.googleapis.com/gresearch/wit/wit_v1.train.all-0000{}-of-00010.tsv.gz".format(i) for i in range(9)] 7 | if onepercentsample: 8 | links = ["https://storage.googleapis.com/gresearch/wit/wit_v1.train.all-1percent_sample.tsv.gz"] 9 | filenames = [link.split('/')[-1] for link in links] 10 | os.makedirs(urlfolder, exist_ok=True) 11 | 12 | class TqdmUpTo(tqdm): 13 | def update_to(self, b=1, bsize=1, tsize=None): 14 | if tsize is not None: 15 | self.total = tsize 16 | return self.update(b * bsize - self.n) 17 | 18 | for witurl, filename in zip(links, filenames): 19 | filepath = os.path.join(urlfolder, filename) 20 | if not os.path.exists(filepath): 21 | with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, 22 | desc=witurl.split('/')[-1]) as t: # all optional kwargs 23 | urllib.request.urlretrieve(witurl, filename=filepath, 24 | reporthook=t.update_to, data=None) 25 | t.total = t.n 26 | else: 27 | print('{} already downloaded.'.format(filename)) -------------------------------------------------------------------------------- /general/openimages_labels.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import re 4 | import h5py 5 | import json 6 | from tqdm import trange 7 | import numpy as np 8 | import pandas as pd 9 | from tabulate import tabulate 10 | 11 | ############################################################################################# 12 | ###### ATTENTION ############################################################################ 13 | ###### You need to download class-descriptions-boxable.csv from the following website ####### 14 | ###### https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv ########## 15 | ############################################################################################# 16 | 17 | def get_open_images_label_names(): 18 | with open("./downsampled-open-images-v4/class-descriptions-boxable.csv", "r") as f: 19 | open_image_labels = {x.split(",")[0]: x.split(",")[1] for x in f.read().split("\n") if len(x)} 20 | return open_image_labels 21 | 22 | def get_open_images_labels(annotations_path): 23 | open_image_labels = get_open_images_label_names() 24 | df = pd.read_csv(annotations_path) 25 | image_to_labels = {} 26 | dropped = [] 27 | pbar = trange(len(df.ImageID.unique())) 28 | path_f = "./downsampled-open-images-v4/256px/" 29 | if "validation" in annotations_path: 30 | path_f += "validation/" 31 | elif "train" in annotations_path: 32 | path_f += "train-256/" 33 | elif "test" in annotations_path: 34 | path_f += "test/" 35 | for _, (img_id, df_sub) in zip(pbar, df.groupby("ImageID")): 36 | path = f"{path_f}{img_id}.jpg" 37 | pbar.set_description(f"Loading {path[::-1][:40][::-1]}") 38 | high_conf = df_sub[df_sub.Confidence == 1].LabelName.values.tolist() 39 | low_conf = df_sub[df_sub.Confidence != 1].LabelName.values.tolist() 40 | if not high_conf or not os.path.exists(path): 41 | dropped.append(img_id) 42 | image_to_labels["open_images_" + img_id] = { 43 | "label": [ 44 | [open_image_labels[x] for x in high_conf], 45 | [open_image_labels[x] for x in low_conf] 46 | ], 47 | "path": path 48 | } 49 | return image_to_labels, dropped 50 | 51 | # ---- Captions are generated using CaptionsGenerator 52 | 53 | class CaptionGenerator(): 54 | templates_labels = [ 55 | "a picture of {}", 56 | "a photo that has {}", 57 | "photo consisting of {}", 58 | "a low resolution photo of {}", 59 | "small photo of {}", 60 | "high resolution picture of {}", 61 | "low resolution picture of {}", 62 | "high res photo that has {}", 63 | "low res photo of {}", 64 | "{} in a photo", 65 | "{} in a picture", 66 | "rendered picture of {}", 67 | "jpeg photo of {}", 68 | "a cool photo of {}", 69 | "{} rendered in a picture", 70 | ] 71 | 72 | templates_maybe = [ 73 | *[x + " and maybe containing {}" for x in templates_labels], 74 | *[x + " and possibly containing {}" for x in templates_labels], 75 | *[x + " and {} but not sure" for x in templates_labels], 76 | *[x + " also roughly {}" for x in templates_labels], 77 | ] 78 | 79 | captions_templates = { 80 | "open_images": [templates_labels, templates_maybe], 81 | } 82 | 83 | def __init__(self): 84 | self.ds_names = list(self.captions_templates.keys()) 85 | 86 | def generate_open_images_caption(self, ds): 87 | temps_high, temps_low = self.captions_templates["open_images"] 88 | captions = {} 89 | for i,k in enumerate(ds): 90 | high_conf = ", ".join(ds[k]["label"][0]) 91 | if np.random.random() > 0.5: 92 | low_conf = ", ".join(ds[k]["label"][1]) 93 | temp = np.random.choice(temps_low, size=1)[0] 94 | cap = temp.format(high_conf, low_conf) 95 | else: 96 | temp = np.random.choice(temps_high, size = 1)[0] 97 | cap = temp.format(high_conf) 98 | cap = re.sub(r"\s+", " ", cap).strip().lower() 99 | captions["open_images_" + str(k)] = { 100 | "path": ds[k]["path"], 101 | "caption": cap 102 | } 103 | return captions 104 | 105 | def generate_captions(self, ds, ds_name): 106 | print("Generating captions for", ds_name) 107 | if ds_name not in self.ds_names: 108 | raise ValueError(f"{ds_name} not in {self.ds_names}") 109 | 110 | if ds_name == "open_images": 111 | return self.generate_open_images_caption(ds) 112 | 113 | temps = [] 114 | for temp in self.captions_templates[ds_name]: 115 | temps.extend(temp) 116 | 117 | # each ds: {: {"path": , "label": []}} 118 | captions = {} 119 | temps_ordered = np.random.randint(low = 0, high = len(temps), size = (len(ds))) 120 | for i,k in enumerate(ds): 121 | lbs_string = ", ".join(ds[k]["label"]) 122 | cap = temps[temps_ordered[i]].format(lbs_string) 123 | cap = re.sub(r"\s+", " ", cap).strip().lower() 124 | captions[ds_name + "_" + str(k)] = { 125 | "path": ds[k]["path"], 126 | "caption": cap 127 | } 128 | return captions 129 | 130 | 131 | # ---- Script 132 | if __name__ == "__main__": 133 | print("-"*70 + "\n:: Loading OpenImages Dataset") 134 | open_images_img2lab_val, oi_dropped_val = get_open_images_labels( 135 | "./downsampled-open-images-v4/validation-annotations-human-imagelabels-boxable.csv" 136 | ) 137 | open_images_img2lab_train, oi_dropped_train = get_open_images_labels( 138 | "./downsampled-open-images-v4/train-annotations-human-imagelabels-boxable.csv" 139 | ) 140 | open_images_img2lab_test, oi_dropped_test = get_open_images_labels( 141 | "./downsampled-open-images-v4/test-annotations-human-imagelabels-boxable.csv" 142 | ) 143 | 144 | # define table for tabulate 145 | headers = ["name", "num_samples", "dropped"] 146 | table = [ 147 | ["open images (train)", len(open_images_img2lab_train), len(oi_dropped_train)], 148 | ["open images (val)", len(open_images_img2lab_val), len(oi_dropped_val)], 149 | ["open images (test)", len(open_images_img2lab_test), len(oi_dropped_test)], 150 | ] 151 | table_arr = np.asarray(table) 152 | total_samples = sum([ 153 | len(open_images_img2lab_train), 154 | len(open_images_img2lab_val), 155 | len(open_images_img2lab_test), 156 | ]) 157 | total_dropped = sum([ 158 | len(oi_dropped_train), 159 | len(oi_dropped_val), 160 | len(oi_dropped_test), 161 | ]) 162 | table.append(["total", total_samples, total_dropped]) 163 | print("\n", "-"*70, "\n") 164 | print(tabulate(table, headers, tablefmt="psql")) 165 | 166 | print("\n:: Generating captions for labels") 167 | 168 | capgen = CaptionGenerator() 169 | capgen_oi_val = capgen.generate_captions(open_images_img2lab_val, "open_images") 170 | capgen_oi_train = capgen.generate_captions(open_images_img2lab_train, "open_images") 171 | capgen_oi_test = capgen.generate_captions(open_images_img2lab_test, "open_images") 172 | 173 | # make the master captions list 174 | common_captions = {} 175 | common_captions.update(capgen_oi_val) 176 | common_captions.update(capgen_oi_train) 177 | common_captions.update(capgen_oi_test) 178 | 179 | print(len(common_captions), table[-1][1]) 180 | with open("captions_train.json", "w") as f: 181 | f.write(json.dumps(common_captions)) -------------------------------------------------------------------------------- /general/openimages_narrative.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import numpy as np 4 | import shutil 5 | from pathlib import Path 6 | import os 7 | from pandarallel import pandarallel 8 | 9 | ############################################################################################# 10 | ##### ATTENTION: run this file after openimages_labels.py !!! ############################### 11 | ##### you need to have captions_tran.json generated from the openimages_labels.py !!! ####### 12 | ############################################################################################# 13 | 14 | # Settings 15 | CHUNKS = 500000 16 | OUTPUTFOLDER = 'openimages' 17 | 18 | ################################################π############################################# 19 | ###### ATTENTION ############################################################################ 20 | ###### You need to download the following 3 files from the conceptual captions website ###### 21 | ###### https://google.github.io/localized-narratives/ ####################################### 22 | ############################################################################################# 23 | 24 | os.makedirs(OUTPUTFOLDER, exist_ok=True) 25 | 26 | dicts = [] 27 | 28 | for d in open('downsampled-open-images-v4/open_images_validation_captions.jsonl'): 29 | dicts.append(json.loads(d)) 30 | 31 | for d in open('downsampled-open-images-v4/open_images_test_captions.jsonl'): 32 | dicts.append(json.loads(d)) 33 | 34 | for d in open('downsampled-open-images-v4/open_images_train_v6_captions.jsonl'): 35 | dicts.append(json.loads(d)) 36 | 37 | narrative = pd.DataFrame.from_dict(dicts) 38 | narrative.index = narrative['image_id'] 39 | narrative = narrative.rename({'caption': 'narrative'}) 40 | narrative.columns = [x.replace('caption', 'narrative') for x in narrative.columns] 41 | 42 | print('Found {} narrative image-text pairs.'.format(len(narrative))) 43 | 44 | with open('captions_train.json') as json_file: 45 | data = json.load(json_file) 46 | 47 | df = pd.DataFrame.from_dict(data, orient='index') 48 | df.index = [x.split('_')[-1] for x in list(df.index)] 49 | 50 | narrative_ids = list(narrative.index) 51 | 52 | df = df.join(narrative) 53 | 54 | df['caption'] = np.where(~df['narrative'].isna(),df['narrative'],df['caption']) 55 | 56 | def save_files(x, folder_id): 57 | shutil.copyfile(x.path, Path('./' + OUTPUTFOLDER + '/' + folder_id + '/' + x.path.split('/')[-1])) 58 | with open(Path('./' + OUTPUTFOLDER + '/' + folder_id + '/' + x.name + '.txt'), 'w') as f: 59 | f.write(x.caption) 60 | 61 | maxiter = int(len(df) / CHUNKS) + 1 62 | 63 | pandarallel.initialize() 64 | 65 | for batch in range(maxiter): 66 | print('Copying image-text-pairs into {}/{} folders.'.format(batch + 1, maxiter)) 67 | folder_id = (4 - len(str(batch)))*'0' + str(batch) 68 | sdf = df[batch*CHUNKS:(batch+1)*CHUNKS] 69 | os.makedirs(Path(OUTPUTFOLDER, folder_id), exist_ok=True) 70 | sdf.parallel_apply(lambda x: save_files(x, folder_id), axis=1) 71 | 72 | print('Done copying image-text-pairs to {} folders in outputfolder {}!'.format(maxiter, OUTPUTFOLDER)) -------------------------------------------------------------------------------- /general/wit.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from pathlib import Path 4 | from pandarallel import pandarallel 5 | import os 6 | import requests 7 | from io import BytesIO 8 | from PIL import Image 9 | from cairosvg import svg2png 10 | 11 | pandarallel.initialize() 12 | 13 | ### download urls from here https://github.com/google-research-datasets/wit/blob/main/DATA.md 14 | # FILENAME = 'wit_url_captions/wit_v1.train.all-00000-of-00010.tsv.gz' 15 | URL_FOLDER = 'wit_urls' 16 | URLS = sorted(os.listdir(URL_FOLDER)) 17 | # FILENAME = 'wit_v1.train.all-1percent_sample.tsv.gz' 18 | # FILENAME = 'wit_v1.train.all-1percent_sample1.tsv.gz' 19 | CHUNKS = 50000 20 | DATAPARENTFOLDER = 'wit' 21 | MAXWIDTH = 320 22 | MAXHEIGHT = 320 23 | 24 | LANGUAGEFILTER = True 25 | LANGUAGES = ['en'] 26 | 27 | ##### For a future version with auomatic tsv.gz folder reader 28 | # os.listdir(Path('wit_url_captions')) 29 | # MAINFOLDERNAME = 'wit_url_captions' 30 | # FILENAME = 'wit1percent.tsv.gz' 31 | # FILENAMES = os.listdir(Path(MAINFOLDERNAME)) 32 | 33 | def eng_cap(x): 34 | if 'English:' in x: 35 | eng = [y.replace('English:', '').strip() for y in x.split('\n') if 'English:' in y] 36 | return '\n'.join(eng) 37 | else: 38 | return x 39 | 40 | def write_files(x, datafolderpath): 41 | id = "%09d" % x.name 42 | try: 43 | if '.svg' in x.image_url.lower(): 44 | output = svg2png(url=x.image_url) 45 | foo_t = Image.open(BytesIO(output)).convert("RGBA") 46 | foo = Image.new("RGBA", foo_t.size, "WHITE") 47 | foo.paste(foo_t, (0, 0), foo_t) 48 | a = max(MAXWIDTH/foo.size[0], MAXHEIGHT/foo.size[1]) 49 | foo = foo.resize((int(foo.size[0] * a), int(foo.size[1] * a)), Image.ANTIALIAS).convert('RGB') 50 | foo.save(Path(datafolderpath + "/" + id + '.jpg'), optimize=True, quality=85) 51 | else: 52 | with Image.open(requests.get(x.image_url, stream=True, timeout=4).raw) as foo: 53 | if foo.mode == "RGBA": 54 | foo = Image.new("RGBA", foo.size, "WHITE") 55 | foo.paste(foo, (0, 0), foo) 56 | a = max(MAXWIDTH/foo.size[0], MAXHEIGHT/foo.size[1]) 57 | foo = foo.resize((int(foo.size[0] * a), int(foo.size[1] * a)), Image.ANTIALIAS).convert('RGB') 58 | foo.save(Path(datafolderpath + "/" + id + '.jpg'), optimize=True, quality=85) 59 | except Exception as e: 60 | print(e) 61 | print(x.image_url) 62 | pass 63 | else: 64 | with open(Path(datafolderpath + "/" + id + '.txt'), 'w') as f: 65 | f.write(x.caption) 66 | 67 | def get_df(batch): 68 | df = pd.read_csv( 69 | Path(URL_FOLDER + '/' + FILENAME), 70 | compression='gzip', 71 | header=0, 72 | sep='\t', 73 | skiprows=range(1, batch * CHUNKS), 74 | nrows=CHUNKS, 75 | quotechar='"', 76 | dtype={ 77 | 'language': str, 78 | 'page_url': str, 79 | 'image_url': str, 80 | 'page_title': str, 81 | 'section_title': str, 82 | 'hierarchical_section_title': str, 83 | 'caption_reference_description': str, 84 | 'caption_attribution_description': str, 85 | 'caption_alt_text_description': str, 86 | 'mime_type': str, 87 | 'original_height': int, 88 | 'original_width': int, 89 | 'is_main_image': bool, 90 | 'attribution_passes_lang_id': bool, 91 | 'page_changed_recently': str, 92 | 'context_page_description': str, 93 | 'context_section_description': str 94 | }, 95 | # sep='\t', 96 | error_bad_lines=False) 97 | 98 | df.index = [x + batch * CHUNKS for x in list(df.index)] 99 | 100 | if LANGUAGEFILTER: 101 | df = df[df['language'].isin(LANGUAGES)] 102 | 103 | print('Preprocessing captions...') 104 | df = df.replace(np.nan, '') 105 | df['caption'] = df['caption_reference_description'] + '\n' + df['caption_attribution_description'] + '\n' + df['caption_alt_text_description'] + '\n' + df['hierarchical_section_title'] 106 | df['caption'] = '\n' + df['caption_reference_description'] + '\n' + df['caption_attribution_description'] + '\n' + df['caption_alt_text_description'] + '\n' 107 | df['caption'] = df['caption'].str.replace(r'^\.+\s+', '', regex=True) 108 | df['caption'] = df['caption'].str.replace(r'\.\.+', '.', regex=True) 109 | df['caption'] = df['caption'].str.replace(r'\s\s+', ' ', regex=True) 110 | df['caption'] = df['caption'].str.replace(r'\s+\.+', '', regex=True) 111 | df['caption'] = df['caption'].str.replace(r'&', 'and', regex=True) 112 | df['caption'] = df['caption'].str.strip() 113 | df['caption'] = df['caption'].parallel_apply(lambda x: eng_cap(x)) 114 | df['index'] = df.index 115 | return df 116 | 117 | print('Found {} files containing Image-URLs.'.format(len(URLS))) 118 | 119 | for i, FILENAME in enumerate(URLS): 120 | DATAFOLDER = DATAPARENTFOLDER + '/' + FILENAME 121 | os.makedirs(DATAFOLDER, exist_ok=True) 122 | 123 | print('{} - Starting downloading image-text-pairs from {} to {}'.format(i + 1, FILENAME, DATAFOLDER)) 124 | 125 | batch = 0 126 | remaining_df_length = 1 127 | df = get_df(batch) 128 | 129 | while remaining_df_length > 0: 130 | print('Reading url-caption file...') 131 | 132 | totallength = len(df) 133 | foldername = "%05d" % batch 134 | 135 | pathstring = DATAFOLDER + '/' + foldername 136 | path = Path(DATAFOLDER + '/' + foldername) 137 | text_files = [*path.glob('**/*.txt')] 138 | text_files = {text_file.stem: text_file for text_file in text_files} # str(text_file.parents[0]) + 139 | text_total = len(text_files) 140 | 141 | image_files = [ 142 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 143 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 144 | ] 145 | image_files = {image_file.stem: image_file for image_file in image_files} # str(image_file.parents[0]) + 146 | image_total = len(image_files) 147 | 148 | print('Found {:,} textfiles and {:,} images already downloaded for batch {}.'.format(text_total, image_total, batch)) 149 | 150 | keys = (image_files.keys() & text_files.keys()) 151 | 152 | 153 | os.makedirs(path, exist_ok=True) 154 | print('Downloading texts and images into {}.'.format(pathstring)) 155 | filteredsplitdf = df[~df.index.isin(keys)] 156 | # filteredsplitdf = filteredsplitdf[~filteredsplitdf.index.isin(skipfilenumbers)] 157 | dflength = len(filteredsplitdf) 158 | 159 | print('Total length batch {}: {:,}'.format(batch, totallength)) 160 | print('Remaining batch length: {:,}'.format(dflength)) 161 | 162 | if dflength > 0: 163 | filteredsplitdf.parallel_apply(lambda x: write_files(x, pathstring), axis=1) 164 | 165 | batch += 1 166 | df = get_df(batch) 167 | remaining_df_length = len(df) 168 | 169 | print('Finished downloading WIT.') 170 | 171 | # text_files = [*path.glob('**/*.txt')] 172 | # text_files = {text_file.stem: text_file for text_file in text_files} # str(text_file.parents[0]) + 173 | # text_total = len(text_files) 174 | 175 | # image_files = [ 176 | # *path.glob('**/*.png'), *path.glob('**/*.jpg'), 177 | # *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 178 | # ] 179 | # image_files = {image_file.stem: image_file for image_file in image_files} # str(image_file.parents[0]) + 180 | # image_total = len(image_files) 181 | 182 | # print('Found {:,} textfiles and {:,} images already downloaded.'.format(text_total, image_total)) 183 | 184 | # print('Finished downloading {:,} images and {:,} texts.'.format(image_total, text_total)) 185 | -------------------------------------------------------------------------------- /general/wit_clip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import pickle 5 | from tqdm import tqdm 6 | import pandas as pd 7 | from multiprocessing import cpu_count #, get_context 8 | from helper_scripts.wit_url_downloader import download_wit_urls 9 | from helper_scripts.wit_clip_class import CLIP 10 | from helper_scripts.wit_dtype import DTYPE, DFLENGTH, DFLENGTH_ENGLISH 11 | from helper_scripts.wit_image_downloader import wit_download_image 12 | from concurrent.futures import ThreadPoolExecutor 13 | 14 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 15 | 16 | ONLYENGLISH = True 17 | MULTIPROCESSING = True 18 | THREAD_COUNT = 2*cpu_count()+1 19 | CHUNKSIZE = 10000 20 | EMBEDDINGS_PER_PICKLE = 5000 21 | SIMILARITIESFOLDER = './wit/witsimilarities' 22 | EMBEDDINGSFOLDER = './wit/witembeddings' 23 | WITURLFOLDER = './wit/witurls' 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--wit_url_folder', type=str, 28 | help='Download location for WIT urls.') 29 | 30 | parser.add_argument('--onepercentsample', 31 | dest='onepercentsample', 32 | action='store_true', 33 | help='Only download 1% sample file.') 34 | 35 | parser.add_argument('--saveimages', 36 | dest='saveimages', 37 | action='store_true', 38 | help='Save the images on the local drive.') 39 | 40 | parser.add_argument('--saveembeddings', 41 | dest='saveembeddings', 42 | action='store_true', 43 | help='Save the image embeddings on the local drive.') 44 | 45 | parser.add_argument('--savewds', 46 | dest='savewds', 47 | action='store_true', 48 | help='Save the images and best matching caption as WebDataset') 49 | 50 | args = parser.parse_args() 51 | 52 | wit_url_folder = args.wit_url_folder if args.wit_url_folder else WITURLFOLDER 53 | 54 | clipper = CLIP() 55 | 56 | os.makedirs(SIMILARITIESFOLDER, exist_ok=True) 57 | if args.saveembeddings: 58 | os.makedirs(EMBEDDINGSFOLDER, exist_ok=True) 59 | 60 | dtv = list(DTYPE.keys()) 61 | caption_dict = {0:dtv[4], 1:dtv[5], 2:dtv[6], 3:dtv[7], 4:dtv[8], 5:dtv[15], 6:dtv[16]} 62 | 63 | def task_done(future): 64 | try: 65 | result = future.result() 66 | except: 67 | return False 68 | else: 69 | return result 70 | 71 | def process_row(row): 72 | saveembeddings = row[18] 73 | saveimages = row[19] 74 | image_url = row[3] 75 | captions = [ 76 | row[5], # row.page_title, 77 | row[6], # row.section_title, 78 | row[7], # row.hierarchical_section_title, 79 | row[8], # row.caption_attribution_description, 80 | row[9], # row.caption_alt_text_description, 81 | row[16], # row.context_page_description, 82 | row[17] # row.context_section_description 83 | ] 84 | available_captions = [True if isinstance(x, str) else False for x in captions] 85 | caption_tuples = [(i, x) for i, x in enumerate(captions) if available_captions[i]] 86 | available_ids, captions = list(zip(*caption_tuples)) 87 | 88 | try: 89 | image_request = wit_download_image(image_url, saveimages) 90 | similarities, embeddings = clipper.return_similarities(image_request, captions, image_url) 91 | similarities = {caption_dict[j]: round(similarities[i], 4) for i, j in enumerate(available_ids) } 92 | except Exception as e: 93 | print('Exception while trying to download {}'.format(image_url)) 94 | print(e) 95 | return False, False, False 96 | else: 97 | if not saveembeddings: 98 | embeddings = None 99 | return row[0], similarities, embeddings 100 | 101 | if __name__ == '__main__': 102 | start = time.time() 103 | global_counter = 0 104 | download_wit_urls(urlfolder=wit_url_folder, onepercentsample=args.onepercentsample) 105 | 106 | fns = sorted([x for x in os.listdir(wit_url_folder) if x[0] != '.' and '.tsv.gz' in x]) 107 | if not args.onepercentsample: 108 | fns = [x for x in fns if '1percent' not in x] 109 | 110 | for i, wit_filename in enumerate(fns): 111 | print('Processing {}. file: {}...'.format(i+1, wit_filename)) 112 | if ONLYENGLISH: 113 | dflen = DFLENGTH_ENGLISH[wit_filename] 114 | else: 115 | dflen = DFLENGTH[wit_filename] 116 | pbar = tqdm(total=dflen) 117 | similarities_dict = {} 118 | embeddings_dict_counter = 0 119 | if args.saveembeddings: 120 | embeddings_dict = {} 121 | if '1percent' in wit_filename: 122 | prefix = "onepercent" 123 | else: 124 | prefix = 'main' + (wit_filename[-17]) 125 | with pd.read_csv( 126 | os.path.join(wit_url_folder, wit_filename), 127 | sep="\t", 128 | compression="gzip", 129 | chunksize=CHUNKSIZE, 130 | quotechar='"', 131 | dtype=DTYPE, 132 | error_bad_lines=False 133 | ) as reader: 134 | for i, df in enumerate(reader): 135 | if ONLYENGLISH: 136 | df = df[df['language'] == 'en'] 137 | # dflen = dflen - i*CHUNKSIZE 138 | df['saveembeddings'] = args.saveembeddings 139 | df['saveimages'] = args.saveimages 140 | embeddings_dict = {} 141 | results = [] 142 | 143 | if MULTIPROCESSING: 144 | with ThreadPoolExecutor() as executor: 145 | for res in executor.map(process_row, df.itertuples(name=None)): 146 | results.append(res) 147 | pbar.update() 148 | else: 149 | for row in tqdm(df.itertuples(name=None), total=dflen): 150 | result = process_row(row) 151 | results.append(result) 152 | pbar.update() 153 | 154 | for result in results: 155 | if result[0] != False: 156 | index, sim, emb = result 157 | similarities_dict[index] = sim 158 | if args.saveembeddings: 159 | embeddings_dict[index] = emb 160 | if len(embeddings_dict.keys()) >= EMBEDDINGS_PER_PICKLE: 161 | with open(os.path.join( 162 | EMBEDDINGSFOLDER, 163 | '{}_{:05d}_image_embeddings.pkl'.format(prefix, embeddings_dict_counter) 164 | ), 'wb') as f: 165 | pickle.dump(embeddings_dict, f) 166 | embeddings_dict_counter += 1 167 | embeddings_dict = {} 168 | 169 | if len(embeddings_dict) > 0: 170 | with open(os.path.join( 171 | EMBEDDINGSFOLDER, 172 | '{}_{:05d}_image_embeddings.pkl'.format(prefix, embeddings_dict_counter) 173 | ), 'wb') as f: 174 | pickle.dump(embeddings_dict, f) 175 | embeddings_dict_counter += 1 176 | 177 | similarity_df = pd.DataFrame.from_dict(similarities_dict, orient='index') 178 | similarity_df.index.name = 'index' 179 | similarity_df.index = similarity_df.index.astype(int) 180 | similarity_df = similarity_df.sort_index() 181 | similarity_df.to_csv( 182 | os.path.join( 183 | SIMILARITIESFOLDER, 184 | wit_filename.replace('.tsv.gz', '') + '_with_similarities_{:05d}'.format(i) + '.tsv' 185 | ), sep="\t") 186 | 187 | global_counter += DFLENGTH_ENGLISH[wit_filename] if ONLYENGLISH else DFLENGTH[wit_filename] 188 | pbar.close() 189 | 190 | end = time.time() 191 | elapsed = end - start 192 | print('Finished processing {} WIT-rows in {:.2f} hours!'.format(global_counter, elapsed/(60*60))) -------------------------------------------------------------------------------- /general/wit_old.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from pathlib import Path 4 | from pandarallel import pandarallel 5 | import os 6 | import requests 7 | from PIL import Image 8 | 9 | pandarallel.initialize() 10 | 11 | ### download urls from here https://github.com/google-research-datasets/wit/blob/main/DATA.md 12 | # FILENAME = 'wit_url_captions/wit_v1.train.all-00000-of-00010.tsv.gz' 13 | FILENAME = 'wit_v1.train.all-1percent_sample.tsv.gz' 14 | GROUP = FILENAME[38] 15 | CHUNKS = 50000 16 | TEXTFOLDER = 'texts' 17 | IMAGEFOLDER = 'images' 18 | SKIPFOLDER = 'skips' 19 | IMAGEFORMATS = ['jpg', 'jpeg', 'bmp', 'png'] 20 | MAXWIDTH = 320 21 | MAXHEIGHT = 320 22 | 23 | LANGUAGEFILTER = True 24 | LANGUAGES = ['en'] 25 | 26 | ##### For a future version with auomatic tsv.gz folder reader 27 | # os.listdir(Path('wit_url_captions')) 28 | # MAINFOLDERNAME = 'wit_url_captions' 29 | # FILENAME = 'wit1percent.tsv.gz' 30 | # FILENAMES = os.listdir(Path(MAINFOLDERNAME)) 31 | 32 | os.makedirs(TEXTFOLDER, exist_ok=True) 33 | os.makedirs(IMAGEFOLDER, exist_ok=True) 34 | os.makedirs(SKIPFOLDER, exist_ok=True) 35 | 36 | imagefolders = os.listdir(Path(IMAGEFOLDER)) 37 | textfolders = os.listdir(Path(TEXTFOLDER)) 38 | skipfiles = os.listdir(Path(SKIPFOLDER)) 39 | 40 | ####### Progress calculation based on image files 41 | def return_total_downloaded_images(): 42 | images = [] 43 | for subimagefolder in os.listdir(Path(IMAGEFOLDER)): 44 | images += os.listdir(Path(IMAGEFOLDER + '/' + subimagefolder)) 45 | return len(images) 46 | 47 | # def return_total_downloaded_texts(): 48 | # texts = [] 49 | # for subtextfolder in os.listdir(Path(TEXTFOLDER)): 50 | # texts += os.listdir(Path(TEXTFOLDER + '/' + subtextfolder)) 51 | # return len(texts) 52 | 53 | imagefiles = [] 54 | textfiles = [] 55 | skipfilenumbers = [] 56 | 57 | ####### Extracting content of subfolders 58 | for subtextfolder in textfolders: 59 | textfiles += os.listdir(Path(TEXTFOLDER + '/' + subtextfolder)) 60 | 61 | for subimagefolder in imagefolders: 62 | imagefiles += os.listdir(Path(IMAGEFOLDER + '/' + subimagefolder)) 63 | 64 | ######## Calculating downloaded files 65 | if len(imagefiles) > 0: 66 | imagefilenumbers = [int(x[1:-4]) for x in imagefiles] 67 | else: 68 | imagefilenumbers = [] 69 | 70 | if len(textfiles) > 0: 71 | textfilenumbers = [int(x[1:-4]) for x in textfiles] 72 | else: 73 | textfilenumbers = [] 74 | 75 | missing_images = [x for x in textfilenumbers if x not in imagefilenumbers] 76 | missing_texts = [x for x in imagefilenumbers if x not in textfilenumbers] 77 | 78 | print('Missing {:,} images and {:,} texts.'.format(len(missing_images), len(missing_texts))) 79 | 80 | ####### Note skip folder does not have subfolders, size does not matter 81 | ####### because it is not needed for training 82 | # if len(skipfiles) > 0: 83 | # skipfilenumbers = [int(x[1:]) for x in skipfiles] 84 | 85 | print('Already downloaded {:,} images and {:,} texts.'.format(len(imagefilenumbers), len(textfilenumbers))) 86 | 87 | def load_missing_captions(x, textfolderpath, imagefolderpath, i, itotal, totallength): 88 | id = "w" + "0"*(9-len(str(x.name))) + str(x.name) 89 | with open(Path(textfolderpath + '/' + id + '.txt'), 'w') as f: 90 | f.write(x.caption) 91 | 92 | def write_files(x, textfolderpath, imagefolderpath, i, itotal, totallength): 93 | id = "w" + "0"*(9-len(str(x.name))) + str(x.name) 94 | try: 95 | foo = Image.open(requests.get(x.image_url, stream=True, timeout=4).raw) 96 | a = max(MAXWIDTH/foo.size[0], MAXHEIGHT/foo.size[1]) 97 | foo = foo.resize((int(foo.size[0] * a), int(foo.size[1] * a)), Image.ANTIALIAS) 98 | foo.save(Path(imagefolderpath + "/" + id + '.jpg'), optimize=True, quality=85) 99 | except Exception: 100 | # open(Path(SKIPFOLDER + '/' + id), 'a').close 101 | pass 102 | else: 103 | with open(Path(textfolderpath + '/' + id + '.txt'), 'w') as f: 104 | f.write(x.caption) 105 | 106 | 107 | print('Reading url-caption file...') 108 | df = pd.read_csv( 109 | Path(FILENAME), 110 | compression='gzip', 111 | header=0, 112 | sep='\t', 113 | quotechar='"', 114 | dtype={ 115 | 'language': str, 116 | 'page_url': str, 117 | 'image_url': str, 118 | 'page_title': str, 119 | 'section_title': str, 120 | 'hierarchical_section_title': str, 121 | 'caption_reference_description': str, 122 | 'caption_attribution_description': str, 123 | 'caption_alt_text_description': str, 124 | 'mime_type': str, 125 | 'original_height': int, 126 | 'original_width': int, 127 | 'is_main_image': bool, 128 | 'attribution_passes_lang_id': bool, 129 | 'page_changed_recently': str, 130 | 'context_page_description': str, 131 | 'context_section_description': str 132 | }, 133 | # sep='\t', 134 | error_bad_lines=False) 135 | 136 | if LANGUAGEFILTER: 137 | df = df[df['language'].isin(LANGUAGES)] 138 | 139 | print('Preprocessing captions...') 140 | df = df.replace(np.nan, '') 141 | df['caption'] = df['caption_reference_description'] + '\n' + df['caption_attribution_description'] + '\n' + df['caption_alt_text_description'] + '\n' + df['hierarchical_section_title'] 142 | df['caption'] = '\n' + df['caption_reference_description'] + '\n' + df['caption_attribution_description'] + '\n' + df['caption_alt_text_description'] + '\n' 143 | df['caption'] = df['caption'].str.replace(r'^\.+\s+', '', regex=True) 144 | df['caption'] = df['caption'].str.replace(r'\.\.+', '.', regex=True) 145 | df['caption'] = df['caption'].str.replace(r'\s\s+', ' ', regex=True) 146 | df['caption'] = df['caption'].str.replace(r'\s+\.+', '', regex=True) 147 | df['caption'] = df['caption'].str.replace(r'&', 'and', regex=True) 148 | df['caption'] = df['caption'].str.strip() 149 | df['index'] = df.index 150 | 151 | totallength = len(df) 152 | 153 | parts = totallength / CHUNKS 154 | 155 | splits = np.array_split(df, parts) 156 | 157 | # print(totallength) 158 | # print(len(splits)) 159 | itotal = len(splits) 160 | print('The whole dataframe was divided into {} part(s).'.format(itotal)) 161 | 162 | for i, splitdf in enumerate(splits): 163 | foldername = "w" + str(GROUP) + '_' + "0"*(3-len(str(i))) + str(i) 164 | textfolderpath = TEXTFOLDER + '/' + foldername 165 | imagefolderpath = IMAGEFOLDER + '/' + foldername 166 | os.makedirs(Path(textfolderpath), exist_ok=True) 167 | os.makedirs(Path(imagefolderpath), exist_ok=True) 168 | textfolderfiles = os.listdir(textfolderpath) 169 | folderfiles = len(os.listdir(imagefolderpath)) 170 | if len(missing_texts) > 0: 171 | print('Trying do generate missing text files...') 172 | missingdf = df[df.index.isin(missing_texts)] 173 | missingdf.parallel_apply(lambda x: load_missing_captions(x, textfolderpath, imagefolderpath, i, itotal, totallength), axis=1) 174 | textfiles = [] 175 | for subtextfolder in textfolders: 176 | textfiles += os.listdir(Path(TEXTFOLDER + '/' + subtextfolder)) 177 | textfilestotal = len(textfiles) 178 | print('Successfully generated {:,} missing text file(s).'.format(len(missingdf))) 179 | print('Downloading texts into {} and images into {}.'.format(textfolderpath, imagefolderpath)) 180 | filteredsplitdf = splitdf[~splitdf.index.isin(textfilenumbers)] 181 | # filteredsplitdf = filteredsplitdf[~filteredsplitdf.index.isin(skipfilenumbers)] 182 | dflength = len(filteredsplitdf) 183 | downloadedimages = return_total_downloaded_images() 184 | print('Total length: {:,}'.format(totallength)) 185 | print('Downloaded images: {:,}'.format(downloadedimages)) 186 | print('Remaining images: {:,}\n'.format(totallength - downloadedimages)) 187 | print(dflength) 188 | if dflength > 0: 189 | while return_total_downloaded_images() < totallength: 190 | print('##############') 191 | print('Remaining {:,} images to get downloaded.'.format(totallength - return_total_downloaded_images())) 192 | print('##############') 193 | filteredsplitdf.parallel_apply(lambda x: write_files(x, textfolderpath, imagefolderpath, i, itotal, totallength), axis=1) 194 | 195 | textfolders = os.listdir(Path(TEXTFOLDER)) 196 | textfiles = [] 197 | for subtextfolder in textfolders: 198 | textfiles += os.listdir(Path(TEXTFOLDER + '/' + subtextfolder)) 199 | textfilestotal = len(textfiles) 200 | 201 | imagefolders = os.listdir(Path(IMAGEFOLDER)) 202 | imagefiles = [] 203 | for subimagefolder in imagefolders: 204 | imagefiles += os.listdir(Path(IMAGEFOLDER + '/' + subimagefolder)) 205 | imagefilestotal = len(imagefiles) 206 | 207 | print('Finished downloading {:,} images and {:,} texts.'.format(imagefilestotal, textfilestotal)) 208 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'dalle-datasets', 5 | packages = find_packages(), 6 | include_package_data = True, 7 | version = '0.1.0', 8 | license='MIT', 9 | description = 'DALL-E - Datasets', 10 | author = 'Robert van Volt', 11 | author_email = 'robvanvolt@gmail.com', 12 | url = 'https://github.com/robvanvolt/dalle-datasets', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'big data', 16 | 'datasets', 17 | ], 18 | install_requires=[ 19 | 'pillow', 20 | 'regex', 21 | 'torch', 22 | 'torchvision', 23 | 'WebDataset', 24 | 'tfrecord', 25 | 'tensorflow', 26 | 'youtokentome', 27 | 'pandas', 28 | 'pandarallel', 29 | 'svglib' 30 | ], 31 | classifiers=[ 32 | 'Development Status :: 4 - Beta', 33 | 'Intended Audience :: Developers', 34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 35 | 'License :: OSI Approved :: MIT License', 36 | 'Programming Language :: Python :: 3.8.8', 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /utilities/clip_wit.py: -------------------------------------------------------------------------------- 1 | # pip install git+https://github.com/openai/CLIP.git 2 | import matplotlib.pyplot as plt 3 | import matplotlib.image as mpimg 4 | from pathlib import Path 5 | import clip 6 | import torch 7 | import torch.nn as nn 8 | from PIL import Image 9 | 10 | import os 11 | 12 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 13 | 14 | 15 | ### TODO: clip image / text 16 | ### TODO: WDS support output 17 | ### TODO: Multiprocessing 18 | ### TODO: GPU support 19 | 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | model, preprocess = clip.load("ViT-B/32", device=device) 22 | 23 | DATAFOLDER = 'data' 24 | 25 | path = Path(DATAFOLDER) 26 | 27 | text_files = [*path.glob('**/*.txt')] 28 | text_files = {text_file.stem: text_file for text_file in text_files} # str(text_file.parents[0]) + 29 | text_total = len(text_files) 30 | 31 | image_files = [ 32 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 33 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp'), 34 | *path.glob('**/*.PNG'), *path.glob('**/*.JPG'), 35 | *path.glob('**/*.JPEG'), *path.glob('**/*.BMP') 36 | ] 37 | image_files = {image_file.stem: image_file for image_file in image_files} # str(image_file.parents[0]) + 38 | image_total = len(image_files) 39 | 40 | print('Found {:,} textfiles and {:,} images.'.format(text_total, image_total)) 41 | 42 | keys = (image_files.keys() & text_files.keys()) 43 | 44 | for key in keys: 45 | print(text_files[key]) 46 | cap = '' 47 | with open(text_files[key], 'r') as t: 48 | for line in t: 49 | cap += line + ' ' 50 | 51 | image = preprocess(Image.open(image_files[key])).unsqueeze(0).to(device) 52 | text = clip.tokenize(cap.split(' ')).to(device) 53 | 54 | with torch.no_grad(): 55 | image_features = model.encode_image(image) 56 | text_features = model.encode_text(text) # self.model.encode_text(text_tokens.to(device)).float() 57 | 58 | logits_per_image, logits_per_text = model(image, text) 59 | 60 | probs = logits_per_image.softmax(dim=-1).cpu().numpy() 61 | 62 | # image_features /= image_features.norm(dim=-1, keepdim=True) 63 | # text_features /= text_features.norm(dim=-1, keepdim=True) 64 | # similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) 65 | # print(similarity) 66 | 67 | # cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-6) 68 | # image_features = model.encode_image(image.to(device)).float() 69 | # image_features /= image_features.norm(dim=-1, keepdim=True) 70 | # text_features = model.encode_text(text.to(device)).float() 71 | # similarity = cosine_similarity(image_features, text_features).tolist() 72 | 73 | # print(float(cosine_similarity(text_features , image_features))) 74 | # print(float(cosine_similarity(torch.reshape(text_features, (1, 512)) , image_features))) 75 | # print(max(similarity)) 76 | 77 | print("Label probs:", probs) 78 | print('Cossimilarity > 0.3?') 79 | print(max(probs[0])) 80 | 81 | input() 82 | img = mpimg.imread(image_files[key]) 83 | imgplot = plt.imshow(img) 84 | plt.title(cap) 85 | plt.show() 86 | 87 | -------------------------------------------------------------------------------- /utilities/dataset_sanitycheck.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from pathlib import Path 3 | import pandas as pd 4 | import argparse 5 | import os 6 | import collections 7 | 8 | parser = argparse.ArgumentParser(description='A script to sanitycheck your image-text pairs for DALLE-pytorch training.') 9 | parser.add_argument("-D", "--dataset_folder", help="Add the folder containing image-text pairs for DALLE-training.", required=True) 10 | parser.add_argument("-DEL", "--delete_incomplete_files", action='store_true', help="Decide if the incomplete/corrupt files shall be removed.", required=False) 11 | parser.add_argument("-O", "--output_file", help="Incomplete files get saved in a textfile.", default='incomplete_files.csv') 12 | parser.add_argument("-M", "--min_characters", help="Text files with less than the specified character length get deleted.", default=5) 13 | args = parser.parse_args() 14 | 15 | DATASETFOLDER = args.dataset_folder 16 | OUTPUTFILE = args.output_file 17 | MINCHARACTERS = args.min_characters 18 | FILEFORMATS = ['jpg', 'jpeg', 'png', 'bmp'] 19 | DELETE = False 20 | if args.delete_incomplete_files is not None: 21 | DELETE = True 22 | 23 | filenames_and_folders = os.listdir(DATASETFOLDER) 24 | folders = [x for x in filenames_and_folders if os.path.isdir(Path(DATASETFOLDER + '/' + x)) == True] 25 | files = [x for x in filenames_and_folders if x not in folders] 26 | folders = [''] + folders 27 | faulty_data = {} 28 | 29 | def return_incomplete_and_paired_ids(files): 30 | ids = [x[:-4] for x in files] 31 | d = collections.defaultdict(int) 32 | for x in ids: d[x] += 1 33 | 34 | incomplete_ids = [x for x in ids if d[x] == 1] 35 | paired_ids = list(set(ids) - set(incomplete_ids)) 36 | incomplete_files = [x for x in files if x[:-4] in incomplete_ids] 37 | 38 | return {'paired_ids': paired_ids, 'incomplete_files': incomplete_files} 39 | 40 | def true_if_image_corrupt_and_fileformat(parent, id): 41 | for fileformat in FILEFORMATS: 42 | filepath = Path(parent + '/' + id + '.' + fileformat) 43 | if os.path.isfile(filepath): 44 | try: 45 | img = Image.open(Path(filepath)) 46 | img.verify() 47 | img.close() 48 | except (IOError, SyntaxError) as _: 49 | return True, fileformat 50 | else: 51 | return False, fileformat 52 | 53 | def return_empty_text_and_corrupt_images(parent, paired_ids): 54 | empty_texts = [] 55 | corrupt_images = [] 56 | delete = [] 57 | 58 | short_text = False 59 | 60 | for id in paired_ids: 61 | with open(Path(parent + '/' + id + '.txt'), 'r') as f: 62 | if len(f.read()) < MINCHARACTERS: 63 | short_text = True 64 | 65 | corrupt, fileformat = true_if_image_corrupt_and_fileformat(parent, id) 66 | 67 | if corrupt: 68 | corrupt_images.append(id + '.' + fileformat) 69 | 70 | if short_text: 71 | empty_texts.append(id + '.txt') 72 | 73 | if corrupt or short_text: 74 | delete.append(parent + '/' + id + '.txt') 75 | delete.append(parent + '/' + id + '.' + fileformat) 76 | 77 | return {'empty': empty_texts, 'corrupt': corrupt_images, 'delete': delete} 78 | 79 | for folder in folders: 80 | sep = '/' if folder != '' else '' 81 | parent = DATASETFOLDER + sep + folder 82 | files = os.listdir(parent) 83 | 84 | if folder == '': 85 | files = [x for x in files if '.' in x] 86 | 87 | incomplete_and_paired_files_in_parent_folder = return_incomplete_and_paired_ids(files) 88 | empty_and_corrupt_files_in_parent_folder = \ 89 | return_empty_text_and_corrupt_images( 90 | parent, 91 | incomplete_and_paired_files_in_parent_folder['paired_ids']) 92 | 93 | if len(files) > 0: 94 | faulty_data[parent] = { 95 | 'incomplete': incomplete_and_paired_files_in_parent_folder['incomplete_files'], 96 | 'corrupt': empty_and_corrupt_files_in_parent_folder['corrupt'], 97 | 'empty': empty_and_corrupt_files_in_parent_folder['empty'], 98 | 'delete': empty_and_corrupt_files_in_parent_folder['delete'] \ 99 | + [parent + '/' + x for x in incomplete_and_paired_files_in_parent_folder['incomplete_files']] 100 | } 101 | 102 | df = pd.DataFrame.from_dict(faulty_data).T 103 | df.index.name = 'folderpath' 104 | df.to_csv(OUTPUTFILE, sep='|') 105 | print('Table of incomplete files was saved to {}'.format(OUTPUTFILE)) 106 | 107 | if DELETE: 108 | print('Deleting incomplete and corrupt files...') 109 | delete_lists = list(df['delete']) 110 | count_deleted_files = 0 111 | for delete_list in delete_lists: 112 | count_deleted_files += len(delete_list) 113 | for delete_file in delete_list: 114 | if os.path.isfile(Path(delete_file)): 115 | os.remove(Path(delete_file)) 116 | 117 | print('Finished deleting {:,} files.'.format(count_deleted_files)) -------------------------------------------------------------------------------- /utilities/tokenizer_from_wds_or_text.py: -------------------------------------------------------------------------------- 1 | import youtokentome as yttm 2 | import webdataset as wds 3 | from pathlib import Path 4 | import argparse 5 | import shutil 6 | import html 7 | import os 8 | 9 | parser = argparse.ArgumentParser("""Generate a custom tokenizer for your WebDataset files.""") 10 | 11 | parser.add_argument( 12 | "--source", 13 | type=str, 14 | default="./shards", 15 | help="Specify the vocab size you want to use for your custom tokenizer." 16 | ) 17 | parser.add_argument( 18 | "--text_key", 19 | type=str, 20 | default="label", 21 | help="Specify the text column in your WebDataset file(s)." 22 | ) 23 | parser.add_argument( 24 | "--output_model_name", 25 | type=str, 26 | default="custom_tokenizer.bpe", 27 | help="Specify the output file of your tokenizer." 28 | ) 29 | parser.add_argument( 30 | "--coverage", 31 | type=float, 32 | default=0.9999, 33 | help="Specify the coverage for your custom tokenizer." 34 | ) 35 | parser.add_argument( 36 | "--output_folder", 37 | type=str, 38 | default='./output', 39 | help="Specify the output folder for the generated files." 40 | ) 41 | parser.add_argument( 42 | "--vocab_size", 43 | type=int, 44 | default=4096, 45 | help="Specify the vocab size for your custom tokenizer." 46 | ) 47 | args = parser.parse_args() 48 | 49 | os.makedirs(args.output_folder, exist_ok=True) 50 | 51 | path_to_textfile = args.output_folder + "/text_for_tokenizer.txt" 52 | 53 | if args.source[-4:].lower() == '.txt': 54 | print('--------------------------------------------------------') 55 | print('----> Creating custom tokenizer from provided text file.') 56 | print('--------------------------------------------------------') 57 | path_to_textfile = args.source 58 | else: 59 | assert os.path.isdir(args.source), 'The source path has to be a directory containing text files.' 60 | print('------------------------------------------------------') 61 | print('----> Generating a singe text file from dataset first.') 62 | print('------------------------------------------------------') 63 | path = Path(args.source) 64 | 65 | wds_files = [ 66 | *path.glob('**/*.tar'), *path.glob('**/*.tar.gz') 67 | ] 68 | 69 | if len(wds_files) > 0: 70 | print('Found {:,} WebDataset files (.tar/.tar.gz) in {}'.format(len(wds_files), args.source)) 71 | wds_files = [str(x) for x in wds_files] 72 | dataset = wds.WebDataset(wds_files) 73 | c = 0 74 | with open(path_to_textfile, "w") as f: 75 | for item in dataset: 76 | f.write(html.unescape(item[args.text_key].decode('utf-8'))) 77 | if c % 10000 == 0: 78 | print(' {:.2f}'.format(c), end='\r') 79 | c += 1 80 | else: 81 | print('No WebDataset files (.tar/.tar.gz) found in {}'.format(args.source)) 82 | print('Trying to find text files next (classic format with image-text-pairs).') 83 | txt_files = [*path.glob('**/*.txt')] 84 | assert len(txt_files) > 0, 'No txt files found in source directory {}'.format(args.source) 85 | with open(path_to_textfile, 'wb') as wfd: 86 | for f in txt_files: 87 | with open(f,'rb') as fd: 88 | shutil.copyfileobj(fd, wfd) 89 | 90 | yttm.BPE.train( 91 | data=path_to_textfile, 92 | model=args.output_folder + '/' + args.output_model_name, 93 | vocab_size=args.vocab_size, 94 | coverage=args.coverage, # or 1.0 95 | n_threads=-1, 96 | ) -------------------------------------------------------------------------------- /utilities/wds_create_legacy.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | import os 3 | from pathlib import Path 4 | from collections import Counter 5 | from PIL import Image 6 | 7 | DATASETPATH = 'data' 8 | OUTPUTFILENAME = 'dataset' 9 | 10 | alldirs = os.walk(Path(DATASETPATH)) 11 | all_basepaths = [] 12 | 13 | ############################################################################ 14 | ########### Legacy wds_creator - better use wds_create_shards.py ########### 15 | ############################################################################ 16 | 17 | ### (1) Find image-text pairs in all (sub)folders of the basepath 18 | 19 | for dir in alldirs: 20 | fns = dir[2] 21 | if next((True for x in fns if '.txt' in x), False): 22 | basenames = [x.split('.')[0] for x in fns] 23 | basepaths = [dir[0] + '/' + k for k, v in dict(Counter(basenames)).items() if v == 2] 24 | all_basepaths.extend(basepaths) 25 | 26 | all_basepaths = sorted(all_basepaths, key=lambda x: x.split('/')[-1]) 27 | curated_basepaths = [] 28 | 29 | 30 | ### (2) Verify images exist and can be opened 31 | 32 | for basepath in all_basepaths: 33 | img = Image.open(Path(basepath + '.jpg')) 34 | try: 35 | img.verify() 36 | except Exception: 37 | print('Invalid image on path {}'.format(basepath)) 38 | else: 39 | curated_basepaths.append(basepath) 40 | 41 | 42 | ### (3) Create compressed Webdataset tar file 43 | 44 | sink = wds.TarWriter(OUTPUTFILENAME + '.tar.gz', encoder=False) 45 | 46 | sink.write 47 | 48 | for basepath in curated_basepaths: 49 | with open(Path(basepath + '.jpg'), "rb") as imgstream: 50 | image = imgstream.read() 51 | with open(Path(basepath + '.txt'), "rb") as txtstream: 52 | text = txtstream.read() 53 | sample = { 54 | "__key__": basepath.split('/')[-1], 55 | "img": image, 56 | "cap": text 57 | } 58 | sink.write(sample) 59 | sink.close() -------------------------------------------------------------------------------- /utilities/wds_create_shards.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path 4 | import random 5 | import argparse 6 | import json 7 | 8 | from pathlib import Path 9 | from PIL import Image 10 | 11 | import webdataset as wds 12 | 13 | parser = argparse.ArgumentParser("""Generate sharded dataset from image-text-datasets.""") 14 | parser.add_argument("--maxsize", type=float, default=1e9) 15 | parser.add_argument("--maxcount", type=float, default=100000) 16 | parser.add_argument( 17 | "--compression", 18 | dest="compression", 19 | action="store_true", 20 | help="Creates compressed .tar.gz files instead of uncompressed .tar files." 21 | ) 22 | parser.add_argument( 23 | "--json", 24 | dest="json", 25 | action="store_true", 26 | help="Reads json files and adds them to the .tar files." 27 | ) 28 | parser.add_argument( 29 | "--image_text_keys", 30 | type=str, 31 | default="img,cap", 32 | help="Comma separated WebDataset dictionary keys for images (first argument) and texts (second argument). \ 33 | The exact argument has to be provided to train_dalle.py, e.g. python train_dalle.py --wds img,cp --image_text_folder ../shards" 34 | ) 35 | parser.add_argument( 36 | "--shards", 37 | default="./shards", 38 | help="directory where shards are written" 39 | ) 40 | parser.add_argument( 41 | "--shard_prefix", 42 | default="ds_", 43 | help="prefix of shards' filenames created in the shards-folder" 44 | ) 45 | parser.add_argument( 46 | "--data", 47 | default="./data", 48 | help="directory path containing data suitable for DALLE-pytorch training", 49 | ) 50 | args = parser.parse_args() 51 | 52 | assert len(args.image_text_keys.split(',')) == 2, 'Too many arguments provided' 53 | assert args.maxsize > 10000000 54 | assert args.maxcount < 1000000 55 | 56 | image_key, caption_key = tuple(args.image_text_keys.split(',')) 57 | 58 | if not os.path.isdir(os.path.join(args.data)): 59 | print(f"{args.data}: should be directory containing image-text pairs", file=sys.stderr) 60 | print(f"or subfolders containing image-text-pairs", file=sys.stderr) 61 | sys.exit(1) 62 | 63 | os.makedirs(Path(args.shards), exist_ok=True) 64 | 65 | def readfile(fname): 66 | "Read a binary file from disk." 67 | with open(fname, "rb") as stream: 68 | return stream.read() 69 | 70 | path = Path(args.data) 71 | text_files = [*path.glob('**/*.txt')] 72 | text_files = {text_file.stem: text_file for text_file in text_files} # str(text_file.parents[0]) + 73 | text_total = len(text_files) 74 | 75 | if args.json: 76 | json_files = [*path.glob('**/*.json')] 77 | json_files = {json_file.stem: json_file for json_file in json_files} 78 | json_dicts = {} 79 | # json_files_old = json_files.copy() 80 | 81 | for key in json_files: 82 | try: 83 | with open(json_files[key], "r") as f: 84 | json_dicts[key] = json.dumps(json.load(f)) 85 | except: 86 | pass 87 | # del json_files["key"] 88 | print("Found {} corrupt json file(s).".format(len(json_files.keys()) - len(json_dicts.keys()))) 89 | json_keys = json_files.keys() 90 | 91 | image_files = [ 92 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 93 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 94 | ] 95 | image_files = {image_file.stem: image_file for image_file in image_files} # str(image_file.parents[0]) + 96 | image_total = len(image_files) 97 | 98 | print('Found {:,} textfiles and {:,} images.'.format(text_total, image_total)) 99 | 100 | keys = (image_files.keys() & text_files.keys()) 101 | 102 | text_files = {k: v for k, v in text_files.items() if k in keys} 103 | image_files = {k: v for k, v in image_files.items() if k in keys} 104 | 105 | for key in image_files: 106 | img = Image.open(image_files[key]) 107 | try: 108 | img.verify() 109 | except Exception: 110 | print('Invalid image on path {}'.format(key)) 111 | keys.remove(key) 112 | 113 | print("Remaining keys after image sanity check: {:,}".format(len(keys))) 114 | 115 | total_pairs = len(keys) 116 | keys = list(keys) 117 | 118 | indexes = list(range(total_pairs)) 119 | random.shuffle(indexes) 120 | 121 | # This is the output pattern under which we write shards. 122 | pattern = os.path.join(args.shards, args.shard_prefix + f"%06d.tar" + (".gz" if args.compression else '')) 123 | 124 | with wds.ShardWriter(pattern, maxsize=int(args.maxsize), maxcount=int(args.maxcount)) as sink: 125 | for i in indexes: 126 | with open(image_files[keys[i]], "rb") as imgstream: 127 | image = imgstream.read() 128 | with open(text_files[keys[i]], "rb") as txtstream: 129 | text = txtstream.read() 130 | 131 | ds_key = "%09d" % i 132 | 133 | sample = { 134 | "__key__": ds_key, 135 | image_key: image, 136 | caption_key: text 137 | } 138 | if args.json and keys[i] in json_keys: 139 | sample["json"] = json_dicts[keys[i]] 140 | sink.write(sample) 141 | -------------------------------------------------------------------------------- /utilities/wds_from_tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import webdataset as wds 3 | from pathlib import Path 4 | import argparse 5 | import os 6 | import timeit 7 | import hashlib 8 | from io import BytesIO 9 | from PIL import Image 10 | 11 | parser = argparse.ArgumentParser("""Generate sharded dataset from tfrecord-files.""") 12 | parser.add_argument("--maxsize", type=float, default=1e9) 13 | parser.add_argument("--maxcount", type=float, default=100000) 14 | parser.add_argument( 15 | "--compression", 16 | dest="compression", 17 | action="store_true", 18 | help="Creates compressed .tar.gz files instead of uncompressed .tar files." 19 | ) 20 | parser.add_argument( 21 | "--use_encoder", 22 | dest="use_encoder", 23 | action="store_true", 24 | help="Uses encoder on unknown filetimes (the suffix in the keep_keys argument)." 25 | ) 26 | parser.add_argument( 27 | "--keep_keys", 28 | type=str, 29 | default="image.pyd,label.cls", 30 | help="Only keep the columns from the comma separated keys from that argument. The dot separated suffix is the filetype." 31 | ) 32 | parser.add_argument( 33 | "--remove_duplicates", 34 | dest="remove_duplicates", 35 | default="image", 36 | help="Remove duplicates from given column name. (e.g. --remove_duplicates image)" 37 | ) 38 | parser.add_argument( 39 | "--min_max_size", 40 | dest="min_max_size", 41 | default="192,320", 42 | help="Discards smaller and resizes larger images. (e.g. --min_max_size 256,320)" 43 | ) 44 | parser.add_argument( 45 | "--report_every", 46 | type=int, 47 | default="1000", 48 | help="Report every n iterations." 49 | ) 50 | parser.add_argument( 51 | "--shards", 52 | default="./shards", 53 | help="directory where shards are written" 54 | ) 55 | parser.add_argument( 56 | "--shard_prefix", 57 | default="wds_", 58 | help="prefix of shards' filenames created in the shards-folder" 59 | ) 60 | parser.add_argument( 61 | "--data", 62 | default="./tfr", 63 | help="directory path containing tfrecord files", 64 | ) 65 | args = parser.parse_args() 66 | 67 | KEEP_KEYS = [] 68 | if args.keep_keys != '': 69 | KEEP_KEYS = {x.split('.')[0]: x.split('.')[1] for x in args.keep_keys.split(',')} 70 | 71 | SIZE = {} 72 | if args.min_max_size != '': 73 | SIZE = { 74 | 'min': int(args.min_max_size.split(',')[0]), 75 | 'max': int(args.min_max_size.split(',')[1]) 76 | } 77 | 78 | assert args.maxsize > 10000000 79 | assert args.maxcount < 1000000 80 | assert os.path.isdir(os.path.join(args.data)), '{} does not exist.'.format(args.data) 81 | 82 | os.makedirs(Path(args.shards), exist_ok=True) 83 | 84 | tfrecord_files = [args.data + '/' + x for x in os.listdir(args.data) if x.split('.')[-1] == 'tfrecord'] 85 | total_files = len(tfrecord_files) 86 | 87 | ###### Example of a feature description to a tfrecord dataset 88 | FEATURE_DESCRIPTION = { 89 | ###### Please provide your tfrecord feature description 90 | } 91 | FEATURE_DESCRIPTION = { 92 | 'sampleID': tf.io.FixedLenFeature([], tf.string), 93 | 'image': tf.io.FixedLenFeature([], tf.string), 94 | 'format': tf.io.FixedLenFeature([], tf.string), 95 | 'label': tf.io.FixedLenFeature([], tf.string), 96 | 'height': tf.io.FixedLenFeature([], tf.int64), 97 | 'width': tf.io.FixedLenFeature([], tf.int64), 98 | } 99 | 100 | assert len(FEATURE_DESCRIPTION) > 0, 'Please provide the feature description to your tfrecord dataset.' 101 | 102 | def wrapper(gen): 103 | while True: 104 | try: 105 | yield next(gen) 106 | except StopIteration: 107 | break 108 | except Exception as e: 109 | print(e) 110 | 111 | def _parse_example(example_proto): 112 | example = tf.io.parse_single_example(example_proto, FEATURE_DESCRIPTION) 113 | return example 114 | 115 | pattern = os.path.join(args.shards, args.shard_prefix + f"%06d.tar" + (".gz" if args.compression else '')) 116 | count = 0 117 | 118 | # Arguments for removing duplicates 119 | duplicate_count = 0 120 | duplicate_md5 = set() 121 | skip_duplicate = False 122 | 123 | # Arguments for resizing / discarding images 124 | discard_count = 0 125 | resize_count = 0 126 | skip_sizemismatch_or_corrupt = False 127 | 128 | start = timeit.default_timer() 129 | with wds.ShardWriter(pattern, maxsize=int(args.maxsize), maxcount=int(args.maxcount), encoder=args.use_encoder) as sink: 130 | for tfrecord_file in tfrecord_files: 131 | raw_dataset = tf.data.TFRecordDataset(tfrecord_file) 132 | dataset = raw_dataset.map(_parse_example) 133 | for item in wrapper(dataset.as_numpy_iterator()): 134 | ds_key = "%09d" % count 135 | sample = { 136 | "__key__": ds_key, 137 | } 138 | if args.remove_duplicates != '': 139 | valuehash = hashlib.md5(item[args.remove_duplicates]).hexdigest() 140 | if valuehash in duplicate_md5: 141 | duplicate_count += 1 142 | skip_duplicate = True 143 | else: 144 | duplicate_md5.add(valuehash) 145 | 146 | if skip_duplicate == False: 147 | 148 | ### Resize, discard or keep block 149 | if args.min_max_size != '': 150 | if item['width'] < SIZE['min'] and item['height'] < SIZE['min']: 151 | discard_count += 1 152 | skip_sizemismatch_or_corrupt = True 153 | elif item['width'] > SIZE['max'] or item['height'] > SIZE['max']: 154 | # Try opening and resizing image 155 | try: 156 | foo = Image.open(BytesIO(item['image'])) 157 | if foo.mode != 'RGB': 158 | foo = foo.convert('RGB') 159 | a = max(SIZE['max']/foo.size[0], SIZE['max']/foo.size[1]) 160 | foo = foo.resize((int(foo.size[0] * a), int(foo.size[1] * a)), Image.ANTIALIAS) 161 | # Image to bytes 162 | img_byte_arr = BytesIO() 163 | foo.save(img_byte_arr, format='jpeg', optimize=True, quality=85) 164 | item['image'] = img_byte_arr.getvalue() 165 | except Exception as e: 166 | print(e) 167 | discard_count += 1 168 | skip_sizemismatch_or_corrupt = True 169 | else: 170 | resize_count += 1 171 | if skip_sizemismatch_or_corrupt == False: 172 | #### Writing row to WebDataset file 173 | for key in KEEP_KEYS: 174 | sample[key + '.' + KEEP_KEYS[key] if args.use_encoder else key] = item[key] 175 | sink.write(sample) 176 | #### End writing row to WebDataset file 177 | else: 178 | skip_sizemismatch_or_corrupt = False 179 | 180 | else: 181 | skip_duplicate = False 182 | 183 | if count % args.report_every == 0: 184 | print(' {:.2f}'.format(count), end='\r') 185 | count += 1 186 | 187 | stop = timeit.default_timer() 188 | 189 | print('#################################################################################') 190 | print('# Finished processing {:,} samples from tfrecord files.'.format(count)) 191 | print('# Process took {:.2f} seconds to finish.'.format(stop - start)) 192 | if (args.remove_duplicates != ''): 193 | print('# Skipped {:,} duplicates from a total of {:,} items.'.format(duplicate_count, count)) 194 | if (args.min_max_size != ''): 195 | print('# Discarded {:,} and resized {:,} images from remaining {:,} non-duplicates.'.format(discard_count, resize_count, count - duplicate_count)) 196 | print('# {:,} images remain in the Dataset.'.format(count - (duplicate_count + discard_count))) 197 | print('# The WebDataset files can be found in {}.'.format(args.shards)) 198 | print('#################################################################################') -------------------------------------------------------------------------------- /utilities/wds_from_tfrecords_alternative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tfrecord.torch.dataset import MultiTFRecordDataset 3 | from tfrecord.tools.tfrecord2idx import create_index 4 | import tensorflow as tf 5 | import webdataset as wds 6 | from pathlib import Path 7 | import argparse 8 | import timeit 9 | import os 10 | 11 | parser = argparse.ArgumentParser("""Generate sharded dataset from tfrecord-files.""") 12 | parser.add_argument("--maxsize", type=float, default=1e9) 13 | parser.add_argument("--maxcount", type=float, default=100000) 14 | parser.add_argument( 15 | "--compression", 16 | dest="compression", 17 | action="store_true", 18 | help="Creates compressed .tar.gz files instead of uncompressed .tar files." 19 | ) 20 | parser.add_argument( 21 | "--keep_keys", 22 | type=str, 23 | default="", 24 | help="Only keep the columns from the comma separated keys from that argument." 25 | ) 26 | parser.add_argument( 27 | "--report_every", 28 | type=int, 29 | default="1000", 30 | help="Report every n iterations." 31 | ) 32 | parser.add_argument( 33 | "--shards", 34 | default="./shards", 35 | help="directory where shards are written" 36 | ) 37 | parser.add_argument( 38 | "--shard_prefix", 39 | default="ds_", 40 | help="prefix of shards' filenames created in the shards-folder" 41 | ) 42 | parser.add_argument( 43 | "--data", 44 | default="./tfr", 45 | help="directory path containing tfrecord files", 46 | ) 47 | args = parser.parse_args() 48 | 49 | KEEP_KEYS = [] 50 | if args.keep_keys != '': 51 | KEEP_KEYS = args.keep_keys.split(',') 52 | 53 | assert args.maxsize > 10000000 54 | assert args.maxcount < 1000000 55 | assert os.path.isdir(os.path.join(args.data)), '{} does not exist.'.format(args.data) 56 | 57 | os.makedirs(Path(args.shards), exist_ok=True) 58 | 59 | index_path = args.data 60 | tfrecord_pattern = args.data + '/{}.tfrecord' 61 | index_pattern = index_path + '/{}.index' 62 | 63 | os.makedirs(index_path, exist_ok=True) 64 | 65 | tfrecord_files = [x[:-9] for x in os.listdir(args.data) if x.split('.')[-1] == 'tfrecord'] 66 | total_files = len(tfrecord_files) 67 | splits = {k: 1/total_files for k in tfrecord_files} 68 | 69 | tfrecord_index_files = [x[:-6] for x in os.listdir(index_path) if x.split('.')[-1] == 'index'] 70 | total_index_files = len(tfrecord_index_files) 71 | 72 | TFR_MATCH_INDEX = True if len([x for x in tfrecord_files if x not in tfrecord_index_files]) == 0 else False 73 | 74 | if not TFR_MATCH_INDEX: 75 | print('Index files must be provided when using multiple workers, otherwise the loader may return duplicate records.') 76 | print('Generating index files in {}...'.format(index_path)) 77 | for tfrecord_file in tfrecord_files: 78 | create_index(args.data + '/' + tfrecord_file + '.tfrecord', index_path + '/' + tfrecord_file + '.index') 79 | print('Finished generating index files!') 80 | else: 81 | print('Found matching number of index and tfrecord files.') 82 | 83 | 84 | raw_dataset = tf.data.TFRecordDataset(args.data + '/' + [x for x in os.listdir(args.data) if x.split('.')[-1] == 'tfrecord'][0]) 85 | keys = {} 86 | for raw_record in raw_dataset.take(1): 87 | example = tf.train.Example() 88 | example.ParseFromString(raw_record.numpy()) 89 | for key, value in example.features.feature.items(): 90 | keys[key] = True if value.WhichOneof('kind') == 'bytes_list' else False 91 | 92 | if len(KEEP_KEYS) > 0: 93 | keys = {k: v for k, v in keys.items() if k in KEEP_KEYS} 94 | assert len(keys.items()) > 0, 'No keys left to convert to WebDataset.' 95 | 96 | 97 | def _parse_example(example_proto): 98 | """Return the example_proto as a tuple of the image and its label.""" 99 | return {key: example_proto[key].tobytes() for key in keys} 100 | # return {key: example_proto[key].tobytes() if keys[key] else example_proto[key] for key in keys} 101 | 102 | def _collate_fn(batch): 103 | return batch[0] 104 | 105 | dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, transform=_parse_example, infinite=False) 106 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=_collate_fn, drop_last=False) 107 | 108 | # This is the output pattern under which we write shards. 109 | pattern = os.path.join(args.shards, args.shard_prefix + f"%06d.tar" + (".gz" if args.compression else '')) 110 | count = 0 111 | 112 | start = timeit.default_timer() 113 | with wds.ShardWriter(pattern, maxsize=int(args.maxsize), maxcount=int(args.maxcount)) as sink: 114 | for i, item in enumerate(iter(loader)): 115 | count = i 116 | ds_key = "%09d" % i 117 | sample = { 118 | "__key__": ds_key, 119 | } 120 | for key in keys: 121 | sample[key] = item[key] 122 | sink.write(sample) 123 | if count % args.report_every == 0: 124 | print(' {:,}'.format(count), end='\r') 125 | stop = timeit.default_timer() 126 | 127 | print('###################################################################') 128 | print('Finished converting {:,} samples from tfrecord files to webdataset.'.format(count)) 129 | print('Process took {:.2f} seconds to finish.'.format(stop - start)) 130 | print('###################################################################') -------------------------------------------------------------------------------- /utilities/wds_pytorchread.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | import torch 3 | 4 | # dataset = wds.WebDataset('dataset.tar.gz').shuffle(8).decode().to_tuple("cap", "img") 5 | dataset = wds.WebDataset('shards/ds_000000.tar').shuffle(8).decode() 6 | 7 | # for d in dataset: 8 | # input(d) 9 | 10 | # for val in dataset: 11 | # for item in val: 12 | # print(item) 13 | # input() 14 | 15 | # dataloader = torch.utils.data.DataLoader(dataset) 16 | 17 | # for val in dataloader: 18 | # input(val) 19 | 20 | image_files = {d['__key__']: d['img'] for d in dataset} 21 | text_files = {d['__key__']: d['cap'] for d in dataset} 22 | 23 | keys = list(image_files.keys() & text_files.keys()) 24 | 25 | print(keys) 26 | 27 | for key in keys: 28 | input(text_files[key]) 29 | 30 | # for inp in dataloader: 31 | # input(inp) 32 | # for key in inp: 33 | # print(inp[key]) 34 | # input() -------------------------------------------------------------------------------- /utilities/wds_read.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | from PIL import Image 3 | import io 4 | import matplotlib.pyplot as plt 5 | 6 | # dataset = wds.WebDataset("dataset.tar.gz") 7 | dataset = wds.WebDataset("./shards/ds_000000.tar") 8 | 9 | for i, d in enumerate(dataset): 10 | # print(i) 11 | # print(d) 12 | # input() 13 | 14 | print(d.keys()) 15 | 16 | # print(d['__key__']) 17 | # print(d['cap']) 18 | # if 'json' in d.keys(): 19 | # print(d['json']) 20 | 21 | # print(d['label'].decode('utf-8')) 22 | # image = Image.open(io.BytesIO(d['image'])) 23 | # plt.show() 24 | # plt.imshow(image) 25 | # plt.show() 26 | input() --------------------------------------------------------------------------------