├── src ├── __init__.py ├── constants.py ├── utils.py ├── dreambooth_trainer.py └── datasets.py ├── .gitignore ├── requirements.txt ├── notebooks └── generate_class_priors.ipynb ├── scripts └── generate_experimental_images.py ├── train_dreambooth.py ├── LICENSE └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | **.h5 3 | wandb/ 4 | **.egg-info/ -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | PADDING_TOKEN = 49407 2 | MAX_PROMPT_LENGTH = 77 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras_cv==0.4.0 2 | tensorflow>=2.10.0 3 | tensorflow_datasets>=4.8.1 4 | pillow==9.4.0 5 | wandb>=0.13.9 6 | imutils 7 | opencv-python 8 | -------------------------------------------------------------------------------- /notebooks/generate_class_priors.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "machine_shape": "hm" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU", 17 | "gpuClass": "premium" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": { 24 | "id": "lS37M6R9h7X6" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "!pip install -q keras_cv" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "source": [ 34 | "import tensorflow as tf \n", 35 | "\n", 36 | "tf.keras.mixed_precision.set_global_policy(\"mixed_float16\")" 37 | ], 38 | "metadata": { 39 | "id": "ygttkaWeiuE2" 40 | }, 41 | "execution_count": null, 42 | "outputs": [] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "source": [ 47 | "import keras_cv\n", 48 | "\n", 49 | "model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=True)" 50 | ], 51 | "metadata": { 52 | "id": "mgHFUd8pipkQ" 53 | }, 54 | "execution_count": null, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "from tqdm import tqdm\n", 61 | "import numpy as np \n", 62 | "import hashlib\n", 63 | "import PIL \n", 64 | "import os\n", 65 | "\n", 66 | "class_images_dir = \"class-images\"\n", 67 | "os.makedirs(class_images_dir, exist_ok=True)\n", 68 | "\n", 69 | "\n", 70 | "class_prompt = \"a photo of dog\"\n", 71 | "num_imgs_to_generate = 200 \n", 72 | "for i in tqdm(range(num_imgs_to_generate)):\n", 73 | " images = model.text_to_image(\n", 74 | " class_prompt,\n", 75 | " batch_size=3,\n", 76 | " )\n", 77 | " idx = np.random.choice(len(images))\n", 78 | " selected_image = PIL.Image.fromarray(images[idx])\n", 79 | " \n", 80 | " hash_image = hashlib.sha1(selected_image.tobytes()).hexdigest()\n", 81 | " image_filename = os.path.join(class_images_dir, f\"{hash_image}.jpg\")\n", 82 | " selected_image.save(image_filename)" 83 | ], 84 | "metadata": { 85 | "id": "czaYTOIOismu" 86 | }, 87 | "execution_count": null, 88 | "outputs": [] 89 | } 90 | ] 91 | } -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | 3 | import keras_cv 4 | from typing import List 5 | 6 | import PIL 7 | 8 | import tensorflow as tf 9 | 10 | import wandb 11 | from wandb.keras import WandbModelCheckpoint 12 | 13 | 14 | class QualitativeValidationCallback(tf.keras.callbacks.Callback): 15 | def __init__( 16 | self, 17 | img_heigth: int, 18 | img_width: int, 19 | prompts: List[str], 20 | num_imgs_to_gen: int = 5, 21 | *args, 22 | **kwargs, 23 | ): 24 | super().__init__(*args, **kwargs) 25 | self.img_heigth = img_heigth 26 | self.img_width = img_width 27 | self.prompts = prompts 28 | self.num_imgs_to_gen = num_imgs_to_gen 29 | self.sd_model = keras_cv.models.StableDiffusion( 30 | img_height=self.img_heigth, img_width=self.img_width 31 | ) 32 | self.wandb_table = wandb.Table(columns=["epoch", "prompt", "images"]) 33 | 34 | def on_epoch_end(self, epoch, logs=None): 35 | print(f"Performing inference for logging generated images for epoch {epoch}...") 36 | print(f"Number of images to generate: {self.num_imgs_to_gen}") 37 | 38 | # load weights to stable diffusion model 39 | self.sd_model.diffusion_model.set_weights( 40 | self.model.diffusion_model.get_weights() 41 | ) 42 | if hasattr(self.model, "text_encoder"): 43 | self.sd_model.text_encoder.set_weights( 44 | self.model.text_encoder.get_weights() 45 | ) 46 | 47 | for prompt in self.prompts: 48 | images_dreamboothed = self.sd_model.text_to_image( 49 | prompt, batch_size=self.num_imgs_to_gen 50 | ) 51 | images_dreamboothed = [ 52 | wandb.Image(PIL.Image.fromarray(image), caption=f"{i}: {prompt}") 53 | for i, image in enumerate(images_dreamboothed) 54 | ] 55 | self.wandb_table.add_data(epoch, prompt, images_dreamboothed) 56 | 57 | def on_train_end(self, logs=None): 58 | wandb.log({"validation-table": self.wandb_table}) 59 | print("Performing inference on train end for logging generated images...") 60 | print(f"Number of images to generate: {self.num_imgs_to_gen}") 61 | for prompt in self.prompts: 62 | images_dreamboothed = self.sd_model.text_to_image(prompt, batch_size=self.num_imgs_to_gen) 63 | wandb.log( 64 | { 65 | f"validation/Prompt: {prompt}": [ 66 | wandb.Image(PIL.Image.fromarray(image), caption=f"{i}: {prompt}") 67 | for i, image in enumerate(images_dreamboothed) 68 | ] 69 | } 70 | ) 71 | 72 | 73 | class DreamBoothCheckpointCallback(WandbModelCheckpoint): 74 | def __init__( 75 | self, filepath, save_weights_only: bool = False, *args, **kwargs 76 | ) -> None: 77 | super(DreamBoothCheckpointCallback.__bases__[0], self).__init__( 78 | filepath, save_weights_only=save_weights_only, *args, **kwargs 79 | ) 80 | self.save_weights_only = save_weights_only 81 | # User-friendly warning when trying to save the best model. 82 | if self.save_best_only: 83 | self._check_filepath() 84 | self._is_old_tf_keras_version = None 85 | 86 | def _log_ckpt_as_artifact(self, filepath: str, aliases) -> None: 87 | if wandb.run is not None: 88 | model_artifact = wandb.Artifact(f"run_{wandb.run.id}_model", type="model") 89 | for file in glob(f"{filepath}*.h5"): 90 | model_artifact.add_file(file) 91 | wandb.log_artifact(model_artifact, aliases=aliases or []) 92 | -------------------------------------------------------------------------------- /scripts/generate_experimental_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use this script to generate and report a batch of images to W&B 3 | with a single prompt through multiple weights of Stable Diffusion. 4 | This is particularly useful when you have multiple fine-tuned weight 5 | files. It works for the both cases: (a) diffusion model only and 6 | (b) text encoder + diffusion model 7 | 8 | Usage: 9 | 10 | # Find weight files(.h5) under "." location. 11 | # Generate 4 images with a nested loop over 12 | # num_steps x ugs combinations. 13 | $ python generate_experimental_images.py \ 14 | --base_root_dir "." \ 15 | --caption "A photo of sks dog in a bucket" \ 16 | --num_image_gen 4 \ 17 | --num_steps 75 100 150 \ 18 | --ugs 15 30 \ 19 | --wandb_project_id "my-wandb-project" 20 | 21 | Depending on the unique identifier and class you used, you'd need 22 | to change the caption accordingly. In the above case, the unique identifier 23 | is "sks" and the class is "dog". 24 | 25 | If the fine-tuned weights are stored as artifacts in WandB, then you can 26 | use this script: https://gist.github.com/sayakpaul/0d83d7fd7c3939ce2ddc2292b6d4f173 27 | """ 28 | 29 | import tensorflow as tf 30 | 31 | tf.keras.mixed_precision.set_global_policy("mixed_float16") 32 | 33 | import argparse 34 | import glob 35 | 36 | import keras_cv 37 | import PIL 38 | import wandb 39 | 40 | 41 | def generate_report( 42 | sd_model, 43 | weights_dict, 44 | caption, 45 | num_image_gen, 46 | num_steps, 47 | unconditional_guidance_scales, 48 | wandb_project, 49 | ): 50 | "Generates images and report the results to WandB." 51 | for key in list(weights_dict.keys()): 52 | print(f"Generating images for model({key}).") 53 | wandb.init(project=wandb_project, name=key) 54 | 55 | unet_params_path = weights_dict[key]["unet"] 56 | sd_model.diffusion_model.load_weights(unet_params_path) 57 | 58 | if "text_encoder" in weights_dict[key]: 59 | text_encoder_params_path = weights_dict[key]["text_encoder"] 60 | sd_model.text_encoder.load_weights(text_encoder_params_path) 61 | 62 | for steps in num_steps: 63 | for ugs in unconditional_guidance_scales: 64 | images = sd_model.text_to_image( 65 | caption, 66 | batch_size=num_image_gen, 67 | num_steps=steps, 68 | unconditional_guidance_scale=ugs, 69 | ) 70 | 71 | wandb.log( 72 | { 73 | f"num_steps@{steps}-ugs@{ugs}": [ 74 | wandb.Image( 75 | PIL.Image.fromarray(image), caption=f"{i}: {caption}" 76 | ) 77 | for i, image in enumerate(images) 78 | ] 79 | } 80 | ) 81 | 82 | wandb.finish() 83 | 84 | 85 | def find_weights(base_root_dir): 86 | """Finds weights per model name.""" 87 | weights_dict = {} 88 | 89 | for file in glob.glob(f"{base_root_dir}/*.h5"): 90 | if "@True" in file: 91 | rindex = file.rindex("@True") 92 | key = file[: rindex + len("@True")] 93 | else: 94 | rindex = file.rindex("@False") 95 | key = file[: rindex + len("@False")] 96 | 97 | if key not in weights_dict: 98 | weights_dict[key] = {} 99 | 100 | if "text_encoder" in file[rindex:]: 101 | weights_dict[key]["text_encoder"] = file 102 | else: 103 | weights_dict[key]["unet"] = file 104 | 105 | return weights_dict 106 | 107 | 108 | def run(args): 109 | """Finds weights, generate images based on them""" 110 | # Initialize the SD model. 111 | img_height = img_width = 512 112 | sd_model = keras_cv.models.StableDiffusion( 113 | img_width=img_width, img_height=img_height, jit_compile=True 114 | ) 115 | 116 | # Find weights per model. 117 | weights_dict = find_weights(args.base_root_dir) 118 | 119 | # Run image generations. 120 | generate_report( 121 | sd_model, 122 | weights_dict, 123 | args.caption, 124 | args.num_image_gen, 125 | args.num_steps, 126 | args.ugs, 127 | args.wandb_project_id, 128 | ) 129 | 130 | 131 | def parse_args(): 132 | parser = argparse.ArgumentParser( 133 | description="Script to perform image generating experimentations." 134 | ) 135 | 136 | parser.add_argument( 137 | "--base_root_dir", 138 | type=str, 139 | default=".", 140 | help="base directory to search for weight files", 141 | ) 142 | parser.add_argument( 143 | "--caption", 144 | type=str, 145 | default="A photo of sks person without mustache, handsome, ultra realistic, 4k, 8k", 146 | help="prompt to use to generate images", 147 | ) 148 | parser.add_argument( 149 | "--num_image_gen", 150 | type=int, 151 | default=16, 152 | help="number of images to generate per model", 153 | ) 154 | parser.add_argument( 155 | "--num_steps", 156 | nargs="+", 157 | type=int, 158 | default=[75, 100, 150], 159 | help="list of num_steps", 160 | ) 161 | parser.add_argument( 162 | "--ugs", 163 | nargs="+", 164 | type=int, 165 | default=[15, 30], 166 | help="list of unconditional guidance scale", 167 | ) 168 | 169 | parser.add_argument( 170 | "--wandb_project_id", 171 | type=str, 172 | default="dreambooth-generate-dog", 173 | help="W&B project id to log", 174 | ) 175 | 176 | return parser.parse_args() 177 | 178 | 179 | if __name__ == "__main__": 180 | args = parse_args() 181 | run(args) 182 | -------------------------------------------------------------------------------- /src/dreambooth_trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.experimental.numpy as tnp 3 | from keras_cv.models.stable_diffusion.text_encoder import TextEncoder 4 | 5 | from src.constants import MAX_PROMPT_LENGTH 6 | 7 | 8 | class DreamBoothTrainer(tf.keras.Model): 9 | # Reference: 10 | # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py 11 | 12 | def __init__( 13 | self, 14 | diffusion_model, 15 | vae, 16 | noise_scheduler, 17 | train_text_encoder, 18 | use_mixed_precision=False, 19 | prior_loss_weight=1.0, 20 | max_grad_norm=1.0, 21 | **kwargs 22 | ): 23 | super().__init__(**kwargs) 24 | 25 | self.diffusion_model = diffusion_model 26 | self.diffusion_model.trainable = True 27 | 28 | self.vae = vae 29 | self.vae.trainable = False 30 | 31 | self.noise_scheduler = noise_scheduler 32 | 33 | self.train_text_encoder = train_text_encoder 34 | if self.train_text_encoder: 35 | self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH) 36 | self.text_encoder.trainable = True 37 | self.pos_ids = tf.convert_to_tensor( 38 | [list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32 39 | ) 40 | 41 | self.prior_loss_weight = prior_loss_weight 42 | self.max_grad_norm = max_grad_norm 43 | self.use_mixed_precision = use_mixed_precision 44 | 45 | def train_step(self, inputs): 46 | instance_batch = inputs[0] 47 | class_batch = inputs[1] 48 | 49 | instance_images = instance_batch["instance_images"] 50 | instance_texts = instance_batch["instance_texts"] 51 | class_images = class_batch["class_images"] 52 | class_texts = class_batch["class_texts"] 53 | 54 | images = tf.concat([instance_images, class_images], 0) 55 | texts = tf.concat( 56 | [instance_texts, class_texts], 0 57 | ) # `texts` can either be caption tokens or embedded caption tokens. 58 | batch_size = tf.shape(images)[0] 59 | 60 | with tf.GradientTape() as tape: 61 | # If the `text_encoder` is being fine-tuned. 62 | if self.train_text_encoder: 63 | texts = self.text_encoder([texts, self.pos_ids], training=True) 64 | 65 | # Project image into the latent space and sample from it. 66 | latents = self.sample_from_encoder_outputs(self.vae(images, training=False)) 67 | # Know more about the magic number here: 68 | # https://keras.io/examples/generative/fine_tune_via_textual_inversion/ 69 | latents = latents * 0.18215 70 | 71 | # Sample noise that we'll add to the latents. 72 | noise = tf.random.normal(tf.shape(latents)) 73 | 74 | # Sample a random timestep for each image. 75 | timesteps = tnp.random.randint( 76 | 0, self.noise_scheduler.train_timesteps, (batch_size,) 77 | ) 78 | 79 | # Add noise to the latents according to the noise magnitude at each timestep 80 | # (this is the forward diffusion process). 81 | noisy_latents = self.noise_scheduler.add_noise( 82 | tf.cast(latents, noise.dtype), noise, timesteps 83 | ) 84 | 85 | # Get the target for loss depending on the prediction type 86 | # just the sampled noise for now. 87 | target = noise # noise_schedule.predict_epsilon == True 88 | 89 | # Predict the noise residual and compute loss. 90 | timestep_embedding = tf.map_fn( 91 | lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32 92 | ) 93 | model_pred = self.diffusion_model( 94 | [noisy_latents, timestep_embedding, texts], training=True 95 | ) 96 | loss = self.compute_loss(target, model_pred) 97 | if self.use_mixed_precision: 98 | loss = self.optimizer.get_scaled_loss(loss) 99 | 100 | # Update parameters of the diffusion model. 101 | if self.train_text_encoder: 102 | trainable_vars = ( 103 | self.text_encoder.trainable_variables 104 | + self.diffusion_model.trainable_variables 105 | ) 106 | else: 107 | trainable_vars = self.diffusion_model.trainable_variables 108 | 109 | gradients = tape.gradient(loss, trainable_vars) 110 | if self.use_mixed_precision: 111 | gradients = self.optimizer.get_unscaled_gradients(gradients) 112 | gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients] 113 | self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 114 | 115 | return {m.name: m.result() for m in self.metrics} 116 | 117 | def get_timestep_embedding(self, timestep, dim=320, max_period=10000): 118 | half = dim // 2 119 | log_max_preiod = tf.math.log(tf.cast(max_period, tf.float32)) 120 | freqs = tf.math.exp(-log_max_preiod * tf.range(0, half, dtype=tf.float32) / half) 121 | args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs 122 | embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0) 123 | return embedding 124 | 125 | def sample_from_encoder_outputs(self, outputs): 126 | mean, logvar = tf.split(outputs, 2, axis=-1) 127 | logvar = tf.clip_by_value(logvar, -30.0, 20.0) 128 | std = tf.exp(0.5 * logvar) 129 | sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype) 130 | return mean + std * sample 131 | 132 | def compute_loss(self, target, model_pred): 133 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 134 | model_pred, model_pred_prior = tf.split(model_pred, num_or_size_splits=2, axis=0) 135 | target, target_prior = tf.split(target, num_or_size_splits=2, axis=0) 136 | 137 | # Compute instance loss. 138 | loss = self.compiled_loss(target, model_pred) 139 | 140 | # Compute prior loss. 141 | prior_loss = self.compiled_loss(target_prior, model_pred_prior) 142 | 143 | # Add the prior loss to the instance loss. 144 | loss = loss + self.prior_loss_weight * prior_loss 145 | return loss 146 | 147 | def save_weights( 148 | self, ckpt_path_prefix, overwrite=True, save_format=None, options=None 149 | ): 150 | # Overriding this method will allow us to use the `ModelCheckpoint` 151 | # callback directly with this trainer class. In this case, it will 152 | # only checkpoint the `diffusion_model` and optionally the `text_encoder`. 153 | diffusion_model_path = ckpt_path_prefix + "-unet.h5" 154 | self.diffusion_model.save_weights( 155 | filepath=diffusion_model_path, 156 | overwrite=overwrite, 157 | save_format=save_format, 158 | options=options, 159 | ) 160 | self.diffusion_model_path = diffusion_model_path 161 | if self.train_text_encoder: 162 | text_encoder_model_path = ckpt_path_prefix + "-text_encoder.h5" 163 | self.text_encoder.save_weights( 164 | filepath=text_encoder_model_path, 165 | overwrite=overwrite, 166 | save_format=save_format, 167 | options=options, 168 | ) 169 | self.text_encoder_model_path = text_encoder_model_path 170 | -------------------------------------------------------------------------------- /train_dreambooth.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import os 6 | 7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 8 | import argparse 9 | import math 10 | 11 | import tensorflow as tf 12 | from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel 13 | from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder 14 | from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler 15 | 16 | import tensorflow as tf 17 | from tensorflow.keras import mixed_precision 18 | 19 | from src import utils 20 | from src.constants import MAX_PROMPT_LENGTH 21 | from src.datasets import DatasetUtils 22 | from src.dreambooth_trainer import DreamBoothTrainer 23 | from src.utils import QualitativeValidationCallback, DreamBoothCheckpointCallback 24 | 25 | import wandb 26 | from wandb.keras import WandbMetricsLogger 27 | 28 | 29 | # These hyperparameters come from this tutorial by Hugging Face: 30 | # https://github.com/huggingface/diffusers/tree/main/examples/dreambooth 31 | def get_optimizer( 32 | lr=5e-6, beta_1=0.9, beta_2=0.999, weight_decay=(1e-2,), epsilon=1e-08 33 | ): 34 | """Instantiates the AdamW optimizer.""" 35 | 36 | optimizer = tf.keras.optimizers.experimental.AdamW( 37 | learning_rate=lr, 38 | weight_decay=weight_decay, 39 | beta_1=beta_1, 40 | beta_2=beta_2, 41 | epsilon=epsilon, 42 | ) 43 | 44 | return optimizer 45 | 46 | 47 | def prepare_trainer( 48 | img_resolution: int, train_text_encoder: bool, use_mp: bool, **kwargs 49 | ): 50 | """Instantiates and compiles `DreamBoothTrainer` for training.""" 51 | image_encoder = ImageEncoder(img_resolution, img_resolution) 52 | 53 | dreambooth_trainer = DreamBoothTrainer( 54 | diffusion_model=DiffusionModel( 55 | img_resolution, img_resolution, MAX_PROMPT_LENGTH 56 | ), 57 | # Remove the top layer from the encoder, which cuts off 58 | # the variance and only returns the mean. 59 | vae=tf.keras.Model( 60 | image_encoder.input, 61 | image_encoder.layers[-2].output, 62 | ), 63 | noise_scheduler=NoiseScheduler(), 64 | train_text_encoder=train_text_encoder, 65 | use_mixed_precision=use_mp, 66 | **kwargs, 67 | ) 68 | 69 | optimizer = get_optimizer() 70 | dreambooth_trainer.compile(optimizer=optimizer, loss="mse") 71 | print("DreamBooth trainer initialized and compiled.") 72 | 73 | return dreambooth_trainer 74 | 75 | 76 | def train(dreambooth_trainer, train_dataset, max_train_steps, callbacks): 77 | """Performs DreamBooth training `DreamBoothTrainer` with the given `train_dataset`.""" 78 | num_update_steps_per_epoch = train_dataset.cardinality() 79 | epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 80 | print(f"Training for {epochs} epochs.") 81 | 82 | dreambooth_trainer.fit(train_dataset, epochs=epochs, callbacks=callbacks) 83 | 84 | 85 | def parse_args(): 86 | parser = argparse.ArgumentParser( 87 | description="Script to perform DreamBooth training using Stable Diffusion." 88 | ) 89 | # Dataset related. 90 | parser.add_argument( 91 | "--instance_images_url", 92 | default="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz", 93 | type=str, 94 | ) 95 | parser.add_argument( 96 | "--class_images_url", 97 | default="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz", 98 | type=str, 99 | ) 100 | parser.add_argument("--unique_id", default="sks", type=str) 101 | parser.add_argument("--class_category", default="dog", type=str) 102 | parser.add_argument("--img_resolution", default=512, type=int) 103 | # Optimization hyperparameters. 104 | parser.add_argument("--seed", default=42, type=int) 105 | parser.add_argument("--lr", default=5e-6, type=float) 106 | parser.add_argument("--wd", default=1e-2, type=float) 107 | parser.add_argument("--beta_1", default=0.9, type=float) 108 | parser.add_argument("--beta_2", default=0.999, type=float) 109 | parser.add_argument("--epsilon", default=1e-08, type=float) 110 | # Training hyperparameters. 111 | parser.add_argument("--batch_size", default=1, type=int) 112 | parser.add_argument("--max_train_steps", default=800, type=int) 113 | parser.add_argument( 114 | "--train_text_encoder", 115 | action="store_true", 116 | help="If fine-tune the text-encoder too.", 117 | ) 118 | parser.add_argument( 119 | "--mp", action="store_true", help="Whether to use mixed-precision." 120 | ) 121 | # Misc. 122 | parser.add_argument( 123 | "--log_wandb", action="store_true", help="Whether to use Weights & Biases for experiment tracking.", 124 | ) 125 | parser.add_argument( 126 | "--validation_prompts", 127 | nargs="+", 128 | default=None, 129 | type=str, 130 | help="Prompts to generate samples for validation purposes and logging on Weights & Biases", 131 | ) 132 | parser.add_argument( 133 | "--num_images_to_generate", 134 | default=5, 135 | type=int, 136 | help="Number of validation image to generate per prompt.", 137 | ) 138 | 139 | return parser.parse_args() 140 | 141 | 142 | def run(args): 143 | # Set random seed for reproducibility 144 | tf.keras.utils.set_random_seed(args.seed) 145 | 146 | validation_prompts = [f"A photo of {args.unique_id} {args.class_category} in a bucket"] 147 | if args.validation_prompts is not None: 148 | validation_prompts = args.validation_prompts 149 | 150 | run_name = f"lr@{args.lr}-max_train_steps@{args.max_train_steps}-train_text_encoder@{args.train_text_encoder}" 151 | if args.log_wandb: 152 | wandb.init(project="dreambooth-keras", name=run_name, config=vars(args)) 153 | 154 | if args.mp: 155 | print("Enabling mixed-precision...") 156 | policy = mixed_precision.Policy("mixed_float16") 157 | mixed_precision.set_global_policy(policy) 158 | assert policy.compute_dtype == "float16" 159 | assert policy.variable_dtype == "float32" 160 | 161 | print("Initializing dataset...") 162 | data_util = DatasetUtils( 163 | instance_images_url=args.instance_images_url, 164 | class_images_url=args.class_images_url, 165 | unique_id=args.unique_id, 166 | class_category=args.class_category, 167 | train_text_encoder=args.train_text_encoder, 168 | batch_size=args.batch_size, 169 | ) 170 | train_dataset = data_util.prepare_datasets() 171 | 172 | print("Initializing trainer...") 173 | ckpt_path_prefix = run_name 174 | dreambooth_trainer = prepare_trainer( 175 | args.img_resolution, args.train_text_encoder, args.mp 176 | ) 177 | 178 | callbacks = [ 179 | # save model checkpoint and optionally log model checkpoints to 180 | # Weights & Biases as artifacts 181 | DreamBoothCheckpointCallback(ckpt_path_prefix, save_weights_only=True) 182 | ] 183 | if args.log_wandb: 184 | # log training metrics to Weights & Biases 185 | callbacks.append(WandbMetricsLogger(log_freq="batch")) 186 | # perform inference on validation prompts at the end of every epoch and 187 | # log the resuts to a Weights & Biases table 188 | callbacks.append( 189 | QualitativeValidationCallback( 190 | img_heigth=args.img_resolution, 191 | img_width=args.img_resolution, 192 | prompts=validation_prompts, 193 | num_imgs_to_gen=args.num_images_to_generate, 194 | ) 195 | ) 196 | 197 | train(dreambooth_trainer, train_dataset, args.max_train_steps, callbacks) 198 | 199 | if args.log_wandb: 200 | wandb.finish() 201 | 202 | 203 | if __name__ == "__main__": 204 | args = parse_args() 205 | run(args) 206 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Callable, Dict, List, Tuple 3 | 4 | import keras_cv 5 | import numpy as np 6 | import tensorflow as tf 7 | from imutils import paths 8 | from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer 9 | from keras_cv.models.stable_diffusion.text_encoder import TextEncoder 10 | 11 | from src.constants import MAX_PROMPT_LENGTH, PADDING_TOKEN 12 | 13 | POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32) 14 | AUTO = tf.data.AUTOTUNE 15 | 16 | 17 | class DatasetUtils: 18 | """ 19 | DatasetUtils prepares a `tf.data.Dataset` object for DreamBooth training. 20 | It works in the following steps. First, it downloads images for instance 21 | and class (assuming they are compressed). Second, it optionally embeds the 22 | captions associated with the images with `TextEncoder`. Third, it builds 23 | `tf.data.Dataset` object of a pair of image and embeded text for instance 24 | and class separately. Finally, it zips the two `tf.data.Dataset` objects. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | instance_images_url: str, 30 | class_images_url: str, 31 | unique_id: str, 32 | class_category: str, 33 | train_text_encoder: bool, 34 | img_height: int = 512, 35 | img_width: int = 512, 36 | batch_size: int = 1, 37 | ): 38 | """ 39 | Args: 40 | instance_images_url: URL of a compressed file which contains 41 | a set of instance images. 42 | class_images_url: URL of a compressed file which contains a 43 | set of class images. 44 | unique_id: unique identifier to represent a new concept/instance. 45 | for instance, the typically used unique_id is "sks" in DreamBooth. 46 | class_category: a class of concept which the unique_id belongs 47 | to. For instance, if unique_id represents a specific dog, 48 | class_category should be "dog". 49 | train_text_encoder: Boolean flag to denote if the text encoder 50 | is fine-tuned. If set to True, only tokenized text batches 51 | are passed to the trainer. 52 | """ 53 | self.instance_images_url = instance_images_url 54 | self.class_images_url = class_images_url 55 | self.unique_id = unique_id 56 | self.class_category = class_category 57 | self.img_height = img_height 58 | self.img_width = img_width 59 | self.batch_size = batch_size 60 | 61 | self.tokenizer = SimpleTokenizer() 62 | self.train_text_encoder = train_text_encoder 63 | if not self.train_text_encoder: 64 | self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH) 65 | 66 | self.augmenter = keras_cv.layers.Augmenter( 67 | layers=[ 68 | keras_cv.layers.CenterCrop(self.img_height, self.img_width), 69 | keras_cv.layers.RandomFlip(), 70 | tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1), 71 | ] 72 | ) 73 | 74 | def _get_captions( 75 | self, num_instance_images: int, num_class_images: int 76 | ) -> Tuple[List, List]: 77 | """Prepares captions for instance and class images.""" 78 | instance_caption = f"a photo of {self.unique_id} {self.class_category}" 79 | instance_captions = [instance_caption] * num_instance_images 80 | 81 | class_caption = f"a photo of {self.class_category}" 82 | class_captions = [class_caption] * num_class_images 83 | 84 | return instance_captions, class_captions 85 | 86 | def _tokenize_text(self, caption: str) -> np.ndarray: 87 | """Tokenizes a given caption.""" 88 | tokens = self.tokenizer.encode(caption) 89 | tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens)) 90 | return np.array(tokens) 91 | 92 | def _tokenize_captions( 93 | self, instance_captions: List[str], class_captions: List[str] 94 | ) -> np.ndarray: 95 | """Tokenizes a batch of captions.""" 96 | tokenized_texts = np.empty( 97 | (len(instance_captions) + len(class_captions), MAX_PROMPT_LENGTH) 98 | ) 99 | for i, caption in enumerate(itertools.chain(instance_captions, class_captions)): 100 | tokenized_texts[i] = self._tokenize_text(caption) 101 | return tokenized_texts 102 | 103 | def _embed_captions(self, tokenized_texts: np.ndarray) -> np.ndarray: 104 | """Embeds captions with `TextEncoder`. This is done to save some memory.""" 105 | # Ensure the computation takes place on a GPU. 106 | gpus = tf.config.list_logical_devices("GPU") 107 | with tf.device(gpus[0].name): 108 | embedded_text = self.text_encoder( 109 | [tf.convert_to_tensor(tokenized_texts), POS_IDS], training=False 110 | ).numpy() 111 | 112 | del self.text_encoder # To ensure the GPU memory is freed. 113 | return embedded_text 114 | 115 | def _collate_instance_image_paths( 116 | self, instance_image_paths: List[str], class_image_paths: List[str] 117 | ) -> List: 118 | """Makes `instance_image_paths`'s length equal to the length of `class_image_paths`.""" 119 | new_instance_image_paths = [] 120 | for index in range(len(class_image_paths)): 121 | instance_image = instance_image_paths[index % len(instance_image_paths)] 122 | new_instance_image_paths.append(instance_image) 123 | 124 | return new_instance_image_paths 125 | 126 | def _download_images(self) -> Tuple[List, List]: 127 | """Downloads instance and class image archives from the URLs and 128 | un-archives them.""" 129 | instance_images_root = tf.keras.utils.get_file( 130 | origin=self.instance_images_url, 131 | untar=True, 132 | ) 133 | class_images_root = tf.keras.utils.get_file( 134 | origin=self.class_images_url, 135 | untar=True, 136 | ) 137 | 138 | instance_image_paths = list(paths.list_images(instance_images_root)) 139 | class_image_paths = list(paths.list_images(class_images_root)) 140 | instance_image_paths = self._collate_instance_image_paths( 141 | instance_image_paths, class_image_paths 142 | ) 143 | 144 | return instance_image_paths, class_image_paths 145 | 146 | def _process_image( 147 | self, image_path: tf.Tensor, text: tf.Tensor 148 | ) -> Tuple[tf.Tensor, tf.Tensor]: 149 | """Reads an image file and scales it. `text` can be either just tokens 150 | or embedded tokens.""" 151 | image = tf.io.read_file(image_path) 152 | image = tf.io.decode_png(image, 3) 153 | image = tf.image.resize(image, (self.img_height, self.img_width)) 154 | return image, text 155 | 156 | def _apply_augmentation( 157 | self, image_batch: tf.Tensor, text_batch: tf.Tensor 158 | ) -> Tuple[tf.Tensor, tf.Tensor]: 159 | """Applies data augmentation to a batch of images. `text_batch` can 160 | either be just tokens or embedded tokens.""" 161 | return self.augmenter(image_batch), text_batch 162 | 163 | def _prepare_dict(self, instance_only=True) -> Callable: 164 | """ 165 | Returns a function that returns a dictionary with an appropriate 166 | format for instance and class datasets. 167 | """ 168 | 169 | def fn(image_batch, texts) -> Dict[str, tf.Tensor]: 170 | if instance_only: 171 | batch_dict = { 172 | "instance_images": image_batch, 173 | "instance_texts": texts, 174 | } 175 | return batch_dict 176 | else: 177 | batch_dict = { 178 | "class_images": image_batch, 179 | "class_texts": texts, 180 | } 181 | return batch_dict 182 | 183 | return fn 184 | 185 | def _assemble_dataset( 186 | self, image_paths: List[str], texts: np.ndarray, instance_only=True 187 | ) -> tf.data.Dataset: 188 | """Assembles `tf.data.Dataset` object from image paths and their corresponding 189 | captions. `texts` can either be tokens or embedded tokens.""" 190 | dataset = tf.data.Dataset.from_tensor_slices((image_paths, texts)) 191 | dataset = dataset.map(self._process_image, num_parallel_calls=AUTO) 192 | dataset = dataset.shuffle(self.batch_size * 10, reshuffle_each_iteration=True) 193 | dataset = dataset.batch(self.batch_size) 194 | dataset = dataset.map(self._apply_augmentation, num_parallel_calls=AUTO) 195 | 196 | prepare_dict_fn = self._prepare_dict(instance_only=instance_only) 197 | dataset = dataset.map(prepare_dict_fn, num_parallel_calls=AUTO) 198 | return dataset 199 | 200 | def prepare_datasets(self) -> tf.data.Dataset: 201 | """Prepares dataset for DreamBooth training. 202 | 203 | 1. Download the instance and class images (archives) and un-archive them. 204 | 2. Prepare the instance and class image paths. 205 | 3. Prepare the instance and class captions. 206 | 4. Tokenize the captions. 207 | 5. If the text encoder is NOT fine-tuned then embed the tokenized captions. 208 | 6. Assemble the datasets. 209 | """ 210 | print("Downloading instance and class images...") 211 | instance_image_paths, class_image_paths = self._download_images() 212 | 213 | # Prepare captions. 214 | instance_captions, class_captions = self._get_captions( 215 | len(instance_image_paths), len(class_image_paths) 216 | ) 217 | # Tokenize the captions. 218 | text_batch = self._tokenize_captions(instance_captions, class_captions) 219 | 220 | # `text_batch` can either be embedded captions or tokenized captions. 221 | if not self.train_text_encoder: 222 | print("Embedding captions via TextEncoder...") 223 | text_batch = self._embed_captions(text_batch) 224 | 225 | print("Assembling instance and class datasets...") 226 | instance_dataset = self._assemble_dataset( 227 | instance_image_paths, 228 | text_batch[: len(instance_image_paths)], 229 | ) 230 | class_dataset = self._assemble_dataset( 231 | class_image_paths, 232 | text_batch[len(instance_image_paths) :], 233 | instance_only=False, 234 | ) 235 | 236 | train_dataset = tf.data.Dataset.zip((instance_dataset, class_dataset)) 237 | return train_dataset.prefetch(AUTO) 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of DreamBooth using KerasCV and TensorFlow 2 | 3 | This repository provides an implementation of [DreamBooth](https://arxiv.org/abs/2208.12242) using KerasCV and TensorFlow. The implementation is heavily referred from Hugging Face's `diffusers` [example](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth). 4 | 5 | DreamBooth is a way of quickly teaching (fine-tuning) Stable Diffusion about new visual concepts. For more details, refer to [this document](https://dreambooth.github.io/). 6 | 7 | **The code provided in this repository is for research purposes only**. Please check out [this section](https://github.com/keras-team/keras-cv/tree/master/keras_cv/models/stable_diffusion#uses) to know more about the potential use cases and limitations. 8 | 9 | By loading this model you accept the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE. 10 | 11 |
12 | 13 |
14 | 15 | If you're just looking for the accompanying resources of this repository, here are the links: 16 | 17 | * [Inference Colab Notebook](https://colab.research.google.com/github/sayakpaul/dreambooth-keras/blob/main/notebooks/inference_dreambooth.ipynb) 18 | * [Blog post on keras.io](https://keras.io/examples/generative/dreambooth/) 19 | * [Fine-tuned model weights](https://huggingface.co/chansung/dreambooth-dog) 20 | 21 | ### Table of contents 22 | 23 | * [Performing DreamBooth training with the codebase](#steps-to-perform-dreambooth-training-using-the-codebase) 24 | * [Running inference](#inference) 25 | * [Results](#results) 26 | * [Using in Diffusers 🧨](#using-in-diffusers-) 27 | * [Notes](#notes-on-preparing-data-for-dreambooth-training-of-faces) 28 | * [Acknowledgements](#acknowledgements) 29 | 30 | **Update 15/02/2023**: Thanks to [Soumik Rakshit](https://in.linkedin.com/in/soumikrakshit); we now have better utilities to support Weights and Biases (see https://github.com/sayakpaul/dreambooth-keras/pull/22). 31 | 32 | ## Steps to perform DreamBooth training using the codebase 33 | 34 | 1. Install the pre-requisites: `pip install -r requirements.txt`. 35 | 36 | 2. You first need to choose a class to which a unique identifier is appended. This repository codebase was tested using `sks` as the unique idenitifer and `dog` as the class. 37 | 38 | Then two types of prompts are generated: 39 | 40 | (a) **instance prompt**: f"a photo of {self.unique_id} {self.class_category}" 41 | (b) **class prompt**: f"a photo of {self.class_category}" 42 | 43 | 3. **Instance images** 44 | 45 | Get a few images (3 - 10) that are representative of the concept the model is going to be fine-tuned with. These images would be associated with the `instance_prompt`. These images are referred to as the `instance_images` from the codebase. Archive these images and host them somewhere online such that the archive can be downloaded using `tf.keras.utils.get_file()` function internally. 46 | 47 | 4. **Class images** 48 | 49 | DreamBooth uses prior-preservation loss to regularize training. Long story cut short, 50 | prior-preservation loss helps the model to slowly adapt to the new concept under consideration from any prior knowledge it may have had about the concept. To use prior-preservation loss, we need the class prompt as shown above. The class prompt is used to generate a pre-defined number of images which are used for computing the final loss used for DreamBooth training. 51 | 52 | As per [this resource](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth), 200 - 300 images generated using the class prompt work well for most cases. 53 | 54 | So, after you have decided `instance_prompt` and `class_prompt`, use [this Colab Notebook](https://colab.research.google.com/github/sayakpaul/dreambooth-keras/blob/main/notebooks/generate_class_priors.ipynb) to generate some images that would be used for training with the prior-preservation loss. Then archive the generated images as a single archive and host it online such that it can be downloaded using using `tf.keras.utils.get_file()` function internally. In the codebase, we simply refer to these images as `class_images`. 55 | 56 | > It's possible to conduct DreamBooth training WITHOUT using a prior preservation loss. This repository always uses it. For people to easily test this codebase, we hosted the instance and class images [here](https://huggingface.co/datasets/sayakpaul/sample-datasets/tree/main). 57 | 58 | 5. Launch training! There are a number of hyperparameters you can play around with. Refer to the `train_dreambooth.py` script to know more about them. Here's a command that launches training with mixed-precision and other default values: 59 | 60 | ```bash 61 | python train_dreambooth.py --mp 62 | ``` 63 | 64 | You can also fine-tune the text encoder by specifying the `--train_text_encoder` option. 65 | 66 | Additionally, the script supports integration with [Weights and Biases (`wandb`)](https://wandb.ai/). If you specify `--log_wandb`, 67 | - it will automatically log the training metrics to your `wandb` dashboard using the [`WandbMetricsLogger` callback](https://docs.wandb.ai/guides/integrations/keras#experiment-tracking-with-wandbmetricslogger). 68 | - it will also upload your model checkpoints at the end of each epoch to your `wandb` project as an [artifacts](https://docs.wandb.ai/guides/artifacts) for model versioning. This is done using the `DreamBoothCheckpointCallback` which was built using [`WandbModelCheckpoint` callback](https://docs.wandb.ai/guides/integrations/keras#model-checkpointing-using-wandbmodelcheckpoint). 69 | - it will also perform inference with the DreamBoothed model parameters at the end of each epoch and log them into a [`wandb.Table`](https://docs.wandb.ai/guides/data-vis) in your `wandb` dashboard. This is done using the `QualitativeValidationCallback`, which also logs generated images into a media panel on your `wandb` dashboard at the end of the training. 70 | 71 | Here's a command that launches training and logs training metrics and generated images to your Weights & Biases workspace: 72 | 73 | ```bash 74 | python train_dreambooth.py \ 75 | --log_wandb \ 76 | --validation_prompts \ 77 | "a photo of sks dog with a cat" \ 78 | "a photo of sks dog riding a bicycle" \ 79 | "a photo of sks dog peeing" \ 80 | "a photo of sks dog playing cricket" \ 81 | "a photo of sks dog as an astronaut" 82 | ``` 83 | 84 | [Here's](https://wandb.ai/geekyrakshit/dreambooth-keras/runs/huou7nzr) an example `wandb` run where you can find the generated images as well as the [model checkpoints](https://wandb.ai/geekyrakshit/dreambooth-keras/artifacts/model/run_huou7nzr_model). 85 | 86 | ## Inference 87 | 88 | * [Colab Notebook](https://colab.research.google.com/github/sayakpaul/dreambooth-keras/blob/main/notebooks/inference_dreambooth.ipynb) 89 | * [Script for launching bulk experiments](https://github.com/sayakpaul/dreambooth-keras/blob/main/scripts/generate_experimental_images.py) 90 | 91 | ## Results 92 | 93 | We have tested our implementation in two different methods: (a) fine-tuning the diffusion model (the UNet) only, (b) fine-tuning the diffusion model along with the text encoder. The experiments were conducted over a wide range of hyperparameters for `learning rate` and `training steps` for during training and for `number of steps` and `unconditional guidance scale` (ugs) during inference. But only the most salient results (from our perspective) are included here. If you are curious about how different hyperparameters affect the generated image quality, find the link to the full reports in each section. 94 | 95 | __Note that our experiments were guided by [this blog post from Hugging Face](https://huggingface.co/blog/dreambooth).__ 96 | 97 | ### (a) Fine-tuning diffusion model 98 | 99 | Here are a selected few results from various experiments we conducted. Our experimental logs for this setting are available [here](https://wandb.ai/sayakpaul/dreambooth-keras). More visualization images (generated with the checkpoints from these experiments) are available [here](https://wandb.ai/sayakpaul/experimentation_images). 100 | 101 | 102 |
103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 |
ImagesStepsUGSSetting
5030LR: 1e-6 Training steps: 800 (Weights)
2515LR: 1e-6 Training steps: 1000 (Weights)
7515LR: 3e-6 Training steps: 1200 (Weights)
129 | Caption: "A photo of sks dog in a bucket" 130 |
131 | 132 | ### (b) Fine-tuning text encoder + diffusion model 133 | 134 |
135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 |
ImagesStepsugs
7515
7530
152 | "Caption: A photo of sks dog in a bucket" 153 | 154 | w/ learning rate=9e-06, max train steps=200 (weights | reports) 155 |

156 | 157 | 158 |
159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 |
ImagesStepsugs
15015
7530
176 | "Caption: A photo of sks person without mustache, handsome, ultra realistic, 4k, 8k" 177 | 178 | w/ learning rate=9e-06, max train steps=200 (datasets | reports) 179 |

180 | 181 | ## Using in Diffusers 🧨 182 | 183 | The [`diffusers` library](https://github.com/huggingface/diffusers/) provides state-of-the-art tooling for experimenting with 184 | different Diffusion models, including Stable Diffusion. It includes 185 | different optimization techniques that can be leveraged to perform efficient inference 186 | with `diffusers` when using large Stable Diffusion checkpoints. One particularly 187 | advantageous feature `diffusers` has is its support for [different schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers) that can 188 | be configured during runtime and can be integrated into any compatible Diffusion model. 189 | 190 | Once you have obtained the DreamBooth fine-tuned checkpoints using this codebase, you can actually 191 | export those into a handy [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview) and use it from the `diffusers` library directly. 192 | 193 | Consider this repository: [chansung/dreambooth-dog](https://huggingface.co/chansung/dreambooth-dog). You can use the 194 | checkpoints of this repository in a `StableDiffusionPipeline` after running some small steps: 195 | 196 | ```py 197 | from diffusers import StableDiffusionPipeline 198 | 199 | # checkpoint of the converted Stable Diffusion from KerasCV 200 | model_ckpt = "sayakpaul/text-unet-dogs-kerascv_sd_diffusers_pipeline" 201 | pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt) 202 | pipeline.to("cuda") 203 | 204 | unique_id = "sks" 205 | class_label = "dog" 206 | prompt = f"A photo of {unique_id} {class_label} in a bucket" 207 | image = pipeline(prompt, num_inference_steps=50).images[0] 208 | ``` 209 | 210 | Follow [this guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/kerascv) to know more. 211 | 212 | 213 | ### Experimental results through various scheduler settings: 214 | 215 | We have converted fine-tuned checkpoint for the dog images into Diffusers compatible StableDiffusionPipeline and ran various experiments with different scheduler settings. For example, the following parameters of the `DDIMScheduler` are tested on a different set of `guidance_scale` and `num_inference_steps`. 216 | 217 | ```python 218 | num_inference_steps_list = [25, 50, 75, 100] 219 | guidance_scale_list = [7.5, 15, 30] 220 | 221 | scheduler_configs = { 222 | "DDIMScheduler": { 223 | "beta_value": [ 224 | [0.000001, 0.02], 225 | [0.000005, 0.02], 226 | [0.00001, 0.02], 227 | [0.00005, 0.02], 228 | [0.0001, 0.02], 229 | [0.0005, 0.02] 230 | ], 231 | "beta_schedule": [ 232 | "linear", 233 | "scaled_linear", 234 | "squaredcos_cap_v2" 235 | ], 236 | "clip_sample": [True, False], 237 | "set_alpha_to_one": [True, False], 238 | "prediction_type": [ 239 | "epsilon", 240 | "sample", 241 | "v_prediction" 242 | ] 243 | } 244 | } 245 | ``` 246 | 247 | Below is the comparison between different values of `beta_schedule` parameters while others are fixed to their default values. Take a look at [the original report](https://docs.google.com/spreadsheets/d/1_NhWuORn5ByEnvD9T3X4sHUnz_GR8uEtbE5HbI98hOM/edit?usp=sharing) which includes the results from other schedulers such as `PNDMScheduler` and `LMSDiscreteScheduler`. 248 | 249 | It is often observed the default settings do guarantee to generate better quality images. For example, the default values of `guidance_scale` and `beta_schedule` are set to 7.5 and `linear`. However, when `guidance_scale` is set to 7.5, `scaled_linear` of the `beta_schedule` seems to work better. Or, when `beta_schedule` is set to `linear`, higher `guidance_scale` seems to work better. 250 | 251 | ![](https://i.postimg.cc/QsW-CKTcv/DDIMScheduler.png) 252 | 253 | We ran 4,800 experiments which generated 38,400 images in total. Those experiments are logged in Weights and Biases. If you are curious, do check them out [here](https://wandb.ai/chansung18/SD-Scheduler-Explore?workspace=user-chansung18) as well as the [script](https://gist.github.com/deep-diver/0a2deb2cd369ab8c1bf3ee12f47d272a) that was used to run the experiments. 254 | 255 | ## Notes on preparing data for DreamBooth training of faces 256 | 257 | In addition to the tips and tricks shared in [this blog post](https://huggingface.co/blog/dreambooth#using-prior-preservation-when-training-faces), we followed these things while preparing the instances for conducting DreamBooth training on human faces: 258 | 259 | * Instead of 3 - 5 images, use 20 - 25 images of the same person varying different angles, backgrounds, and poses. 260 | * No use of images containing multiple persons. 261 | * If the person wears glasses, don't include images only with glasses. Combine images with and without glasses. 262 | 263 | Thanks to [Abhishek Thakur](https://no.linkedin.com/in/abhi1thakur) for sharing these tips. 264 | 265 | ## Acknowledgements 266 | 267 | * Thanks to Hugging Face for providing the [original example](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth). It's very readable and easy to understand. 268 | * Thanks to the ML Developer Programs' team at Google for providing GCP credits. 269 | --------------------------------------------------------------------------------