├── .gitignore ├── COCO ├── README.md └── download.sh ├── CVPR ├── README.md └── download.sh ├── ConceptualCaptions ├── README.md └── download.sh ├── Flickr30k └── download.sh ├── Food-101 ├── README.md └── download.sh ├── OpenImages ├── README.md ├── download.sh ├── ids_to_download.txt.gz └── prepare.py ├── README.md ├── STL-10 ├── README.md └── download.sh ├── SVHN ├── README.md └── download.sh ├── VisualGenome ├── README.md ├── download.sh └── get_captions.py ├── WIT ├── README.md ├── download.sh ├── download_all.sh └── download_wit_500k.sh ├── dalle_pytorch_datasets.ipynb ├── download-template.sh └── generate_captions.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | -------------------------------------------------------------------------------- /COCO/README.md: -------------------------------------------------------------------------------- 1 | # COCO 2017 (~200,000 images) 2 | Training set downscaled to 256px. Captions in the folder. 3 | -------------------------------------------------------------------------------- /COCO/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | wget https://www.dropbox.com/s/dtjjz9cpenmpowr/train2017.zip; 8 | unzip -q train2017.zip; 9 | -------------------------------------------------------------------------------- /CVPR/README.md: -------------------------------------------------------------------------------- 1 | # CVPR Indoor Scene Recognition 2 | 3 | 4 | ## Details 5 | - 2.59 GiB 6 | - 15,620 images 7 | - captions need to be generated. 8 | 9 | -------------------------------------------------------------------------------- /CVPR/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | aria2c https://academictorrents.com/download/59aa0ad684e5d849f68bad9a6d43a9000a927164.torrent; # indoor CVPR 8 | 9 | 10 | mkdir output; 11 | 12 | #tar -xf food-101.tgz --directory=output; 13 | # extract indoorCVPR 14 | tar -xf indoorCVPR_09.tar --directory=output; 15 | rm indoorCVPR_09.tar; 16 | mv output/Images/ indoorCVPR/; # change the name to correct dataset folder 17 | -------------------------------------------------------------------------------- /ConceptualCaptions/README.md: -------------------------------------------------------------------------------- 1 | # Conceptual Captions (2,902,649 images) 2 | 3 | These are jpeg compressed at level 75% then gzip compressed with level 9 to get the storage size from 400 GiB to around 50 GiB. 4 | 5 | I made the mistake of downscaling some of these more than necessary because I didn't have the right ImageMagick parameters set. Images were downsized to one dimension being 320px, meaning images with a non-square aspect ratio can dip below 256px (i.e. 320x240). Sorry about that. 6 | 7 | captions: Haven't preprocessed these yet. 8 | -------------------------------------------------------------------------------- /ConceptualCaptions/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | wget https://www.dropbox.com/s/yxibearxqgkh58k/conceptual_captions_train_256.zip; 8 | unzip -q conceptual_captions_train_256.zip; 9 | -------------------------------------------------------------------------------- /Flickr30k/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | -------------------------------------------------------------------------------- /Food-101/README.md: -------------------------------------------------------------------------------- 1 | # food-101 2 | 3 | 4 | -------------------------------------------------------------------------------- /Food-101/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | if [[ ! -v OUTPUT_PATH]]; then 8 | echo "Using default output path." 9 | OUTPUT_PATH=../output 10 | else 11 | echo "Using provided OUTPUT_PATH: $OUTPUT_PATH" 12 | fi 13 | 14 | 15 | 16 | aria2c https://academictorrents.com/download/470791483f8441764d3b01dbc4d22b3aa58ef46f.torrent; # food-101 17 | 18 | tar -xf food-101.tgz --directory=$OUTPUT_PATH; 19 | rm food-101.tgz; 20 | -------------------------------------------------------------------------------- /OpenImages/README.md: -------------------------------------------------------------------------------- 1 | # OpenImagesV6 with "Localized Annotations" (504,416 images) 2 | 3 | This is a pretty clean dataset and works decently in my efforts. We need more, but you'll be able to see it generalize to a (very limited) degree after a full epoch. 4 | 5 | Some of the captions will exceed the token length of 256. **You will need to trim these or throw them out in your DataLoader.** 6 | -------------------------------------------------------------------------------- /OpenImages/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | mkdir output; 8 | cd output; 9 | 10 | # Images 11 | wget https://www.dropbox.com/s/dpqqewh7kum36vp/open_images_256.tar.gz; 12 | tar xf open_images_256.tar.gz; 13 | rm -rf open_images_256.tar.gz; 14 | 15 | # Captions 16 | wget https://www.dropbox.com/s/o33hxj3azn185sw/open_images_v6_captions_1.tar.gz; 17 | tar xf open_images_v6_captions_1.tar.gz; 18 | rm -rf open_images_v6_captions_1.tar.gz; 19 | 20 | 21 | cd ../; 22 | -------------------------------------------------------------------------------- /OpenImages/ids_to_download.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afiaka87/dalle-pytorch-datasets/3ab59ce8517ed19582844e277819edd1dee1ee01/OpenImages/ids_to_download.txt.gz -------------------------------------------------------------------------------- /OpenImages/prepare.py: -------------------------------------------------------------------------------- 1 | # python file to get the captions 2 | 3 | import io 4 | import os 5 | import json 6 | import random 7 | import pandas as pd 8 | from tqdm import trange 9 | from ast import literal_eval 10 | from argparse import ArgumentParser 11 | 12 | # declare URLS 13 | CLASSES = "https://storage.googleapis.com/openimages/2018_04/class-descriptions-boxable.csv" 14 | RELATION_LABELS = "https://storage.googleapis.com/openimages/2019_01/challenge-2018-relationships-description.csv" 15 | ATTRIBUTES = "https://storage.googleapis.com/openimages/2019_01/challenge-2018-attributes-description.csv" 16 | 17 | IMAGE_LABELS = { 18 | "train": "https://storage.googleapis.com/openimages/2018_04/train/train-annotations-human-imagelabels-boxable.csv", 19 | "val": "https://storage.googleapis.com/openimages/2018_04/validation/validation-annotations-human-imagelabels-boxable.csv", 20 | "test": "https://storage.googleapis.com/openimages/2018_04/test/test-annotations-human-imagelabels-boxable.csv" 21 | } 22 | NARRATIVES = { 23 | "train": "https://storage.googleapis.com/localized-narratives/annotations/open_images_train_v6_captions.jsonl", 24 | "val": "https://storage.googleapis.com/localized-narratives/annotations/open_images_validation_captions.jsonl", 25 | "test": "https://storage.googleapis.com/localized-narratives/annotations/open_images_test_captions.jsonl" 26 | } 27 | 28 | # https://storage.googleapis.com/openimages/web/download_v4.html 29 | # NOTE: we will provide visual relationships annotations on the test and validation sets soon - stay tuned! 30 | RELATIONS_TRAIN = "https://storage.googleapis.com/openimages/2019_01/train/challenge-2018-train-vrd.csv" 31 | 32 | 33 | # ---- functions 34 | def fetch(url): 35 | # https://github.com/geohot/tinygrad/blob/master/extra/utils.py 36 | import requests, os, hashlib, tempfile 37 | fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest()) 38 | if os.path.isfile(fp) and os.stat(fp).st_size > 0: 39 | with open(fp, "rb") as f: 40 | dat = f.read() 41 | else: 42 | print("fetching %s" % url) 43 | dat = requests.get(url).content 44 | with open(fp+".tmp", "wb") as f: 45 | f.write(dat) 46 | os.rename(fp+".tmp", fp) 47 | return dat 48 | 49 | 50 | def get_data_narratives(merged_labels, split = "train"): 51 | # takes the merged_labels and split and returns a dictionary object 52 | if split not in ["test", "train", "val"]: 53 | raise ValueError("split should be one of `test`, `train`, `val`") 54 | 55 | print("-"*70) 56 | print("Starting Process for", split.upper()) 57 | 58 | # this process creates temporary fetch files because we don't really need to store those 59 | # labels 60 | image_labels = pd.read_csv(io.BytesIO(fetch(IMAGE_LABELS[split]))) 61 | image_labels.LabelName = [merged_labels[x] for x in image_labels.LabelName.values] 62 | 63 | # narratives 64 | narratives = [literal_eval(x) for x in fetch(NARRATIVES[split]).decode("utf-8").split("\n")[:-1]] 65 | narratives_train_by_id = {} 66 | for x in narratives: 67 | narratives_train_by_id[x["image_id"]] = x["caption"] 68 | 69 | # now we need to match the image ids with the ones in our labels 70 | img_ids_train = set(image_labels.ImageID.unique().tolist()) 71 | print(f"[{split.upper()}] Total images in labels:", len(img_ids_train)) 72 | img_ids_train_narratives = set(narratives_train_by_id.keys()) 73 | print(f"[{split.upper()}] Total images in narratives:", len(img_ids_train_narratives)) 74 | common = img_ids_train_narratives.intersection(img_ids_train) 75 | print(f"[{split.upper()}] Total images in common:", len(common)) 76 | 77 | # convert to target format 78 | all_data = [] 79 | common = list(common) 80 | for _, _id in zip(trange(len(common)), common): 81 | df_sub = image_labels[image_labels.ImageID == _id] 82 | _d = { 83 | "labels": df_sub.LabelName.values.tolist(), 84 | "score": df_sub.Confidence.values.tolist(), 85 | "caption": narratives_train_by_id[_id], 86 | "source_split": split, 87 | "original_language": "en", # all narratives are in english so we can hardcode 88 | "dataset": "open_images_v4" 89 | } 90 | all_data.append(_d) 91 | 92 | # print two samples for user 93 | print("-"*70) 94 | print("Two Samples:\n") 95 | for _ in range(2): 96 | print(random.choice(all_data)) 97 | print() 98 | 99 | return all_data 100 | 101 | 102 | if __name__ == "__main__": 103 | args = ArgumentParser(description="""Get captions for OpenImagesv4.""") 104 | args.add_argument( 105 | "--only-narratives", 106 | action = "store_true", 107 | default = False, 108 | help = "If passed file only loads the captions from Google's Localized Narratives " 109 | "(google.github.io/localized-narratives/)" 110 | ) 111 | args.add_argument( 112 | "--target-path", 113 | default = "./", 114 | help = "Pass the path where you want to store (def: `./`)" 115 | ) 116 | args = args.parse_args() 117 | 118 | # while the captions are not ready we are only using narratives 119 | if not args.only_narratives: 120 | raise ValueError("Can only load narratives right now") 121 | 122 | # define things that will be used for each split 123 | classes = pd.read_csv(io.BytesIO(fetch(CLASSES)), names = ["id", "name"]) 124 | class_labels = {} 125 | for x in json.loads(classes.to_json(orient = "records")): 126 | class_labels.update({x["id"]: x["name"]}) 127 | 128 | attributes = pd.read_csv(io.BytesIO(fetch(ATTRIBUTES)), names = ["id", "name"]) 129 | attributes_labels = {} 130 | for x in json.loads(attributes.to_json(orient = "records")): 131 | attributes_labels.update({x["id"]: x["name"]}) 132 | 133 | # In the relations the LabeleName2 can have either classes or attributes 134 | # thus create a common merged dict 135 | merged_labels = {} 136 | merged_labels.update(class_labels) 137 | merged_labels.update(attributes_labels) 138 | 139 | # now go over different splits and get the output 140 | for split in IMAGE_LABELS: 141 | data = get_data_narratives(merged_labels, split) 142 | path = os.path.join(args.target_path, f"open_images_{split}.json") 143 | print("Writing file:", path) 144 | with open(path) as f: 145 | f.write(json.dumps(data)) 146 | 147 | print("Completed Process") 148 | print("-" * 70) 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (Archived) 2 | 3 | ## @robvanvolt has created a much more fleshed out version here: https://github.com/robvanvolt/DALLE-datasets 4 | 5 | None of this code works yet. If you'd like to contribute, create a pull request! We need all the datasets we can get. Otherwise come back in a few weeks to check on progress. 6 | 7 | This repository includes metadata and instructions for downloading many captioned datasets + generated captions from labels. 8 | 9 | Thanks to @yashbonde, we eventually intend to include generated captions for a variety of datasets that don't include captions. 10 | 11 | ## Data Format 12 | 13 | Since this is a highly versatile dataset we have a common format for each sample: 14 | ``` 15 | { 16 | "image_id": { 17 | "labels": ["car", "chair", "something else"], 18 | "score": [0, 1, 1], 19 | "caption": "caption goes here", 20 | "dataset": "open_images_v4" 21 | "source_split": "train", 22 | "original_language": "eng", 23 | } 24 | } 25 | ``` 26 | 27 | * `image_id`: this will be expanded to the complete filepath when training 28 | * `labels`: in case the given images has labels add those here, **default is `None`** 29 | * `score`: in case there is a score against that labels eg. OpenImages, **default is `None`** 30 | * `caption`: generated caption goes here 31 | * `source_split`: what split was this a part of in the datasset it is of 32 | * `dataset`: key of the dataset name 33 | * `original_language`: in case this has multilingual dataset use [ISO-639-2 code](https://en.wikipedia.org/wiki/List_of_ISO_639-2_codes) 34 | 35 | 36 | ## Datasets 37 | 38 | |name|size|image count|link|used for VAE|captions given|captions generated| 39 | |-|-|-|-|-|-|-| 40 | |Downscale OpenImagesv4|16GB|1.9M|[torrent](https://academictorrents.com/details/9208d33aceb2ca3eb2beb70a192600c9c41efba1)|✅| | | 41 | |Stanford STL-10|2.64GB|113K|[torrent](https://academictorrents.com/details/a799a2845ac29a66c07cf74e2a2838b6c5698a6a)|✅| | | 42 | |CVPR Indoor Scene Recognition|2.59GB|15620|[torrent](https://academictorrents.com/details/59aa0ad684e5d849f68bad9a6d43a9000a927164)|✅| | | 43 | |The Visual Genome Dataset v1.0 + v1.2 Images|15.20GB|108K|[torrent](https://academictorrents.com/details/1bfe6871046860a2ff8c0cc1414318beb35dc916)|✅|✅| | 44 | |Food-101|5.69GB|101K|[torrent](https://academictorrents.com/details/470791483f8441764d3b01dbc4d22b3aa58ef46f)|✅| | | 45 | |The Street View House Numbers (SVHN) Dataset|2.64GB|600K|[torrent](https://academictorrents.com/details/6f4caf3c24803d114c3cae3ab9cb946cd23c7213)|✅| | | 46 | |Downsampled ImageNet 64x64|12.59GB|1.28M|[torrent](https://academictorrents.com/details/96816a530ee002254d29bf7a61c0c158d3dedc3b)|✅| | 47 | |COCO 2017|52.44GB|287K|[torrent](https://academictorrents.com/details/74dec1dd21ae4994dfd9069f9cb0443eb960c962) [website](https://cocodataset.org/#download)| | | 48 | |Flickr 30k Captions (bad data, downloads duplicates)|8GB|31K|[kaggle](https://www.kaggle.com/hsankesara/flickr-image-dataset)| |✅| | 49 | 50 | ## Other Projects 51 | 52 | This a big community led effort, find more projects: 53 | * [`DALLE-pytorch`](https://github.com/lucidrains/DALLE-pytorch/) 54 | * [`dall-e-baby`](https://github.com/yashbonde/dall-e-baby) 55 | 56 | ## Connect with us 57 | 58 | You can join the [discord](https://discord.gg/hBtKR6JF) for direct communication. 59 | -------------------------------------------------------------------------------- /STL-10/README.md: -------------------------------------------------------------------------------- 1 | # STL-10 2 | 3 | 4 | -------------------------------------------------------------------------------- /STL-10/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | mkdir output; 8 | 9 | aria2c https://academictorrents.com/download/a799a2845ac29a66c07cf74e2a2838b6c5698a6a.torrent; # STL-10 10 | 11 | 12 | # extract STL-10 13 | tar -xf stl10_binary.tar.gz --directory=output; 14 | -------------------------------------------------------------------------------- /SVHN/README.md: -------------------------------------------------------------------------------- 1 | # The Street View House Numbers (SVHN) Dataset 2 | 3 | 4 | Size: 2.64 GiB 5 | 6 | Images: 600K 7 | 8 | Captions: Need to be generated. 9 | -------------------------------------------------------------------------------- /SVHN/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | if [[ ! -v OUTPUT_PATH]]; then 8 | echo "Using default output path." 9 | OUTPUT_PATH=../output 10 | else 11 | echo "Using provided OUTPUT_PATH: $OUTPUT_PATH" 12 | fi 13 | 14 | 15 | aria2c https://academictorrents.com/download/6f4caf3c24803d114c3cae3ab9cb946cd23c7213.torrent; # SVHN 16 | 17 | # extracting SVHN 18 | for i in extra.tar.gz test.tar.gz train.tar.gz 19 | do 20 | echo "Untarring: $i"; 21 | tar -xf $i --directory=$OUTPUT_PATH; 22 | done 23 | 24 | -------------------------------------------------------------------------------- /VisualGenome/README.md: -------------------------------------------------------------------------------- 1 | # The Visual Genome Dataset v1.0 + v1.2 Images 2 | 3 | Size: 15.20 GiB 4 | Images: 108K 5 | -------------------------------------------------------------------------------- /VisualGenome/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail; 5 | IFS=$'\n\t'; 6 | 7 | if [[ ! -v OUTPUT_PATH ]]; then 8 | echo "Using default output path." 9 | OUTPUT_PATH=../output; 10 | else 11 | echo "Using provided OUTPUT_PATH: $OUTPUT_PATH" 12 | fi 13 | 14 | mkdir -p $OUTPUT_PATH 15 | cd $OUTPUT_PATH 16 | 17 | ZIP_DIR="VG_100K_2"; 18 | 19 | aria2c --split=4 --lowest-speed-limit=10K --parameterized-uri https://academictorrents.com/download/1bfe6871046860a2ff8c0cc1414318beb35dc916.torrent; 20 | 21 | mv $ZIP_DIR/ $OUTPUT_PATH 22 | 23 | unzip -q images.zip; 24 | rm -rf images.zip; 25 | 26 | unzip -q images2.zip; 27 | rm -rf images2.zip; 28 | 29 | 30 | wget -q "http://visualgenome.org/static/data/dataset/region_descriptions.json.zip"; 31 | unzip -q region_descriptions.json.zip; 32 | mv region_descriptions.json $OUTPUT_PATH; 33 | 34 | 35 | -------------------------------------------------------------------------------- /VisualGenome/get_captions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | # visual Genome captions 5 | def get_genome_captions(root_folder="../output/VG_100K_2"): 6 | reg_des_path = os.path.join(root_folder, "region_descriptions.json") 7 | with open(reg_des_path) as f: 8 | regdes = json.load(f) 9 | 10 | captions = {} 11 | dropped = [] 12 | for item in regdes: 13 | id = item["id"] 14 | path = os.path.join(root_folder, f"VG_100K/{id}.jpg") 15 | if not os.path.exists(path): 16 | path = os.path.join(root_folder, f"VG_100K_2/{id}.jpg") 17 | if not os.path.exists(path): 18 | dropped.append(id) 19 | continue 20 | captions["genome_"+str(item["id"])] = { 21 | "caption":" ".join([x["phrase"] for x in item["regions"]]), 22 | "path": path 23 | } 24 | 25 | return captions, dropped 26 | 27 | -------------------------------------------------------------------------------- /WIT/README.md: -------------------------------------------------------------------------------- 1 | # (en) Wikipedia Image-Text by Google (5,411,978 image urls) 2 | This will download **only** the list of urls needed to download the images themselves. 3 | 4 | 5 | # (en) 'WIT 94K' Compressed (~94k images/~5GiB) 6 | 94,000 WIT images resized to 256 px 7 | -------------------------------------------------------------------------------- /WIT/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | wget https://www.dropbox.com/s/ge29pmic7blwkkv/urls_en.tar.gz; 8 | 9 | tar -xf url.en.tar.gz; 10 | cd url; 11 | cat *.txt > all_urls.txt; 12 | mkdir OUTPUT; 13 | aria2c --auto-file-renaming false --conditional-get --no-file-allocation-limit=1K -P true --deferred-input true --optimize-concurrent-downloads -i ./all_urls.txt -d ./OUTPUT; 14 | -------------------------------------------------------------------------------- /WIT/download_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | wget https://www.dropbox.com/s/ge29pmic7blwkkv/urls_en.tar.gz; 8 | 9 | tar -xf url.en.tar.gz; 10 | cd url; 11 | cat *.txt > all_urls.txt; 12 | mkdir OUTPUT; 13 | aria2c --auto-file-renaming false --conditional-get --no-file-allocation-limit=1K -P true --deferred-input true --optimize-concurrent-downloads -i ./all_urls.txt -d ./OUTPUT; 14 | -------------------------------------------------------------------------------- /WIT/download_wit_500k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | 7 | URL="https://www.dropbox.com/s/6ptycnlypb1psp7/wit_en_94k_256px.tar.gz" 8 | 9 | wget $URL 10 | 11 | -------------------------------------------------------------------------------- /dalle_pytorch_datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "dalle-pytorch-datasets.ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "g8H8qQkDfR_3" 22 | }, 23 | "source": [ 24 | "#@title Setup\n", 25 | "\n", 26 | "%pip install dalle-pytorch\n", 27 | "%pip install wandb\n", 28 | "\n", 29 | "%pip install gdown\n", 30 | "!wget https://github.com/aria2/aria2/releases/download/release-1.35.0/aria2-1.35.0.tar.gz\n", 31 | "!tar -xf aria2-1.35.0.tar.gz\n", 32 | "%cd \"aria2-1.35.0\"\n", 33 | "!./configure && make && make install\n", 34 | "%cd /content\n", 35 | "\n", 36 | "# from google.colab import drive\n", 37 | "# drive.mount('/content/drive')" 38 | ], 39 | "execution_count": null, 40 | "outputs": [] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "metadata": { 45 | "id": "I1LnVLuJmpXV" 46 | }, 47 | "source": [ 48 | "#@title Download COCO2017, VirtualGenome, and WIT_380k\n", 49 | "!wget https://www.dropbox.com/s/qrufasneaduttmc/wit_en_small_380k.tar.gz\n", 50 | "!wget https://www.dropbox.com/s/dtjjz9cpenmpowr/train2017.zip\n", 51 | "!wget https://www.dropbox.com/s/30sdd7f9m2nuii2/virtual_genome_captions.tar.gz" 52 | ], 53 | "execution_count": null, 54 | "outputs": [] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "zZLfLHjeiZQn" 60 | }, 61 | "source": [ 62 | "# from google.colab import drive\n", 63 | "# drive.mount('/content/drive')" 64 | ], 65 | "execution_count": null, 66 | "outputs": [] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "metadata": { 71 | "id": "G5UZxD8Dh9-N" 72 | }, 73 | "source": [ 74 | "# !git clone \"https://github.com/afiaka87/dalle-pytorch-datasets.git\"" 75 | ], 76 | "execution_count": null, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "rovjjP-Sh9pv" 83 | }, 84 | "source": [ 85 | "image_text_folder = \"/content/output\" #@param\n", 86 | "import wandb\n", 87 | "!wandb login" 88 | ], 89 | "execution_count": null, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "metadata": { 95 | "id": "FagaS4d-5yJk" 96 | }, 97 | "source": [ 98 | "" 99 | ], 100 | "execution_count": null, 101 | "outputs": [] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "metadata": { 106 | "collapsed": true, 107 | "id": "rIp1ESmFg1qV" 108 | }, 109 | "source": [ 110 | "import argparse\n", 111 | "from random import choice, sample\n", 112 | "from pathlib import Path\n", 113 | "\n", 114 | "# torch\n", 115 | "\n", 116 | "import torch\n", 117 | "from torch.optim import Adam\n", 118 | "from torch.nn.utils import clip_grad_norm_\n", 119 | "\n", 120 | "# vision imports\n", 121 | "\n", 122 | "from PIL import Image\n", 123 | "from torchvision import transforms as T\n", 124 | "from torch.utils.data import DataLoader, Dataset\n", 125 | "from torchvision.datasets import ImageFolder\n", 126 | "from torchvision.utils import make_grid, save_image\n", 127 | "\n", 128 | "# dalle related classes and utils\n", 129 | "\n", 130 | "from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE\n", 131 | "from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE\n", 132 | "\n", 133 | "# argument parsing\n", 134 | "\n", 135 | "image_text_folder = \"/content/output\" #@param\n", 136 | "taming = True #@param \n", 137 | "\n", 138 | "# helpers\n", 139 | "\n", 140 | "def exists(val):\n", 141 | " return val is not None\n", 142 | "\n", 143 | "# constants\n", 144 | "\n", 145 | "VAE_PATH = None\n", 146 | "DALLE_PATH = None # 'dalle.pt'\n", 147 | "RESUME = exists(DALLE_PATH)\n", 148 | "\n", 149 | "EPOCHS = 20\n", 150 | "BATCH_SIZE = 8\n", 151 | "LEARNING_RATE = 3e-4\n", 152 | "GRAD_CLIP_NORM = 0.8\n", 153 | "\n", 154 | "MODEL_DIM = 512\n", 155 | "TEXT_SEQ_LEN = 256\n", 156 | "DEPTH = 6\n", 157 | "HEADS = 12\n", 158 | "DIM_HEAD = 64\n", 159 | "REVERSIBLE = False\n", 160 | "\n", 161 | "# reconstitute vae\n", 162 | "\n", 163 | "if RESUME:\n", 164 | " dalle_path = Path(DALLE_PATH)\n", 165 | " assert dalle_path.exists(), 'DALL-E model file does not exist'\n", 166 | "\n", 167 | " loaded_obj = torch.load(str(dalle_path), map_location='cpu')\n", 168 | "\n", 169 | " dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']\n", 170 | "\n", 171 | " if vae_params is not None:\n", 172 | " vae = DiscreteVAE(**vae_params)\n", 173 | " else:\n", 174 | " vae_klass = OpenAIDiscreteVAE if not taming else VQGanVAE1024\n", 175 | " vae = vae_klass()\n", 176 | " \n", 177 | " dalle_params = dict( \n", 178 | " **dalle_params\n", 179 | " )\n", 180 | " IMAGE_SIZE = vae.image_size\n", 181 | "else:\n", 182 | " if exists(VAE_PATH):\n", 183 | " vae_path = Path(VAE_PATH)\n", 184 | " assert vae_path.exists(), 'VAE model file does not exist'\n", 185 | "\n", 186 | " loaded_obj = torch.load(str(vae_path))\n", 187 | "\n", 188 | " vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']\n", 189 | "\n", 190 | " vae = DiscreteVAE(**vae_params)\n", 191 | " vae.load_state_dict(weights)\n", 192 | " else:\n", 193 | " print('using pretrained VAE for encoding images to tokens')\n", 194 | " vae_params = None\n", 195 | "\n", 196 | " vae_klass = OpenAIDiscreteVAE if not taming else VQGanVAE1024\n", 197 | " vae = vae_klass()\n", 198 | "\n", 199 | " IMAGE_SIZE = vae.image_size\n", 200 | "\n", 201 | " dalle_params = dict(\n", 202 | " num_text_tokens = VOCAB_SIZE,\n", 203 | " text_seq_len = TEXT_SEQ_LEN,\n", 204 | " dim = MODEL_DIM,\n", 205 | " depth = DEPTH,\n", 206 | " heads = HEADS,\n", 207 | " dim_head = DIM_HEAD,\n", 208 | " reversible = REVERSIBLE,\n", 209 | " attn_types = ('axial_row', 'axial_col', 'conv_like')\n", 210 | " )\n", 211 | "\n", 212 | "# helpers\n", 213 | "\n", 214 | "def save_model(path):\n", 215 | " save_obj = {\n", 216 | " 'hparams': dalle_params,\n", 217 | " 'vae_params': vae_params,\n", 218 | " 'weights': dalle.state_dict()\n", 219 | " }\n", 220 | "\n", 221 | " torch.save(save_obj, path)\n", 222 | "\n", 223 | "# dataset loading\n", 224 | "\n", 225 | "class TextImageDataset(Dataset):\n", 226 | " def __init__(self, folder, text_len = 256, image_size = 128):\n", 227 | " super().__init__()\n", 228 | " path = Path(folder)\n", 229 | "\n", 230 | " text_files = [*path.glob('**/*.txt')]\n", 231 | "\n", 232 | " image_files = [\n", 233 | " *path.glob('**/*.png'),\n", 234 | " *path.glob('**/*.jpg'),\n", 235 | " *path.glob('**/*.jpeg')\n", 236 | " ]\n", 237 | "\n", 238 | " text_files = {t.stem: t for t in text_files}\n", 239 | " image_files = {i.stem: i for i in image_files}\n", 240 | "\n", 241 | " keys = (image_files.keys() & text_files.keys())\n", 242 | "\n", 243 | " self.keys = list(keys)\n", 244 | " self.text_files = {k: v for k, v in text_files.items() if k in keys}\n", 245 | " self.image_files = {k: v for k, v in image_files.items() if k in keys}\n", 246 | "\n", 247 | " self.image_tranform = T.Compose([\n", 248 | " T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n", 249 | " T.RandomResizedCrop(image_size, scale = (0.9, 1.), ratio = (1., 1.)),\n", 250 | " T.ToTensor()\n", 251 | " ])\n", 252 | "\n", 253 | " def __len__(self):\n", 254 | " return len(self.keys)\n", 255 | "\n", 256 | " def __getitem__(self, ind):\n", 257 | " key = self.keys[ind]\n", 258 | " text_file = self.text_files[key]\n", 259 | " image_file = self.image_files[key]\n", 260 | "\n", 261 | " image = Image.open(image_file)\n", 262 | " descriptions = text_file.read_text().split('\\n')\n", 263 | " descriptions = list(filter(lambda t: len(t) > 0, descriptions))\n", 264 | " description = choice(descriptions)\n", 265 | " description = description.replace(\" \", \" \")\n", 266 | "\n", 267 | " if len(tokenizer.encode(text)) >= 256:\n", 268 | " print(f\"Caption at idx {ind} too long. Selecting a hundred words.\")\n", 269 | " description = \" \".join(description.replace(\" \", \" \").split(\" \")[:100])\n", 270 | "\n", 271 | "\n", 272 | " try:\n", 273 | " tokenized_text = tokenize(description).squeeze(0)\n", 274 | " except Exception:\n", 275 | " print(f\"Tokenized text failed at index {ind}. Returning {ind+1} instead.\")\n", 276 | " if ind < self.__len__() - 1:\n", 277 | " return self.__getitem__(ind+1)\n", 278 | " else:\n", 279 | " return self.__getitem__(ind-1)\n", 280 | "\n", 281 | " mask = tokenized_text != 0\n", 282 | " image_tensor = self.image_tranform(image)\n", 283 | " return tokenized_text, image_tensor, mask\n", 284 | "\n", 285 | "# create dataset and dataloader\n", 286 | "\n", 287 | "ds = TextImageDataset(\n", 288 | " image_text_folder,\n", 289 | " text_len = TEXT_SEQ_LEN,\n", 290 | " image_size = IMAGE_SIZE\n", 291 | ")\n", 292 | "\n", 293 | "assert len(ds) > 0, 'dataset is empty'\n", 294 | "print(f'{len(ds)} image-text pairs found for training')\n", 295 | "\n", 296 | "dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)\n", 297 | "\n", 298 | "# initialize DALL-E\n", 299 | "\n", 300 | "dalle = DALLE(vae = vae, **dalle_params).cuda()\n", 301 | "\n", 302 | "if RESUME:\n", 303 | " dalle.load_state_dict(weights)\n", 304 | "\n", 305 | "# optimizer\n", 306 | "\n", 307 | "opt = Adam(dalle.parameters(), lr = LEARNING_RATE)\n", 308 | "\n", 309 | "# experiment tracker\n", 310 | "\n", 311 | "import wandb\n", 312 | "\n", 313 | "model_config = dict(\n", 314 | " depth = DEPTH,\n", 315 | " heads = HEADS,\n", 316 | " dim_head = DIM_HEAD\n", 317 | ")\n", 318 | "\n", 319 | "run = wandb.init(project = 'dalle-pytorch-datasets', resume = RESUME, config = model_config)\n", 320 | "\n", 321 | "# training\n", 322 | "\n", 323 | "for epoch in range(EPOCHS):\n", 324 | " for i, (text, images, mask) in enumerate(dl):\n", 325 | " text, images, mask = map(lambda t: t.cuda(), (text, images, mask))\n", 326 | "\n", 327 | " loss = dalle(text, images, mask = mask, return_loss = True)\n", 328 | "\n", 329 | " loss.backward()\n", 330 | " clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)\n", 331 | "\n", 332 | " opt.step()\n", 333 | " opt.zero_grad()\n", 334 | "\n", 335 | " log = {}\n", 336 | "\n", 337 | " if i % 10 == 0:\n", 338 | " print(epoch, i, f'loss - {loss.item()}')\n", 339 | "\n", 340 | " log = {\n", 341 | " **log,\n", 342 | " 'epoch': epoch,\n", 343 | " 'iter': i,\n", 344 | " 'loss': loss.item()\n", 345 | " }\n", 346 | "\n", 347 | " if i % 100 == 0:\n", 348 | " sample_text = text[:1]\n", 349 | " token_list = sample_text.masked_select(sample_text != 0).tolist()\n", 350 | " decoded_text = tokenizer.decode(token_list)\n", 351 | "\n", 352 | " image = dalle.generate_images(\n", 353 | " text[:1],\n", 354 | " mask = mask[:1],\n", 355 | " filter_thres = 0.9 # topk sampling at 0.9\n", 356 | " )\n", 357 | "\n", 358 | " save_model(f'./dalle.pt')\n", 359 | " wandb.save(f'./dalle.pt')\n", 360 | "\n", 361 | " log = {\n", 362 | " **log,\n", 363 | " 'image': wandb.Image(image, caption = decoded_text)\n", 364 | " }\n", 365 | "\n", 366 | " wandb.log(log)\n", 367 | "\n", 368 | " # save trained model to wandb as an artifact every epoch's end\n", 369 | "\n", 370 | " model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))\n", 371 | " model_artifact.add_file('dalle.pt')\n", 372 | " run.log_artifact(model_artifact)\n", 373 | "\n", 374 | "save_model(f'./dalle-final.pt')\n", 375 | "wandb.save('./dalle-final.pt')\n", 376 | "model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))\n", 377 | "model_artifact.add_file('dalle-final.pt')\n", 378 | "run.log_artifact(model_artifact)\n", 379 | "\n", 380 | "wandb.finish()" 381 | ], 382 | "execution_count": null, 383 | "outputs": [] 384 | } 385 | ] 386 | } -------------------------------------------------------------------------------- /download-template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable strict mode. 4 | set -euo pipefail 5 | IFS=$'\n\t' 6 | -------------------------------------------------------------------------------- /generate_captions.py: -------------------------------------------------------------------------------- 1 | # Script to prepare the captions.json object that will be used for training transformer 2 | 3 | import os 4 | import re 5 | import h5py 6 | import json 7 | import pandas as pd 8 | from discrete_vae import * 9 | from tabulate import tabulate 10 | 11 | 12 | # ---- Datasets where captions are already given 13 | 14 | # coco captions 15 | def get_coco_captions(captions_path): 16 | with open(captions_path, "r") as f: 17 | cap2017 = json.load(f) 18 | captions = {} 19 | dropped = [] 20 | for x in cap2017["annotations"]: 21 | id_ = str(x["image_id"]) 22 | id_ = "0"*(12-len(id_))+id_ 23 | path = "../fast-ai-coco/train2017/"+id_+".jpg" 24 | if not os.path.exists(path): 25 | dropped.append(path) 26 | continue 27 | 28 | key = "coco_"+str(x["image_id"]) 29 | captions.setdefault(key, { 30 | "caption": "", 31 | "path": path 32 | }) 33 | captions[key]["caption"] += " " + x["caption"] 34 | return captions, dropped 35 | 36 | 37 | # visual Genome captions 38 | def get_genome_captions(root_folder = "../VG_100K_2"): 39 | with open(f"{root_folder}/region_descriptions.json", "r") as f: 40 | regdes = json.load(f) 41 | 42 | captions = {} 43 | dropped = [] 44 | for item in regdes: 45 | id = item["id"] 46 | path = f"../VG_100K_2/VG_100K/{id}.jpg" 47 | if not os.path.exists(path): 48 | path = f"../VG_100K_2/VG_100K_2/{id}.jpg" 49 | if not os.path.exists(path): 50 | dropped.append(id) 51 | continue 52 | captions["genome_"+str(item["id"])] = { 53 | "caption":" ".join([x["phrase"] for x in item["regions"]]), 54 | "path": path 55 | } 56 | 57 | return captions, dropped 58 | 59 | # Flickr30k captions 60 | def get_flickr30k_captions(rf="../flickr30k_images"): 61 | data = pd.read_csv(f"{rf}/flickr30k_images/results.csv", sep="|") 62 | captions = {} 63 | dropped = [] 64 | for idx, (img_id, df_sub) in enumerate(data.groupby("image_name")): 65 | path = f"../flickr30k_images/flickr30k_images/{img_id}" 66 | if not os.path.exists(path): 67 | dropped.append(path) 68 | continue 69 | captions[f"flickr_{idx}"] = { 70 | "caption": " ".join([str(x) for x in df_sub[" comment"].values.tolist()]), 71 | "path": path 72 | } 73 | return captions, dropped 74 | 75 | 76 | # ---- Datasets where only labels are given so we have to generate captions for this 77 | 78 | def get_open_images_label_names(): 79 | with open("../downsampled-open-images-v4/class-descriptions-boxable.csv", "r") as f: 80 | open_image_labels = {x.split(",")[0]: x.split(",")[1] for x in f.read().split("\n") if len(x)} 81 | return open_image_labels 82 | 83 | 84 | def get_open_images_labels(annotations_path): 85 | open_image_labels = get_open_images_label_names() 86 | df = pd.read_csv(annotations_path) 87 | image_to_labels = {} 88 | dropped = [] 89 | pbar = trange(len(df.ImageID.unique())) 90 | path_f = "../downsampled-open-images-v4/256px/" 91 | if "validation" in annotations_path: 92 | path_f += "validation/" 93 | elif "train" in annotations_path: 94 | path_f += "train-256/" 95 | elif "test" in annotations_path: 96 | path_f += "test/" 97 | for _, (img_id, df_sub) in zip(pbar, df.groupby("ImageID")): 98 | path = f"{path_f}{img_id}.jpg" 99 | pbar.set_description(f"Loading {path[::-1][:40][::-1]}") 100 | high_conf = df_sub[df_sub.Confidence == 1].LabelName.values.tolist() 101 | low_conf = df_sub[df_sub.Confidence != 1].LabelName.values.tolist() 102 | if not high_conf or not os.path.exists(path): 103 | dropped.append(img_id) 104 | image_to_labels["open_images_" + img_id] = { 105 | "label": [ 106 | [open_image_labels[x] for x in high_conf], 107 | [open_image_labels[x] for x in low_conf] 108 | ], 109 | "path": path 110 | } 111 | return image_to_labels, dropped 112 | 113 | def get_indoor_cvpr(rf= "../indoor/"): 114 | indoor = get_images_in_folder(rf) 115 | img2label = {idx: { 116 | "label": [x.split("/")[2].replace("_", " ").title()], 117 | "path": x 118 | } for idx, x in enumerate(indoor)} 119 | return img2label 120 | 121 | 122 | def get_food(rf="../food-101/"): 123 | food = get_images_in_folder(rf) 124 | img2label = {idx: { 125 | "label": [x.split("/")[3].replace("_", " ").title()], 126 | "path": x 127 | } for idx, x in enumerate(food)} 128 | return img2label 129 | 130 | 131 | def get_stl10(bin_file = "../stl10/stl10_binary/train_y.bin"): 132 | classes = ["airplane", "bird", "car", "cat", "deer", "dog", "horse", "monkey", "ship", "truck"] 133 | with open(bin_file, 'rb') as fobj: 134 | # read whole file in uint8 chunks 135 | everything = np.fromfile(fobj, dtype=np.uint8) 136 | labels = [[classes[x - 1]] for x in everything] 137 | 138 | # sort the images in the STL10 that are already parsed 139 | stl10 = [x for x in get_images_in_folder("../stl10/stl10_binary/") if "train" in x] 140 | imgs = {int(x.split("_")[-1].split(".")[0]): x for x in stl10} 141 | img2label = { 142 | f"stl_{k}":{ 143 | "path": imgs[k], 144 | "label": l 145 | } for k,l in zip(imgs, labels) 146 | } 147 | 148 | return img2label 149 | 150 | 151 | def get_svhn_data(matfile = '../housenumbers/train/digitStruct.mat'): 152 | def readInt(intArray, dsFile): 153 | intRef = intArray[0] 154 | isReference = isinstance(intRef, h5py.Reference) 155 | intVal = 0 156 | if isReference: 157 | intObj = dsFile[intRef] 158 | intVal = int(intObj[0]) 159 | else: # Assuming value type 160 | intVal = int(intRef) 161 | return intVal 162 | 163 | digitmat = h5py.File(matfile, 'r') 164 | 165 | print("Loading labels:") 166 | digit_struct = digitmat["digitStruct"] 167 | labels_to_return = [] 168 | for _, box in zip(trange(digit_struct["bbox"].shape[0]), digit_struct["bbox"]): 169 | bbox = digitmat[box[0]] # load bbox using reference 170 | labels = bbox["label"] # bbox object has the data for bounding box and labels 171 | lbl = "".join([ 172 | str(readInt(l, digitmat)) 173 | for l in labels 174 | ]) # create the label string by iterating over all the bboxes 175 | labels_to_return.append(lbl) 176 | 177 | # now replicate for names 178 | print("Loading filenames:") 179 | names = [] 180 | for _, name in zip(trange(digit_struct["name"].shape[0]), digit_struct["name"]): 181 | name = ''.join(chr(i) for i in digitmat[name[0]]) 182 | names.append(name) 183 | 184 | # create final mapping 185 | dropped =[] 186 | img2label = {} 187 | for name, label in zip(names, labels_to_return): 188 | path = f"../housenumbers/train/{name}" 189 | if not os.path.exists(path): 190 | dropped.append(path) 191 | continue 192 | img2label[f"housenumber_{name.split('.')[0]}"] = { 193 | "path": path, 194 | "label": [label] 195 | } 196 | 197 | return img2label, dropped 198 | 199 | # ---- Captions are generated using CaptionsGenerator 200 | 201 | class CaptionGenerator(): 202 | templates_labels = [ 203 | "a picture of {}", 204 | "a photo that has {}", 205 | "photo consisting of {}", 206 | "a low resolution photo of {}", 207 | "small photo of {}", 208 | "high resolution picture of {}", 209 | "low resolution picture of {}", 210 | "high res photo that has {}", 211 | "low res photo of {}", 212 | "{} in a photo", 213 | "{} in a picture", 214 | "rendered picture of {}", 215 | "jpeg photo of {}", 216 | "a cool photo of {}", 217 | "{} rendered in a picture", 218 | ] 219 | 220 | templates_maybe = [ 221 | *[x + " and maybe containing {}" for x in templates_labels], 222 | *[x + " and possibly containing {}" for x in templates_labels], 223 | *[x + " and {} but not sure" for x in templates_labels], 224 | *[x + " also roughly {}" for x in templates_labels], 225 | ] 226 | 227 | templates_indoor = [ 228 | "indoor picture of {}", 229 | "picture inside of {}", 230 | "picture of {} from inside", 231 | ] 232 | 233 | templates_food = [ 234 | "picture of {}, a food item", 235 | "photo of food {}", 236 | "nice photo of food {}", 237 | "picture of food item {}", 238 | "picture of dish {}", 239 | "picture of {}, a food dish", 240 | "gourmet food {}", 241 | ] 242 | 243 | templates_svhn = [ 244 | "a picture of house number '{}'", 245 | "number '{}' written in front of a house", 246 | "street house number '{}' written on a door", 247 | "a photo with number '{}' written in it", 248 | "number '{}' written on a door", 249 | "photograph of number '{}'" 250 | ] 251 | 252 | captions_templates = { 253 | "open_images": [templates_labels, templates_maybe], 254 | "indoor": [templates_labels, templates_indoor], 255 | "food": [templates_labels, templates_food], 256 | "svhn": [templates_svhn], 257 | "stl": [templates_labels] 258 | } 259 | 260 | def __init__(self): 261 | self.ds_names = list(self.captions_templates.keys()) 262 | 263 | def generate_open_images_caption(self, ds): 264 | temps_high, temps_low = self.captions_templates["open_images"] 265 | captions = {} 266 | for i,k in enumerate(ds): 267 | high_conf = ", ".join(ds[k]["label"][0]) 268 | if np.random.random() > 0.5: 269 | low_conf = ", ".join(ds[k]["label"][1]) 270 | temp = np.random.choice(temps_low, size=1)[0] 271 | cap = temp.format(high_conf, low_conf) 272 | else: 273 | temp = np.random.choice(temps_high, size = 1)[0] 274 | cap = temp.format(high_conf) 275 | cap = re.sub(r"\s+", " ", cap).strip().lower() 276 | captions["open_images_" + str(k)] = { 277 | "path": ds[k]["path"], 278 | "caption": cap 279 | } 280 | return captions 281 | 282 | def generate_captions(self, ds, ds_name): 283 | print("Generating captions for", ds_name) 284 | if ds_name not in self.ds_names: 285 | raise ValueError(f"{ds_name} not in {self.ds_names}") 286 | 287 | if ds_name == "open_images": 288 | return self.generate_open_images_caption(ds) 289 | 290 | temps = [] 291 | for temp in self.captions_templates[ds_name]: 292 | temps.extend(temp) 293 | 294 | # each ds: {: {"path": , "label": []}} 295 | captions = {} 296 | temps_ordered = np.random.randint(low = 0, high = len(temps), size = (len(ds))) 297 | for i,k in enumerate(ds): 298 | lbs_string = ", ".join(ds[k]["label"]) 299 | cap = temps[temps_ordered[i]].format(lbs_string) 300 | cap = re.sub(r"\s+", " ", cap).strip().lower() 301 | captions[ds_name + "_" + str(k)] = { 302 | "path": ds[k]["path"], 303 | "caption": cap 304 | } 305 | return captions 306 | 307 | 308 | # ---- Script 309 | if __name__ == "__main__": 310 | print("-"*70 + "\n:: Loading COCO dataset") 311 | coco_train, coco_droppped_train = get_coco_captions("../fast-ai-coco/annotations/captions_train2017.json") 312 | coco_val, coco_droppped_val = get_coco_captions("../fast-ai-coco/annotations/captions_val2017.json") 313 | 314 | print("-"*70 + "\n:: Loading Visual Genome dataset") 315 | genome_captions, dropped_genome = get_genome_captions() 316 | 317 | print("-"*70 + "\n:: Loading Flickr30k dataset") 318 | captions_flickr, dropped_flickr = get_flickr30k_captions() 319 | 320 | print("-"*70 + "\n:: Loading OpenImages Dataset") 321 | open_images_img2lab_val, oi_dropped_val = get_open_images_labels( 322 | "../downsampled-open-images-v4/validation-annotations-human-imagelabels-boxable.csv" 323 | ) 324 | open_images_img2lab_train, oi_dropped_train = get_open_images_labels( 325 | "../downsampled-open-images-v4/train-annotations-human-imagelabels-boxable.csv" 326 | ) 327 | open_images_img2lab_test, oi_dropped_test = get_open_images_labels( 328 | "../downsampled-open-images-v4/test-annotations-human-imagelabels-boxable.csv" 329 | ) 330 | 331 | print("-"*70 + "\n:: Loading Indoor CVPR Dataset") 332 | img2label_indoor = get_indoor_cvpr() 333 | 334 | print("-"*70 + "\n:: Loading Food-101k Dataset") 335 | img2label_food = get_food() 336 | 337 | print("-"*70 + "\n:: Loading STL-10 Dataset") 338 | img2label_stl = get_stl10() 339 | 340 | print("-"*70 + "\n:: Loading SVHN Dataset") 341 | img2label_svhn, dropped_svhn = get_svhn_data() 342 | 343 | # define table for tabulate 344 | headers = ["name", "num_samples", "dropped"] 345 | table = [ 346 | ["coco_train", len(coco_train), len(coco_droppped_train)], 347 | ["coco_val", len(coco_val), len(coco_droppped_val)], 348 | ["visual genome", len(genome_captions), len(dropped_genome)], 349 | ["open images (train)", len(open_images_img2lab_train), len(oi_dropped_train)], 350 | ["open images (val)", len(open_images_img2lab_val), len(oi_dropped_val)], 351 | ["open images (test)", len(open_images_img2lab_test), len(oi_dropped_test)], 352 | ["indoor cvpr", len(img2label_indoor), 0], 353 | ["food-101k", len(img2label_food), 0], 354 | ["STL-10", len(img2label_stl), 0], 355 | ["SVHN", len(img2label_svhn), len(dropped_svhn)], 356 | ] 357 | table_arr = np.asarray(table) 358 | total_samples = sum([ 359 | len(coco_train), 360 | len(coco_val), 361 | len(genome_captions), 362 | len(open_images_img2lab_train), 363 | len(open_images_img2lab_val), 364 | len(open_images_img2lab_test), 365 | len(img2label_indoor), 366 | len(img2label_food), 367 | len(img2label_stl), 368 | len(img2label_svhn) 369 | ]) 370 | total_dropped = sum([ 371 | len(coco_droppped_train), 372 | len(coco_droppped_val), 373 | len(dropped_genome), 374 | len(oi_dropped_train), 375 | len(oi_dropped_val), 376 | len(oi_dropped_test), 377 | len(dropped_svhn) 378 | ]) 379 | table.append(["total", total_samples, total_dropped]) 380 | print("\n", "-"*70, "\n") 381 | print(tabulate(table, headers, tablefmt="psql")) 382 | 383 | print("\n:: Generating captions for labels") 384 | 385 | capgen = CaptionGenerator() 386 | capgen_oi_train = capgen.generate_captions(open_images_img2lab_train, "open_images") 387 | capgen_oi_val = capgen.generate_captions(open_images_img2lab_val, "open_images") 388 | capgen_oi_test = capgen.generate_captions(open_images_img2lab_test, "open_images") 389 | capgen_indoor = capgen.generate_captions(img2label_indoor, "indoor") 390 | capgen_food = capgen.generate_captions(img2label_food, "food") 391 | capgen_stl = capgen.generate_captions(img2label_stl, "stl") 392 | capgen_svhn = capgen.generate_captions(img2label_svhn, "svhn") 393 | 394 | # make the master captions list 395 | common_captions = {} 396 | common_captions.update(capgen_oi_train) 397 | common_captions.update(capgen_oi_val) 398 | common_captions.update(capgen_oi_test) 399 | common_captions.update(capgen_indoor) 400 | common_captions.update(capgen_food) 401 | common_captions.update(capgen_stl) 402 | common_captions.update(capgen_svhn) 403 | common_captions.update(coco_train) 404 | common_captions.update(coco_val) 405 | common_captions.update(genome_captions) 406 | common_captions.update(captions_flickr) 407 | 408 | print(len(common_captions), table[-1][1]) 409 | with open("../captions_train.json", "w") as f: 410 | f.write(json.dumps(common_captions)) 411 | 412 | --------------------------------------------------------------------------------