├── LICENSE ├── README.md ├── dreambooth_depth.ipynb └── train_dreambooth.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 epitaque 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dreambooth with depth2image 2 | This repo adapts the dreambooth script from diffusers to train the [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) model. 3 | It works by creating a depth map from all the input images, then adds those depth maps as conditioning to the unet. 4 | Check out the notebook for a demo of how to use the script. 5 | I used the conda `environment.yaml` from the stable diffusion GitHub repo. 6 | 7 | ## NOTICE 8 | At the time of writing, `StableDiffusionDepth2ImgPipeline` is supported only on the `main` branch of `transformers`. Install it with the following: 9 | ``` 10 | pip uninstall -y transformers 11 | pip install git+https://github.com/huggingface/transformers.git@main 12 | ``` 13 | -------------------------------------------------------------------------------- /dreambooth_depth.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "NVIDIA GeForce RTX 3090, 24576 MiB, 20390 MiB\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "%pip install jupyter_compare_view" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# Fill in these environment variables\n", 36 | "%env MODEL_NAME=stabilityai/stable-diffusion-2-depth\n", 37 | "%env INSTANCE_DIR=/workspace/content/data/\n", 38 | "%env CLASS_DIR=/workspace/content/data/person\n", 39 | "%env OUTPUT_DIR=/workspace/content/data/output-2\n", 40 | "%load_ext autoreload" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 5, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "%autoreload 2\n", 50 | "\n", 51 | "# !python -m debugpy --listen 0.0.0.0:5678 --wait-for-client \\\n", 52 | "!accelerate launch \\\n", 53 | " train_dreambooth.py \\\n", 54 | " --mixed_precision=\"fp16\" \\\n", 55 | " --pretrained_model_name_or_path=$MODEL_NAME \\\n", 56 | " --pretrained_txt2img_model_name_or_path=\"stabilityai/stable-diffusion-2-1-base\" \\\n", 57 | " --train_text_encoder \\\n", 58 | " --instance_data_dir=$INSTANCE_DIR \\\n", 59 | " --class_data_dir=$CLASS_DIR \\\n", 60 | " --output_dir=$OUTPUT_DIR \\\n", 61 | " --with_prior_preservation --prior_loss_weight=1.0 \\\n", 62 | " --instance_prompt=\"a photo of person\" \\\n", 63 | " --class_prompt=\"a photo of person\" \\\n", 64 | " --resolution=512 \\\n", 65 | " --train_batch_size=1 \\\n", 66 | " --gradient_accumulation_steps=1 \\\n", 67 | " --learning_rate=1e-6 \\\n", 68 | " --lr_scheduler=\"constant\" \\\n", 69 | " --lr_warmup_steps=0 \\\n", 70 | " --num_class_images=200 \\\n", 71 | " --max_train_steps=300 \\\n", 72 | " --use_8bit_adam" 73 | ] 74 | }, 75 | { 76 | "attachments": {}, 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "# Test the model you just made" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 2, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "import PIL\n", 90 | "import torch\n", 91 | "from torchvision import transforms\n", 92 | "import diffusers\n", 93 | "import transformers\n", 94 | "from diffusers import StableDiffusionDepth2ImgPipeline\n", 95 | "import os\n", 96 | "from jupyter_compare_view import compare" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 3, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "print(f'Getting model from {os.environ.get(\"OUTPUT_DIR\")}')\n", 106 | "pipeline = StableDiffusionDepth2ImgPipeline.from_pretrained(os.environ.get('OUTPUT_DIR'))\n", 107 | "pipeline = pipeline.to(\"cuda\")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Use an image as an input depth map\n", 117 | "image_path = \"/workspace/content/data/samples/village.jpg\" # replace with whatever you want\n", 118 | "image = PIL.Image.open(image_path)\n", 119 | "\n", 120 | "image_transform = transforms.Compose(\n", 121 | " [\n", 122 | " transforms.Resize((384, 384)),\n", 123 | " transforms.ToTensor()\n", 124 | " ]\n", 125 | ")\n", 126 | "image = image_transform(image)\n", 127 | "image = image[None,:,:,:]\n", 128 | "image = image.to(\"cuda\")\n", 129 | "depth_map = pipeline.depth_estimator(image).predicted_depth\n", 130 | "image = transforms.ToPILImage()(image[0])\n", 131 | "depth_min = torch.amin(depth_map, dim=[0, 1, 2], keepdim=True)\n", 132 | "depth_max = torch.amax(depth_map, dim=[0, 1, 2], keepdim=True)\n", 133 | "depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0\n", 134 | "depth_map = depth_map[0,:,:]\n", 135 | "depth_map = transforms.ToPILImage()(depth_map)\n", 136 | "compare(depth_map, image, cmap=\"gray\", start_mode=\"horizontal\", start_slider_pos=0.73)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "result = pipeline(\"a photo of , standing, Kodachrome, Canon 5D, f2 aperture, extremely detailed, sharp focus\", image)\n", 146 | "compare(result[0][0], image, cmap=\"gray\", start_mode=\"horizontal\", start_slider_pos=0.73)" 147 | ] 148 | } 149 | ], 150 | "metadata": { 151 | "kernelspec": { 152 | "display_name": "ldm", 153 | "language": "python", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "codemirror_mode": { 158 | "name": "ipython", 159 | "version": 3 160 | }, 161 | "file_extension": ".py", 162 | "mimetype": "text/x-python", 163 | "name": "python", 164 | "nbconvert_exporter": "python", 165 | "pygments_lexer": "ipython3", 166 | "version": "3.8.15" 167 | }, 168 | "orig_nbformat": 4, 169 | "vscode": { 170 | "interpreter": { 171 | "hash": "819cb4a3c362ba74ab36110a3c649a29c95096d034ca532b6baac1c46e03a72e" 172 | } 173 | } 174 | }, 175 | "nbformat": 4, 176 | "nbformat_minor": 2 177 | } 178 | -------------------------------------------------------------------------------- /train_dreambooth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import hashlib 4 | import itertools 5 | import math 6 | import os 7 | import warnings 8 | from pathlib import Path 9 | from typing import Optional 10 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.utils.checkpoint 15 | from torch.utils.data import Dataset 16 | 17 | from accelerate import Accelerator 18 | from accelerate.logging import get_logger 19 | from accelerate.utils import set_seed 20 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel 21 | from diffusers.optimization import get_scheduler 22 | from diffusers.utils import check_min_version 23 | from diffusers.utils.import_utils import is_xformers_available 24 | from huggingface_hub import HfFolder, Repository, whoami 25 | from PIL import Image 26 | from torchvision import transforms 27 | from tqdm.auto import tqdm 28 | from transformers import AutoTokenizer, PretrainedConfig 29 | 30 | 31 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 32 | check_min_version("0.10.0.dev0") 33 | 34 | logger = get_logger(__name__) 35 | 36 | 37 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 38 | text_encoder_config = PretrainedConfig.from_pretrained( 39 | pretrained_model_name_or_path, 40 | subfolder="text_encoder", 41 | revision=revision, 42 | ) 43 | model_class = text_encoder_config.architectures[0] 44 | 45 | if model_class == "CLIPTextModel": 46 | from transformers import CLIPTextModel 47 | 48 | return CLIPTextModel 49 | elif model_class == "RobertaSeriesModelWithTransformation": 50 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 51 | 52 | return RobertaSeriesModelWithTransformation 53 | else: 54 | raise ValueError(f"{model_class} is not supported.") 55 | 56 | 57 | def parse_args(input_args=None): 58 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 59 | parser.add_argument( 60 | "--pretrained_model_name_or_path", 61 | type=str, 62 | default=None, 63 | required=True, 64 | help="Path to pretrained model or model identifier from huggingface.co/models.", 65 | ) 66 | parser.add_argument( 67 | "--pretrained_txt2img_model_name_or_path", 68 | type=str, 69 | default=None, 70 | required=True, 71 | help="Path to pretrained model or model identifier from huggingface.co/models. This model will be used to generate images from text, without depth conditioning, for generating sample images.", 72 | ) 73 | parser.add_argument( 74 | "--revision", 75 | type=str, 76 | default=None, 77 | required=False, 78 | help="Revision of pretrained model identifier from huggingface.co/models.", 79 | ) 80 | parser.add_argument( 81 | "--tokenizer_name", 82 | type=str, 83 | default=None, 84 | help="Pretrained tokenizer name or path if not the same as model_name", 85 | ) 86 | parser.add_argument( 87 | "--instance_data_dir", 88 | type=str, 89 | default=None, 90 | required=True, 91 | help="A folder containing the training data of instance images.", 92 | ) 93 | parser.add_argument( 94 | "--class_data_dir", 95 | type=str, 96 | default=None, 97 | required=False, 98 | help="A folder containing the training data of class images.", 99 | ) 100 | parser.add_argument( 101 | "--instance_prompt", 102 | type=str, 103 | default=None, 104 | required=True, 105 | help="The prompt with identifier specifying the instance", 106 | ) 107 | parser.add_argument( 108 | "--class_prompt", 109 | type=str, 110 | default=None, 111 | help="The prompt to specify images in the same class as provided instance images.", 112 | ) 113 | parser.add_argument( 114 | "--with_prior_preservation", 115 | default=False, 116 | action="store_true", 117 | help="Flag to add prior preservation loss.", 118 | ) 119 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 120 | parser.add_argument( 121 | "--num_class_images", 122 | type=int, 123 | default=100, 124 | help=( 125 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 126 | " class_data_dir, additional images will be sampled with class_prompt." 127 | ), 128 | ) 129 | parser.add_argument( 130 | "--output_dir", 131 | type=str, 132 | default="text-inversion-model", 133 | help="The output directory where the model predictions and checkpoints will be written.", 134 | ) 135 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 136 | parser.add_argument( 137 | "--resolution", 138 | type=int, 139 | default=512, 140 | help=( 141 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 142 | " resolution" 143 | ), 144 | ) 145 | parser.add_argument( 146 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 147 | ) 148 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 149 | parser.add_argument( 150 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 151 | ) 152 | parser.add_argument( 153 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 154 | ) 155 | parser.add_argument("--num_train_epochs", type=int, default=1) 156 | parser.add_argument( 157 | "--max_train_steps", 158 | type=int, 159 | default=None, 160 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 161 | ) 162 | parser.add_argument( 163 | "--checkpointing_steps", 164 | type=int, 165 | default=500, 166 | help=( 167 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 168 | " training using `--resume_from_checkpoint`." 169 | ), 170 | ) 171 | parser.add_argument( 172 | "--resume_from_checkpoint", 173 | type=str, 174 | default=None, 175 | help=( 176 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 177 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 178 | ), 179 | ) 180 | parser.add_argument( 181 | "--gradient_accumulation_steps", 182 | type=int, 183 | default=1, 184 | help="Number of updates steps to accumulate before performing a backward/update pass.", 185 | ) 186 | parser.add_argument( 187 | "--gradient_checkpointing", 188 | action="store_true", 189 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 190 | ) 191 | parser.add_argument( 192 | "--learning_rate", 193 | type=float, 194 | default=5e-6, 195 | help="Initial learning rate (after the potential warmup period) to use.", 196 | ) 197 | parser.add_argument( 198 | "--scale_lr", 199 | action="store_true", 200 | default=False, 201 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 202 | ) 203 | parser.add_argument( 204 | "--lr_scheduler", 205 | type=str, 206 | default="constant", 207 | help=( 208 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 209 | ' "constant", "constant_with_warmup"]' 210 | ), 211 | ) 212 | parser.add_argument( 213 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 214 | ) 215 | parser.add_argument( 216 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 217 | ) 218 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 219 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 220 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 221 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 222 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 223 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 224 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 225 | parser.add_argument( 226 | "--hub_model_id", 227 | type=str, 228 | default=None, 229 | help="The name of the repository to keep in sync with the local `output_dir`.", 230 | ) 231 | parser.add_argument( 232 | "--logging_dir", 233 | type=str, 234 | default="logs", 235 | help=( 236 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 237 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 238 | ), 239 | ) 240 | parser.add_argument( 241 | "--mixed_precision", 242 | type=str, 243 | default=None, 244 | choices=["no", "fp16", "bf16"], 245 | help=( 246 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 247 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 248 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 249 | ), 250 | ) 251 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 252 | 253 | if input_args is not None: 254 | args = parser.parse_args(input_args) 255 | else: 256 | args = parser.parse_args() 257 | 258 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 259 | if env_local_rank != -1 and env_local_rank != args.local_rank: 260 | args.local_rank = env_local_rank 261 | 262 | if args.with_prior_preservation: 263 | if args.class_data_dir is None: 264 | raise ValueError("You must specify a data directory for class images.") 265 | if args.class_prompt is None: 266 | raise ValueError("You must specify prompt for class images.") 267 | else: 268 | # logger is not available yet 269 | if args.class_data_dir is not None: 270 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") 271 | if args.class_prompt is not None: 272 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") 273 | 274 | return args 275 | 276 | def get_depth_image_path(normal_image_path): 277 | return normal_image_path.parent / f"{normal_image_path.stem}_depth.png" 278 | 279 | class DreamBoothDataset(Dataset): 280 | """ 281 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 282 | It pre-processes the images and the tokenizes prompts. 283 | """ 284 | 285 | def __init__( 286 | self, 287 | instance_data_root, 288 | instance_prompt, 289 | tokenizer, 290 | vae_scale_factor, 291 | class_data_root=None, 292 | class_prompt=None, 293 | size=512, 294 | center_crop=False 295 | ): 296 | self.size = size 297 | self.center_crop = center_crop 298 | self.tokenizer = tokenizer 299 | self.vae_scale_factor = vae_scale_factor 300 | 301 | self.instance_data_root = Path(instance_data_root) 302 | if not self.instance_data_root.exists(): 303 | raise ValueError("Instance images root doesn't exists.") 304 | 305 | self.instance_images_path = list(filter(lambda path: str(path).find("_depth.") == -1, self.instance_data_root.iterdir())) 306 | self.num_instance_images = len(self.instance_images_path) 307 | self.instance_prompt = instance_prompt 308 | self._length = self.num_instance_images 309 | 310 | if class_data_root is not None: 311 | self.class_data_root = Path(class_data_root) 312 | self.class_data_root.mkdir(parents=True, exist_ok=True) 313 | self.class_images_path = list(filter(lambda path: str(path).find("_depth.") == -1, self.class_data_root.iterdir())) 314 | self.num_class_images = len(self.class_images_path) 315 | self._length = max(self.num_class_images, self.num_instance_images) 316 | self.class_prompt = class_prompt 317 | else: 318 | self.class_data_root = None 319 | 320 | self.image_transforms = transforms.Compose( 321 | [ 322 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 323 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 324 | transforms.ToTensor(), 325 | transforms.Normalize([0.5], [0.5]), 326 | ] 327 | ) 328 | 329 | self.depth_image_transforms = transforms.Compose( 330 | [ 331 | transforms.Resize(size // self.vae_scale_factor, interpolation=transforms.InterpolationMode.BILINEAR), 332 | transforms.ToTensor() 333 | ] 334 | ) 335 | def __len__(self): 336 | return self._length 337 | 338 | def __getitem__(self, index): 339 | example = {} 340 | instance_image_path = self.instance_images_path[index % self.num_instance_images] 341 | instance_depth_image_path = get_depth_image_path(instance_image_path) 342 | instance_image = Image.open(instance_image_path) 343 | instance_depth_image = Image.open(instance_depth_image_path) 344 | if not instance_image.mode == "RGB": 345 | instance_image = instance_image.convert("RGB") 346 | example["instance_images"] = self.image_transforms(instance_image) 347 | example["instance_depth_images"] = self.depth_image_transforms(instance_depth_image) 348 | example["instance_prompt_ids"] = self.tokenizer( 349 | self.instance_prompt, 350 | truncation=True, 351 | padding="max_length", 352 | max_length=self.tokenizer.model_max_length, 353 | return_tensors="pt", 354 | ).input_ids 355 | 356 | if self.class_data_root: 357 | class_image_path = self.class_images_path[index % self.num_class_images] 358 | class_depth_image_path = get_depth_image_path(class_image_path) 359 | class_image = Image.open(class_image_path) 360 | class_depth_image = Image.open(class_depth_image_path) 361 | if not class_image.mode == "RGB": 362 | class_image = class_image.convert("RGB") 363 | example["class_images"] = self.image_transforms(class_image) 364 | example["class_depth_images"] = self.depth_image_transforms(class_depth_image) 365 | example["class_prompt_ids"] = self.tokenizer( 366 | self.class_prompt, 367 | truncation=True, 368 | padding="max_length", 369 | max_length=self.tokenizer.model_max_length, 370 | return_tensors="pt", 371 | ).input_ids 372 | 373 | return example 374 | 375 | 376 | def collate_fn(examples, with_prior_preservation=False): 377 | input_ids = [example["instance_prompt_ids"] for example in examples] 378 | pixel_values = [example["instance_images"] for example in examples] 379 | depth_values = [example["instance_depth_images"] for example in examples] 380 | 381 | # Concat class and instance examples for prior preservation. 382 | # We do this to avoid doing two forward passes. 383 | if with_prior_preservation: 384 | input_ids += [example["class_prompt_ids"] for example in examples] 385 | pixel_values += [example["class_images"] for example in examples] 386 | depth_values += [example["class_depth_images"] for example in examples] 387 | 388 | pixel_values = torch.stack(pixel_values) 389 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 390 | 391 | depth_values = torch.stack(depth_values) 392 | depth_values = depth_values.to(memory_format=torch.contiguous_format).float() 393 | 394 | input_ids = torch.cat(input_ids, dim=0) 395 | 396 | batch = { 397 | "input_ids": input_ids, 398 | "pixel_values": pixel_values, 399 | "depth_values": depth_values 400 | } 401 | return batch 402 | 403 | 404 | class PromptDataset(Dataset): 405 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 406 | 407 | def __init__(self, prompt, num_samples): 408 | self.prompt = prompt 409 | self.num_samples = num_samples 410 | print(f'Creating prompt dataset with prompt={prompt} and num_samples={num_samples}') 411 | 412 | def __len__(self): 413 | return self.num_samples 414 | 415 | def __getitem__(self, index): 416 | example = {} 417 | example["prompt"] = self.prompt 418 | example["index"] = index 419 | return example 420 | 421 | 422 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 423 | if token is None: 424 | token = HfFolder.get_token() 425 | if organization is None: 426 | username = whoami(token)["name"] 427 | return f"{username}/{model_id}" 428 | else: 429 | return f"{organization}/{model_id}" 430 | 431 | def create_depth_images(paths, pretrained_model_name_or_path, accelerator, unet, text_encoder): 432 | pipeline = DiffusionPipeline.from_pretrained( 433 | pretrained_model_name_or_path, 434 | unet=accelerator.unwrap_model(unet), 435 | text_encoder=accelerator.unwrap_model(text_encoder), 436 | revision=args.revision, 437 | ) 438 | pipeline.to("cuda") 439 | for path in paths: 440 | print(f"For each image in {path}, creating a depth image.") 441 | path_iterator = Path(path).iterdir() 442 | non_depth_image_files = list(filter(lambda path: str(path).find("_depth.") == -1, path_iterator)) 443 | for image_path in tqdm(non_depth_image_files): 444 | depth_path = get_depth_image_path(image_path) 445 | if depth_path.exists(): 446 | continue 447 | image_instance = Image.open(image_path) 448 | if not image_instance.mode == "RGB": 449 | image_instance = image_instance.convert("RGB") 450 | image_instance = pipeline.feature_extractor(image_instance, return_tensors="pt").pixel_values 451 | image_instance = image_instance.to("cuda") 452 | depth_map = pipeline.depth_estimator(image_instance).predicted_depth 453 | depth_min = torch.amin(depth_map, dim=[0, 1, 2], keepdim=True) 454 | depth_max = torch.amax(depth_map, dim=[0, 1, 2], keepdim=True) 455 | depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 456 | depth_map = depth_map[0,:,:] 457 | depth_map_image = transforms.ToPILImage()(depth_map) 458 | depth_map_image.save(depth_path) 459 | return 2 ** (len(pipeline.vae.config.block_out_channels) - 1) 460 | 461 | def main(args): 462 | logging_dir = Path(args.output_dir, args.logging_dir) 463 | 464 | accelerator = Accelerator( 465 | gradient_accumulation_steps=args.gradient_accumulation_steps, 466 | mixed_precision=args.mixed_precision, 467 | log_with="tensorboard", 468 | logging_dir=logging_dir, 469 | ) 470 | 471 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 472 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 473 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 474 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 475 | raise ValueError( 476 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 477 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 478 | ) 479 | 480 | if args.seed is not None: 481 | set_seed(args.seed) 482 | 483 | if args.with_prior_preservation: 484 | class_images_dir = Path(args.class_data_dir) 485 | if not class_images_dir.exists(): 486 | class_images_dir.mkdir(parents=True) 487 | cur_class_images = len(list(class_images_dir.iterdir())) 488 | 489 | if cur_class_images < args.num_class_images: 490 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 491 | pipeline = DiffusionPipeline.from_pretrained( 492 | args.pretrained_txt2img_model_name_or_path, 493 | torch_dtype=torch_dtype, 494 | safety_checker=None, 495 | revision=args.revision, 496 | ) 497 | pipeline.set_progress_bar_config(disable=True) 498 | 499 | num_new_images = args.num_class_images - cur_class_images 500 | logger.info(f"Number of class images to sample: {num_new_images}.") 501 | 502 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 503 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 504 | 505 | sample_dataloader = accelerator.prepare(sample_dataloader) 506 | pipeline.to(accelerator.device) 507 | 508 | for example in tqdm( 509 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 510 | ): 511 | images = pipeline(example["prompt"]).images 512 | 513 | for i, image in enumerate(images): 514 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 515 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 516 | image.save(image_filename) 517 | 518 | del pipeline 519 | if torch.cuda.is_available(): 520 | torch.cuda.empty_cache() 521 | 522 | # Handle the repository creation 523 | if accelerator.is_main_process: 524 | if args.push_to_hub: 525 | if args.hub_model_id is None: 526 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 527 | else: 528 | repo_name = args.hub_model_id 529 | repo = Repository(args.output_dir, clone_from=repo_name) 530 | 531 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 532 | if "step_*" not in gitignore: 533 | gitignore.write("step_*\n") 534 | if "epoch_*" not in gitignore: 535 | gitignore.write("epoch_*\n") 536 | elif args.output_dir is not None: 537 | os.makedirs(args.output_dir, exist_ok=True) 538 | 539 | # Load the tokenizer 540 | if args.tokenizer_name: 541 | tokenizer = AutoTokenizer.from_pretrained( 542 | args.tokenizer_name, 543 | revision=args.revision, 544 | use_fast=False, 545 | ) 546 | elif args.pretrained_model_name_or_path: 547 | tokenizer = AutoTokenizer.from_pretrained( 548 | args.pretrained_model_name_or_path, 549 | subfolder="tokenizer", 550 | revision=args.revision, 551 | use_fast=False, 552 | ) 553 | 554 | # import correct text encoder class 555 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 556 | 557 | # Load models and create wrapper for stable diffusion 558 | text_encoder = text_encoder_cls.from_pretrained( 559 | args.pretrained_model_name_or_path, 560 | subfolder="text_encoder", 561 | revision=args.revision, 562 | ) 563 | vae = AutoencoderKL.from_pretrained( 564 | args.pretrained_model_name_or_path, 565 | subfolder="vae", 566 | revision=args.revision, 567 | ) 568 | unet = UNet2DConditionModel.from_pretrained( 569 | args.pretrained_model_name_or_path, 570 | subfolder="unet", 571 | revision=args.revision, 572 | ) 573 | 574 | if is_xformers_available(): 575 | try: 576 | unet.enable_xformers_memory_efficient_attention() 577 | except Exception as e: 578 | logger.warning( 579 | "Could not enable memory efficient attention. Make sure xformers is installed" 580 | f" correctly and a GPU is available: {e}" 581 | ) 582 | 583 | vae.requires_grad_(False) 584 | if not args.train_text_encoder: 585 | text_encoder.requires_grad_(False) 586 | 587 | if args.gradient_checkpointing: 588 | unet.enable_gradient_checkpointing() 589 | if args.train_text_encoder: 590 | text_encoder.gradient_checkpointing_enable() 591 | 592 | if args.scale_lr: 593 | args.learning_rate = ( 594 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 595 | ) 596 | 597 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 598 | if args.use_8bit_adam: 599 | try: 600 | import bitsandbytes as bnb 601 | except ImportError: 602 | raise ImportError( 603 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 604 | ) 605 | 606 | optimizer_class = bnb.optim.AdamW8bit 607 | else: 608 | optimizer_class = torch.optim.AdamW 609 | 610 | params_to_optimize = ( 611 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 612 | ) 613 | optimizer = optimizer_class( 614 | params_to_optimize, 615 | lr=args.learning_rate, 616 | betas=(args.adam_beta1, args.adam_beta2), 617 | weight_decay=args.adam_weight_decay, 618 | eps=args.adam_epsilon, 619 | ) 620 | 621 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 622 | 623 | vae_scale_factor = create_depth_images([args.instance_data_dir, args.class_data_dir], args.pretrained_model_name_or_path, accelerator, unet, text_encoder) 624 | train_dataset = DreamBoothDataset( 625 | instance_data_root=args.instance_data_dir, 626 | instance_prompt=args.instance_prompt, 627 | tokenizer=tokenizer, 628 | vae_scale_factor=vae_scale_factor, 629 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 630 | class_prompt=args.class_prompt, 631 | size=args.resolution, 632 | center_crop=args.center_crop, 633 | ) 634 | 635 | train_dataloader = torch.utils.data.DataLoader( 636 | train_dataset, 637 | batch_size=args.train_batch_size, 638 | shuffle=True, 639 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), 640 | num_workers=1, 641 | ) 642 | 643 | # Scheduler and math around the number of training steps. 644 | overrode_max_train_steps = False 645 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 646 | if args.max_train_steps is None: 647 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 648 | overrode_max_train_steps = True 649 | 650 | lr_scheduler = get_scheduler( 651 | args.lr_scheduler, 652 | optimizer=optimizer, 653 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 654 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 655 | ) 656 | 657 | if args.train_text_encoder: 658 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 659 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 660 | ) 661 | else: 662 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 663 | unet, optimizer, train_dataloader, lr_scheduler 664 | ) 665 | accelerator.register_for_checkpointing(lr_scheduler) 666 | 667 | weight_dtype = torch.float32 668 | if accelerator.mixed_precision == "fp16": 669 | weight_dtype = torch.float16 670 | elif accelerator.mixed_precision == "bf16": 671 | weight_dtype = torch.bfloat16 672 | 673 | # Move text_encode and vae to gpu. 674 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 675 | # as these models are only used for inference, keeping weights in full precision is not required. 676 | vae.to(accelerator.device, dtype=weight_dtype) 677 | if not args.train_text_encoder: 678 | text_encoder.to(accelerator.device, dtype=weight_dtype) 679 | 680 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 681 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 682 | if overrode_max_train_steps: 683 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 684 | # Afterwards we recalculate our number of training epochs 685 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 686 | 687 | # We need to initialize the trackers we use, and also store our configuration. 688 | # The trackers initializes automatically on the main process. 689 | if accelerator.is_main_process: 690 | accelerator.init_trackers("dreambooth", config=vars(args)) 691 | 692 | # Train! 693 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 694 | 695 | logger.info("***** Running training *****") 696 | logger.info(f" Num examples = {len(train_dataset)}") 697 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 698 | logger.info(f" Num Epochs = {args.num_train_epochs}") 699 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 700 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 701 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 702 | logger.info(f" Total optimization steps = {args.max_train_steps}") 703 | global_step = 0 704 | first_epoch = 0 705 | 706 | if args.resume_from_checkpoint: 707 | if args.resume_from_checkpoint != "latest": 708 | path = os.path.basename(args.resume_from_checkpoint) 709 | else: 710 | # Get the mos recent checkpoint 711 | dirs = os.listdir(args.output_dir) 712 | dirs = [d for d in dirs if d.startswith("checkpoint")] 713 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 714 | path = dirs[-1] 715 | accelerator.print(f"Resuming from checkpoint {path}") 716 | accelerator.load_state(os.path.join(args.output_dir, path)) 717 | global_step = int(path.split("-")[1]) 718 | 719 | resume_global_step = global_step * args.gradient_accumulation_steps 720 | first_epoch = resume_global_step // num_update_steps_per_epoch 721 | resume_step = resume_global_step % num_update_steps_per_epoch 722 | 723 | # Only show the progress bar once on each machine. 724 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 725 | progress_bar.set_description("Steps") 726 | 727 | for epoch in range(first_epoch, args.num_train_epochs): 728 | unet.train() 729 | if args.train_text_encoder: 730 | text_encoder.train() 731 | for step, batch in enumerate(train_dataloader): 732 | # Skip steps until we reach the resumed step 733 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 734 | if step % args.gradient_accumulation_steps == 0: 735 | progress_bar.update(1) 736 | continue 737 | 738 | with accelerator.accumulate(unet): 739 | # Convert images to latent space 740 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 741 | latents = latents * 0.18215 742 | 743 | # Sample noise that we'll add to the latents 744 | noise = torch.randn_like(latents) 745 | bsz = latents.shape[0] 746 | # Sample a random timestep for each image 747 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 748 | timesteps = timesteps.long() 749 | 750 | # Add noise to the latents according to the noise magnitude at each timestep 751 | # (this is the forward diffusion process) 752 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 753 | 754 | # Brian: do we add noise to depth or not? I think no 755 | noisy_latents = torch.cat([noisy_latents, batch["depth_values"]], dim=1) 756 | 757 | # Get the text embedding for conditioning 758 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 759 | 760 | # Predict the noise residual 761 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 762 | 763 | # Get the target for loss depending on the prediction type 764 | if noise_scheduler.config.prediction_type == "epsilon": 765 | target = noise 766 | elif noise_scheduler.config.prediction_type == "v_prediction": 767 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 768 | else: 769 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 770 | 771 | if args.with_prior_preservation: 772 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 773 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 774 | target, target_prior = torch.chunk(target, 2, dim=0) 775 | 776 | # Compute instance loss 777 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 778 | 779 | # Compute prior loss 780 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 781 | 782 | # Add the prior loss to the instance loss. 783 | loss = loss + args.prior_loss_weight * prior_loss 784 | else: 785 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 786 | 787 | accelerator.backward(loss) 788 | if accelerator.sync_gradients: 789 | params_to_clip = ( 790 | itertools.chain(unet.parameters(), text_encoder.parameters()) 791 | if args.train_text_encoder 792 | else unet.parameters() 793 | ) 794 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 795 | optimizer.step() 796 | lr_scheduler.step() 797 | optimizer.zero_grad() 798 | 799 | # Checks if the accelerator has performed an optimization step behind the scenes 800 | if accelerator.sync_gradients: 801 | progress_bar.update(1) 802 | global_step += 1 803 | 804 | if global_step % args.checkpointing_steps == 0: 805 | if accelerator.is_main_process: 806 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 807 | accelerator.save_state(save_path) 808 | logger.info(f"Saved state to {save_path}") 809 | 810 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 811 | progress_bar.set_postfix(**logs) 812 | accelerator.log(logs, step=global_step) 813 | 814 | if global_step >= args.max_train_steps: 815 | break 816 | 817 | accelerator.wait_for_everyone() 818 | 819 | # Create the pipeline using using the trained modules and save it. 820 | if accelerator.is_main_process: 821 | pipeline = DiffusionPipeline.from_pretrained( 822 | args.pretrained_model_name_or_path, 823 | unet=accelerator.unwrap_model(unet), 824 | text_encoder=accelerator.unwrap_model(text_encoder), 825 | revision=args.revision, 826 | ) 827 | pipeline.save_pretrained(args.output_dir) 828 | 829 | if args.push_to_hub: 830 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 831 | 832 | accelerator.end_training() 833 | 834 | 835 | if __name__ == "__main__": 836 | args = parse_args() 837 | main(args) 838 | --------------------------------------------------------------------------------