├── requirements.txt ├── README.md ├── make_dataset.py └── minimal_edit_dataset.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | datasets==2.9.0 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository provides utilities to a minimal dataset for [InstructPix2Pix](https://arxiv.org/abs/2211.09800) like training for Diffusion models. 2 | 3 | ## Steps 4 | 5 | 1. Download the original dataset as discussed [here](https://github.com/timothybrooks/instruct-pix2pix#generated-dataset). I used this version: `clip-filtered-dataset`. Note that the download can take as long as 24 hours depending on the internet bandwidth. The dataset also requires at least 600 GB of storage. 6 | 2. Then run: 7 | 8 | ```bash 9 | python make_dataset.py --data_root clip-filtered-dataset --num_samples_to_use 1000 10 | ``` 11 | 3. The `make_dataset.py` was specifically designed to obtain a [🤗 dataset](https://huggingface.co/docs/datasets/). So, it's the most useful when you push the minimal dataset to the 🤗 Hub. You can do so by setting `push_to_hub` while running `make_dataset.py`. 12 | 13 | ## Example dataset 14 | 15 | https://huggingface.co/datasets/sayakpaul/instructpix2pix-1000-samples 16 | 17 | image 18 | 19 | The full version of the CLIP filtered dataset used for InstructPix2Pix training can be found here: https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered 20 | 21 | With the dataset being on the 🤗 Hub, one can do load the dataset with two lines of code: 22 | 23 | ```python 24 | from datasets import load_dataset 25 | 26 | dataset = load_dataset("timbrooks/instructpix2pix-clip-filtered", split="train") 27 | ``` 28 | 29 | And voila 🤗 30 | 31 | ## Acknowledgements 32 | 33 | The structure of `make_dataset.py` is inspired by Nate Raw's [notebook](https://gist.github.com/nateraw/c91fb548c3a749cfbe6436d555a547b0). 34 | -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from datasets import Dataset, Features 4 | from datasets import Image as ImageFeature 5 | from datasets import Value 6 | 7 | from minimal_edit_dataset import EditDataset 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser( 12 | description="Prepare a mini dataset fro InstructPix2Pix style training." 13 | ) 14 | parser.add_argument("--data_root", type=str) 15 | parser.add_argument("--num_samples_to_use", type=int, default=None) 16 | parser.add_argument("--push_to_hub", action="store_true") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def gen_examples(dataset): 22 | def fn(): 23 | for sample in dataset: 24 | yield { 25 | "original_prompt": sample["original_prompt"], 26 | "original_image": {"path": str(sample["original_image"])}, 27 | "edit_prompt": sample["edit_prompt"], 28 | "edited_prompt": sample["edited_prompt"], 29 | "edited_image": {"path": str(sample["edited_image"])}, 30 | } 31 | 32 | return fn 33 | 34 | 35 | def main(args): 36 | mini_edit_dataset = EditDataset(args.data_root, args.num_samples_to_use) 37 | generator_fn = gen_examples(mini_edit_dataset) 38 | 39 | print("Creating dataset...") 40 | mini_ds = Dataset.from_generator( 41 | generator_fn, 42 | features=Features( 43 | original_prompt=Value("string"), 44 | original_image=ImageFeature(), 45 | edit_prompt=Value("string"), 46 | edited_prompt=Value("string"), 47 | edited_image=ImageFeature(), 48 | ), 49 | ) 50 | 51 | if args.push_to_hub: 52 | print("Pushing to the Hub...") 53 | ds_name = f"instructpix2pix-clip-filtered" 54 | if args.num_samples_to_use is not None: 55 | num_samples = args.num_samples_to_use 56 | ds_name += f"{num_samples}-samples" 57 | mini_ds.push_to_hub(ds_name) 58 | 59 | 60 | if __name__ == "__main__": 61 | args = parse_args() 62 | main(args) 63 | -------------------------------------------------------------------------------- /minimal_edit_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/timothybrooks/instruct-pix2pix/blob/main/edit_dataset.py 3 | """ 4 | 5 | import random 6 | 7 | random.seed(0) 8 | 9 | import json 10 | from pathlib import Path 11 | from random import shuffle 12 | 13 | import torch 14 | from PIL import Image 15 | from torch.utils.data import Dataset 16 | 17 | 18 | class EditDataset(Dataset): 19 | def __init__(self, path: str, num_samples_to_use: int, return_paths: bool = True): 20 | self.path = path 21 | 22 | with open(Path(self.path, "seeds.json")) as f: 23 | seeds = json.load(f) 24 | shuffle(seeds) 25 | if num_samples_to_use is not None: 26 | self.seeds = seeds[:num_samples_to_use] 27 | else: 28 | self.seeds = seeds 29 | 30 | self.return_paths = return_paths 31 | 32 | def __len__(self) -> int: 33 | return len(self.seeds) 34 | 35 | def __getitem__(self, i: int) -> dict: 36 | name, seeds = self.seeds[i] 37 | prompt_dir = Path(self.path, name) 38 | seed = seeds[torch.randint(0, len(seeds), ()).item()] 39 | with open(prompt_dir.joinpath("prompt.json")) as fp: 40 | json_contents = dict(json.load(fp)) 41 | edit_prompt = json_contents["edit"] 42 | original_prompt = json_contents.get("input", "") 43 | edited_prompt = json_contents.get("output", "") 44 | url = json_contents.get("url", "") 45 | 46 | image_0_path = prompt_dir.joinpath(f"{seed}_0.jpg") 47 | image_1_path = prompt_dir.joinpath(f"{seed}_1.jpg") 48 | 49 | if self.return_paths: 50 | return dict( 51 | image_url=url, 52 | original_image=image_0_path, 53 | original_prompt=original_prompt, 54 | edit_prompt=edit_prompt, 55 | edited_prompt=edited_prompt, 56 | edited_image=image_1_path, 57 | ) 58 | 59 | image_0 = Image.open(image_0_path).convert("RGB") 60 | image_1 = Image.open(image_1_path).convert("RGB") 61 | 62 | return dict( 63 | image_url=url, 64 | original_image=image_0, 65 | original_prompt=original_prompt, 66 | edit_prompt=edit_prompt, 67 | edited_prompt=edited_prompt, 68 | edited_image=image_1, 69 | ) 70 | --------------------------------------------------------------------------------