├── .gitignore ├── README.md ├── assets ├── dreambooth.jpg └── results.jpg ├── data └── dogs │ └── instance │ ├── alvan-nee-9M0tSjb-cpA-unsplash.jpeg │ ├── alvan-nee-Id1DBHv4fbg-unsplash.jpeg │ ├── alvan-nee-bQaAJCbNq3g-unsplash.jpeg │ ├── alvan-nee-brFsZ7qszSY-unsplash.jpeg │ └── alvan-nee-eoqnr8ikwFE-unsplash.jpeg ├── dataset.py ├── environment.yaml ├── generate_identifier.py ├── inference.py ├── sample.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/* 2 | .cache/* 3 | data/* 4 | scripts/* 5 | outputs/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Stable DreamBooth 2 | 3 | This is an implementation of [DreamBooth](https://dreambooth.github.io/) based on [Stable Diffusion](https://github.com/CompVis/stable-diffusion). 4 | 5 | ## Update 6 | - *This repository has been migrated into `diffusers`! Less than 14 GB memeory is required!* 7 | See more in https://github.com/huggingface/diffusers/tree/main/examples/dreambooth 8 | 9 | 10 | ## Results 11 | Dreambooth results from original paper: 12 | ![Results](assets/dreambooth.jpg) 13 | 14 | 15 | The reproduced results: 16 | ![Results](assets/results.jpg) 17 | 18 | ## Requirements 19 | ### Hardware 20 | - A GPU with at least 30G Memory. 21 | - The training requires about 10 minites on A100 80G GPU with `batch_size` set to 4. 22 | 23 | ### Environment Setup 24 | Create conda environment with pytorch>=1.11. 25 | ```bash 26 | conda env create -f environment.yaml 27 | conda activate stable-diffusion 28 | ``` 29 | 30 | ## Quick Start 31 | ```bash 32 | python sample.py # Generate class samples. 33 | python train.py # Finetune stable diffusion model. 34 | ``` 35 | The generation results are in `logs/dog_finetune`. 36 | 37 | ## Finetune with your own data. 38 | 39 | ### 1. Data Preparation 40 | 1. Collect 3~5 images of an object and save into `data/mydata/instance` folder. 41 | 2. Sample images of the same class as specified object using `sample.py`. 42 | 1. Change corresponding variables in `sample.py`. The `prompt` should be like "a {class}". And the `save_dir` should be changed to `data/mydata/class`. 43 | 2. Run the sample script. 44 | ```bash 45 | python sample.py 46 | ``` 47 | 48 | ### 2. Finetuning 49 | 1. Change the TrainConfig in `train.py`. 50 | 2. Start training. 51 | ```bash 52 | python train.py 53 | ``` 54 | 55 | ### 3. Inference 56 | ```bash 57 | python inference.py --prompt "photo of a [V] dog in a dog house" --checkpoint_dir logs/dogs_finetune 58 | ``` 59 | Generated images are in `outputs` by default. 60 | 61 | ## Acknowledgement 62 | 63 | - Stable Diffusion by CompVis https://github.com/CompVis/stable-diffusion 64 | - DreamBooth https://dreambooth.github.io/ 65 | - Diffusers https://github.com/huggingface/diffusers 66 | -------------------------------------------------------------------------------- /assets/dreambooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/stable-dreambooth/f43bc64a090325dfd0b2e28fb50b6ca71915c0d1/assets/dreambooth.jpg -------------------------------------------------------------------------------- /assets/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/stable-dreambooth/f43bc64a090325dfd0b2e28fb50b6ca71915c0d1/assets/results.jpg -------------------------------------------------------------------------------- /data/dogs/instance/alvan-nee-9M0tSjb-cpA-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/stable-dreambooth/f43bc64a090325dfd0b2e28fb50b6ca71915c0d1/data/dogs/instance/alvan-nee-9M0tSjb-cpA-unsplash.jpeg -------------------------------------------------------------------------------- /data/dogs/instance/alvan-nee-Id1DBHv4fbg-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/stable-dreambooth/f43bc64a090325dfd0b2e28fb50b6ca71915c0d1/data/dogs/instance/alvan-nee-Id1DBHv4fbg-unsplash.jpeg -------------------------------------------------------------------------------- /data/dogs/instance/alvan-nee-bQaAJCbNq3g-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/stable-dreambooth/f43bc64a090325dfd0b2e28fb50b6ca71915c0d1/data/dogs/instance/alvan-nee-bQaAJCbNq3g-unsplash.jpeg -------------------------------------------------------------------------------- /data/dogs/instance/alvan-nee-brFsZ7qszSY-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/stable-dreambooth/f43bc64a090325dfd0b2e28fb50b6ca71915c0d1/data/dogs/instance/alvan-nee-brFsZ7qszSY-unsplash.jpeg -------------------------------------------------------------------------------- /data/dogs/instance/alvan-nee-eoqnr8ikwFE-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/stable-dreambooth/f43bc64a090325dfd0b2e28fb50b6ca71915c0d1/data/dogs/instance/alvan-nee-eoqnr8ikwFE-unsplash.jpeg -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from PIL import Image 3 | from torch.utils import data 4 | from pathlib import Path 5 | from torchvision import transforms 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', 9 | '.JPG', 10 | '.jpeg', 11 | '.JPEG', 12 | '.png', 13 | '.PNG', 14 | '.ppm', 15 | '.PPM', 16 | '.bmp', 17 | '.BMP', 18 | '.tif', 19 | '.TIF', 20 | '.tiff', 21 | '.TIFF', 22 | ] 23 | 24 | 25 | def is_image_file(file: Path): 26 | return file.suffix in IMG_EXTENSIONS 27 | 28 | 29 | def make_dataset(dir, max_dataset_size=float("inf")) -> List[Path]: 30 | images = [] 31 | root = Path(dir) 32 | assert root.is_dir(), '%s is not a valid directory' % dir 33 | 34 | for file in root.rglob('*'): 35 | if is_image_file(file): 36 | images.append(file) 37 | return images[:min(max_dataset_size, len(images))] 38 | 39 | 40 | def default_loader(path): 41 | return Image.open(path).convert('RGB') 42 | 43 | 44 | class ImageFolder(data.Dataset): 45 | def __init__(self, 46 | root, 47 | transform=None, 48 | return_paths=False, 49 | return_dict=False, 50 | sort=False, 51 | loader=default_loader): 52 | imgs = make_dataset(root) 53 | if sort: 54 | imgs = sorted(imgs) 55 | if len(imgs) == 0: 56 | raise (RuntimeError("Found 0 images in: " + root + "\n" 57 | "Supported image extensions are: " + 58 | ",".join(IMG_EXTENSIONS))) 59 | 60 | self.root = root 61 | self.imgs = imgs 62 | self.transform = transform 63 | self.return_paths = return_paths 64 | self.return_dict = return_dict 65 | self.loader = loader 66 | 67 | def __getitem__(self, index): 68 | index = index % len(self) 69 | path = self.imgs[index] 70 | img = self.loader(path) 71 | if self.transform is not None: 72 | img = self.transform(img) 73 | if self.return_paths: 74 | return img, str(path) 75 | else: 76 | if self.return_dict: 77 | return {'images': img} 78 | else: 79 | return img 80 | 81 | def __len__(self): 82 | return len(self.imgs) 83 | 84 | 85 | class MergeDataset(data.Dataset): 86 | def __init__(self, *datasets): 87 | """Merge multiple datasets to one dataset, and each time retrives a combinations of items in all sub datasets. 88 | """ 89 | self.datasets = datasets 90 | self.sizes = [len(dataset) for dataset in datasets] 91 | print('dataset size', self.sizes) 92 | 93 | def __getitem__(self, indexs: List[int]): 94 | return tuple(dataset[idx] for idx, dataset in zip(indexs, self.datasets)) 95 | 96 | def __len__(self): 97 | return max(self.sizes) 98 | 99 | class TrainDataset(data.Dataset): 100 | def __init__(self, data_path, instance_prompt, class_prompt, image_size): 101 | self.instance_prompt = instance_prompt 102 | self.class_prompt = class_prompt 103 | self.transform = transforms.Compose( 104 | [ 105 | transforms.Resize((image_size, image_size)), 106 | transforms.ToTensor(), 107 | transforms.Normalize([0.5], [0.5]), 108 | ] 109 | ) 110 | self.data1 = ImageFolder(Path(data_path) / 'instance', self.transform) # instance dataset 111 | self.data2 = ImageFolder(Path(data_path) / 'class', self.transform) # class dataset 112 | 113 | self.sizes = [len(self.data1), len(self.data2)] 114 | 115 | def __getitem__(self, index): 116 | img1 = self.data1[index] 117 | img2 = self.data2[index] 118 | return img1, self.instance_prompt, img2, self.class_prompt 119 | 120 | def __len__(self): 121 | return max(self.sizes) -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: stable-diffusion 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - imageio==2.9.0 18 | - imageio-ffmpeg==0.4.2 19 | - einops==0.3.0 20 | - transformers==4.19.2 21 | - torchmetrics==0.6.0 -------------------------------------------------------------------------------- /generate_identifier.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines import StableDiffusionPipeline 2 | import torch 3 | import random 4 | 5 | if __name__ == "__main__": 6 | model_id = "CompVis/stable-diffusion-v1-4" 7 | device = "cuda" 8 | model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, cache_dir="./.cache").to(device) 9 | tokenizer = model.tokenizer 10 | 11 | rare_tokens = [] 12 | for k, v in tokenizer.encoder.items(): 13 | if len(k) <= 3 and 40000 > v > 35000: 14 | rare_tokens.append(k) 15 | 16 | 17 | identifiers = [] 18 | for _ in range(3): 19 | idx = random.randint(0, len(rare_tokens)) 20 | identifiers.append(rare_tokens[idx]) 21 | 22 | print(" ".join(identifiers)) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from diffusers.pipelines import StableDiffusionPipeline 3 | import torch 4 | from argparse import ArgumentParser 5 | import json 6 | 7 | def parse_args(): 8 | parser = ArgumentParser() 9 | parser.add_argument("--prompt", required=True) 10 | parser.add_argument("--checkpoint_dir", required=True) 11 | parser.add_argument("--save_dir", default="outputs") 12 | parser.add_argument("--sample_nums", default=16) 13 | parser.add_argument("-gs", "--guidance_scale", type=float, default=7.5) 14 | return parser.parse_args() 15 | 16 | if __name__ == "__main__": 17 | args = parse_args() 18 | with open(Path(args.checkpoint_dir) / 'config.json') as f: 19 | config = json.loads(f.read()) 20 | args.prompt = args.prompt.replace("[V]", config["identifier"]) 21 | device = "cuda" 22 | model = StableDiffusionPipeline.from_pretrained(args.checkpoint_dir).to(device) 23 | 24 | with torch.no_grad(): 25 | with torch.autocast("cuda"): 26 | images = model([args.prompt] * args.sample_nums, height=512, width=512, guidance_scale=args.guidance_scale, num_inference_steps=50)["sample"] 27 | 28 | save_dir = Path(args.save_dir) 29 | save_dir.mkdir(parents=True, exist_ok=True) 30 | 31 | for i, image in enumerate(images): 32 | image.save(save_dir / f'{i}.jpg') 33 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines import StableDiffusionPipeline 2 | import torch 3 | 4 | sample_nums = 1000 5 | batch_size = 16 6 | prompt = "a photo of dog" 7 | save_dir = "data/dogs/class" 8 | 9 | 10 | if __name__ == "__main__": 11 | model_id = "CompVis/stable-diffusion-v1-4" 12 | device = "cuda" 13 | model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, cache_dir="./.cache").to(device) 14 | 15 | datasets = [prompt] * sample_nums 16 | datasets = [datasets[x:x+batch_size] for x in range(0, sample_nums, batch_size)] 17 | id = 0 18 | 19 | for text in datasets: 20 | with torch.no_grad(): 21 | images = model(text, height=512, width=512, num_inference_steps=50)["sample"] 22 | 23 | for image in images: 24 | image.save(f"{save_dir}/{id}.png") 25 | id += 1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import os 4 | from dataclasses import dataclass 5 | from typing import List 6 | from dataset import TrainDataset 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from accelerate import Accelerator 11 | from diffusers.schedulers import DDPMScheduler, LMSDiscreteScheduler 12 | from diffusers.pipelines import StableDiffusionPipeline 13 | from PIL import Image 14 | from tqdm.auto import tqdm 15 | from torchvision import transforms 16 | from datasets import load_dataset 17 | from pathlib import Path 18 | 19 | @dataclass 20 | class TrainingConfig: 21 | # Task specific parameters 22 | instance_prompt: str = "photo of a [V] dog" 23 | class_prompt: str = "photo of a dog" 24 | evaluate_prompt = ["photo of a [V] dog"] * 4 + ["photo of a [V] dog in a doghouse"] * 4 + ["photo of a [V] dog in a bucket"] * 4 + ["photo of a sleeping [V] dog"]*4 25 | data_path: str = "./data/dogs" 26 | identifier: str = "sks" 27 | 28 | # Basic Training Parameters 29 | num_epochs: int = 1 30 | train_batch_size: int = 4 31 | learning_rate: float = 1e-5 32 | image_size: int = 512 # the generated image resolution 33 | gradient_accumulation_steps: int = 1 34 | 35 | # Hyperparmeter for diffusion models 36 | num_train_timesteps: int = 1000 37 | train_guidance_scale: float = 1 # guidance scale at training 38 | sample_guidance_scale: float = 7.5 # guidance scale at inference 39 | 40 | # Practical Training Settings 41 | mixed_precision: str = 'fp16' # `no` for float32, `fp16` for automatic mixed precision 42 | save_image_epochs: int = 1 43 | save_model_epochs: int = 1 44 | output_dir: str = 'logs/dog_finetune' 45 | overwrite_output_dir: bool = True # overwrite the old model when re-running the notebook 46 | seed: int = 42 47 | 48 | def __post_init__(self): 49 | self.instance_prompt = self.instance_prompt.replace("[V]", self.identifier) 50 | self.evaluate_prompt = [s.replace("[V]", self.identifier) for s in self.evaluate_prompt] 51 | 52 | 53 | def pred(model, noisy_latent, time_steps, prompt, guidance_scale): 54 | batch_size = noisy_latent.shape[0] 55 | text_input = model.tokenizer( 56 | prompt, 57 | padding="max_length", 58 | max_length=model.tokenizer.model_max_length, 59 | truncation=True, 60 | return_tensors="pt", 61 | ) 62 | with torch.no_grad(): 63 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 64 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 65 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 66 | # corresponds to doing no classifier free guidance. 67 | do_classifier_free_guidance = guidance_scale > 1.0 68 | # get unconditional embeddings for classifier free guidance 69 | if do_classifier_free_guidance: 70 | max_length = text_input.input_ids.shape[-1] 71 | uncond_input = model.tokenizer( 72 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 73 | ) 74 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 75 | 76 | # For classifier free guidance, we need to do two forward passes. 77 | # Here we concatenate the unconditional and text embeddings into a single batch 78 | # to avoid doing two forward passes 79 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 80 | 81 | latent_model_input = torch.cat([noisy_latent] * 2) if do_classifier_free_guidance else noisy_latent 82 | time_steps = torch.cat([time_steps] * 2) if do_classifier_free_guidance else time_steps 83 | noise_pred = model.unet(latent_model_input, time_steps, encoder_hidden_states=text_embeddings)["sample"] 84 | # perform guidance 85 | if do_classifier_free_guidance: 86 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 87 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 88 | return noise_pred 89 | 90 | 91 | def train_loop(config: TrainingConfig, model: StableDiffusionPipeline, noise_scheduler, optimizer, train_dataloader): 92 | # Initialize accelerator and tensorboard logging 93 | accelerator = Accelerator( 94 | mixed_precision=config.mixed_precision, 95 | gradient_accumulation_steps=config.gradient_accumulation_steps, 96 | 97 | ) 98 | if accelerator.is_main_process: 99 | accelerator.init_trackers("train_example") 100 | 101 | # Prepare everything 102 | # There is no specific order to remember, you just need to unpack the 103 | # objects in the same order you gave them to the prepare method. 104 | model, optimizer, train_dataloader = accelerator.prepare( 105 | model, optimizer, train_dataloader 106 | ) 107 | 108 | global_step = 0 109 | 110 | # Now you train the model 111 | for epoch in range(config.num_epochs): 112 | progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) 113 | progress_bar.set_description(f"Epoch {epoch}") 114 | 115 | for step, batch in enumerate(train_dataloader): 116 | instance_imgs, instance_prompt, class_imgs, class_prompt = batch 117 | imgs = torch.cat((instance_imgs, class_imgs), dim=0) 118 | prompt = instance_prompt + class_prompt 119 | 120 | # Sample noise to add to the images 121 | bs = imgs.shape[0] 122 | 123 | # Sample a random timestep for each image 124 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=accelerator.device).long() 125 | 126 | # Add noise to the clean images according to the noise magnitude at each timestep 127 | # (this is the forward diffusion process) 128 | with torch.no_grad(): 129 | latents = model.vae.encode(imgs).latent_dist.sample() * 0.18215 130 | noise = torch.randn(latents.shape, device=accelerator.device) 131 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps.cpu().numpy()) 132 | 133 | with accelerator.accumulate(model): 134 | # Predict the noise residual 135 | noise_pred = pred(model, noisy_latents, timesteps, prompt, guidance_scale=config.train_guidance_scale) 136 | loss = F.mse_loss(noise_pred, noise) 137 | accelerator.backward(loss) 138 | 139 | accelerator.clip_grad_norm_(model.unet.parameters(), 1.0) 140 | optimizer.step() 141 | optimizer.zero_grad() 142 | 143 | progress_bar.update(1) 144 | logs = {"loss": loss.detach().item(), "step": global_step} 145 | progress_bar.set_postfix(**logs) 146 | accelerator.log(logs, step=global_step) 147 | global_step += 1 148 | 149 | # After each epoch you optionally sample some demo images with evaluate() and save the model 150 | if accelerator.is_main_process: 151 | if epoch % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: 152 | evaluate(config, epoch, model) 153 | 154 | if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: 155 | model.save_pretrained(config.output_dir) 156 | 157 | def make_grid(images, rows, cols): 158 | w, h = images[0].size 159 | grid = Image.new('RGB', size=(cols*w, rows*h)) 160 | for i, image in enumerate(images): 161 | grid.paste(image, box=(i%cols*w, i//cols*h)) 162 | return grid 163 | 164 | def evaluate(config: TrainingConfig, epoch, pipeline: StableDiffusionPipeline): 165 | # Sample some images from random noise (this is the backward diffusion process). 166 | # The default pipeline output type is `List[PIL.Image]` 167 | with torch.no_grad(): 168 | with torch.autocast("cuda"): 169 | images = pipeline(config.evaluate_prompt, num_inference_steps=50, width=config.image_size, height=config.image_size, guidance_scale=config.sample_guidance_scale)["sample"] 170 | 171 | # Make a grid out of the images 172 | image_grid = make_grid(images, rows=4, cols=4) 173 | 174 | # Save the images 175 | test_dir = os.path.join(config.output_dir, "samples") 176 | os.makedirs(test_dir, exist_ok=True) 177 | image_grid.save(f"{test_dir}/{epoch:04d}.jpg") 178 | 179 | def get_dataloader(config: TrainingConfig): 180 | dataset = TrainDataset(config.data_path, config.instance_prompt, config.class_prompt, config.image_size) 181 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=True, pin_memory=True) 182 | return dataloader 183 | 184 | if __name__ == "__main__": 185 | config = TrainingConfig() 186 | output_dir = Path(config.output_dir) 187 | output_dir.mkdir(parents=True, exist_ok=True) 188 | 189 | with open(output_dir / "config.json", "w") as f: 190 | json.dump(dataclasses.asdict(config) , f) 191 | 192 | model_id = "CompVis/stable-diffusion-v1-4" 193 | device = "cuda" 194 | 195 | try: 196 | model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, cache_dir="./.cache").to(device) 197 | except Exception as e: 198 | print(e) 199 | print("Run 'huggingface-cli login' to store auth token.") 200 | exit(1) 201 | 202 | train_dataloader = get_dataloader(config) 203 | optimizer = torch.optim.AdamW(model.unet.parameters(), lr=config.learning_rate) 204 | noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps, beta_start=0.00085, beta_end=0.0120) 205 | 206 | train_loop(config, model, noise_scheduler, optimizer, train_dataloader) 207 | --------------------------------------------------------------------------------