├── README.md ├── comando.sh ├── train_dreambooth.py ├── train_dreambooth_parallel.py ├── train_dreambooth_unet.py └── train_dreambooth_vae.py /README.md: -------------------------------------------------------------------------------- 1 | # Deprecation warning 2 | This repository is obsolete. ShivamShrirao made a better implementation of this. 3 | Here is his repository: https://github.com/ShivamShrirao/diffusers. 4 | 5 | 6 | # notes 7 | the original source is from https://github.com/ShivamShrirao/diffusers 8 | I used this revision of the xformers library pip install git+https://github.com/facebookresearch/xformers@1d31a3a#egg=xformers 9 | Tested on a nvidia 3060 12G. 10 | My machine: 11 | debian linux, 12 | 64GB of ram, 13 | a nvidia 3060 as the main gpu for the desktop, 14 | a secondary nvidia 1060 6G. 15 | 16 | 17 | # DreamBooth training example 18 | 19 | open comand.sh and edit the folder names and the relevant parameters 20 | 21 | DREAMBOOTH_SECONDARY is the device where you want to put the vae. 22 | If unsure, leave "cpu". If you have a secondary GPU, then use "cuda:1". If you want to use the main gpu, then "cuda" 23 | 24 | The EFFICIENT_TRAINER is set to 1 to use the most efficient setup I found. 25 | For a bit more precise training use 0. 26 | 27 | PARALLEL_TRAINING. If 0, the execution is alternating between cpu and gpu 28 | If set to 1, the gpu and secondary device will run in parallel. 29 | Parallel training doesn't increase VRAM. 30 | -------------------------------------------------------------------------------- /comando.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #################################### 5 | ### start configuration parameters 6 | #################################### 7 | 8 | CLASS_NAME="person" 9 | INSTANCE_NAME="matteo" 10 | 11 | MODEL_NAME=/home/matteo/programmi/stable-diffusion-v1-4/stable-diffusion-v1-4/ 12 | OUTPUT_DIR=/home/matteo/programmi/stable-diffusion-v1-4/trained_${INSTANCE_NAME} 13 | CLASS_DIR=/home/matteo/Progetti/ImageAI/textual_inversion/${CLASS_NAME} 14 | INSTANCE_DIR=/home/matteo/Progetti/ImageAI/textual_inversion/me/ 15 | 16 | INSTANCE_PROMPT="photo of $INSTANCE_NAME $CLASS_NAME" 17 | CLASS_PROMPT="photo of a $CLASS_NAME" 18 | export USE_MEMORY_EFFICIENT_ATTENTION=1 19 | 20 | 21 | export DREAMBOOTH_SECONDARY=cpu 22 | EFFICIENT_TRAINER=1 23 | PARALLEL_TRAINING=0 24 | 25 | #################################### 26 | ### end of configuration parameters 27 | #################################### 28 | 29 | 30 | 31 | 32 | PYTHON_TRAIN_FILENAME=train_dreambooth.py 33 | if [[ $PARALLEL_TRAINING -gt 0 ]] 34 | then 35 | PYTHON_TRAIN_FILENAME=train_dreambooth_parallel.py 36 | fi 37 | 38 | 39 | if [[ $EFFICIENT_TRAINER -gt 0 ]] 40 | then 41 | echo using the most efficient training 42 | accelerate launch $PYTHON_TRAIN_FILENAME \ 43 | --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ 44 | --instance_data_dir=$INSTANCE_DIR \ 45 | --class_data_dir=$CLASS_DIR \ 46 | --output_dir=$OUTPUT_DIR \ 47 | --instance_prompt="$INSTANCE_PROMPT" \ 48 | --class_prompt="$CLASS_PROMPT" \ 49 | --resolution=512 \ 50 | --use_8bit_adam \ 51 | --train_batch_size=1 \ 52 | --learning_rate=5e-6 \ 53 | --lr_scheduler="constant" \ 54 | --lr_warmup_steps=0 \ 55 | --sample_batch_size=4 \ 56 | --num_class_images=200 \ 57 | --max_train_steps=400 58 | 59 | else 60 | 61 | echo "you can also try this if you have enough memory and the correct repository, it uses the prior preservation and gradient accumulation" 62 | accelerate launch $PYTHON_TRAIN_FILENAME \ 63 | --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ 64 | --instance_data_dir=$INSTANCE_DIR \ 65 | --class_data_dir=$CLASS_DIR \ 66 | --output_dir=$OUTPUT_DIR \ 67 | --with_prior_preservation --prior_loss_weight=1.0 \ 68 | --instance_prompt="$INSTANCE_PROMPT" \ 69 | --class_prompt="$CLASS_PROMPT" \ 70 | --resolution=512 \ 71 | --use_8bit_adam \ 72 | --train_batch_size=1 \ 73 | --gradient_accumulation_steps=2 --gradient_checkpointing \ 74 | --learning_rate=5e-6 \ 75 | --lr_scheduler="constant" \ 76 | --lr_warmup_steps=0 \ 77 | --sample_batch_size=4 \ 78 | --num_class_images=200 \ 79 | --max_train_steps=800 80 | 81 | fi 82 | 83 | -------------------------------------------------------------------------------- /train_dreambooth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from contextlib import nullcontext 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from torch.utils.data import Dataset 12 | 13 | from accelerate import Accelerator 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 17 | from diffusers.optimization import get_scheduler 18 | from huggingface_hub import HfFolder, Repository, whoami 19 | from PIL import Image 20 | from torchvision import transforms 21 | from tqdm.auto import tqdm 22 | from transformers import CLIPTextModel, CLIPTokenizer 23 | 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 30 | parser.add_argument( 31 | "--pretrained_model_name_or_path", 32 | type=str, 33 | default=None, 34 | required=True, 35 | help="Path to pretrained model or model identifier from huggingface.co/models.", 36 | ) 37 | parser.add_argument( 38 | "--tokenizer_name", 39 | type=str, 40 | default=None, 41 | help="Pretrained tokenizer name or path if not the same as model_name", 42 | ) 43 | parser.add_argument( 44 | "--instance_data_dir", 45 | type=str, 46 | default=None, 47 | required=True, 48 | help="A folder containing the training data of instance images.", 49 | ) 50 | parser.add_argument( 51 | "--class_data_dir", 52 | type=str, 53 | default=None, 54 | required=False, 55 | help="A folder containing the training data of class images.", 56 | ) 57 | parser.add_argument( 58 | "--instance_prompt", 59 | type=str, 60 | default=None, 61 | help="The prompt with identifier specifing the instance", 62 | ) 63 | parser.add_argument( 64 | "--class_prompt", 65 | type=str, 66 | default=None, 67 | help="The prompt to specify images in the same class as provided intance images.", 68 | ) 69 | parser.add_argument( 70 | "--with_prior_preservation", 71 | default=False, 72 | action="store_true", 73 | help="Flag to add prior perservation loss.", 74 | ) 75 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 76 | parser.add_argument( 77 | "--num_class_images", 78 | type=int, 79 | default=100, 80 | help=( 81 | "Minimal class images for prior perversation loss. If not have enough images, additional images will be" 82 | " sampled with class_prompt." 83 | ), 84 | ) 85 | parser.add_argument( 86 | "--output_dir", 87 | type=str, 88 | default="text-inversion-model", 89 | help="The output directory where the model predictions and checkpoints will be written.", 90 | ) 91 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 92 | parser.add_argument( 93 | "--resolution", 94 | type=int, 95 | default=512, 96 | help=( 97 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 98 | " resolution" 99 | ), 100 | ) 101 | parser.add_argument( 102 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 103 | ) 104 | parser.add_argument( 105 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 106 | ) 107 | parser.add_argument( 108 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 109 | ) 110 | parser.add_argument("--num_train_epochs", type=int, default=1) 111 | parser.add_argument( 112 | "--max_train_steps", 113 | type=int, 114 | default=None, 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 | ) 117 | parser.add_argument( 118 | "--gradient_accumulation_steps", 119 | type=int, 120 | default=1, 121 | help="Number of updates steps to accumulate before performing a backward/update pass.", 122 | ) 123 | parser.add_argument( 124 | "--gradient_checkpointing", 125 | action="store_true", 126 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 127 | ) 128 | parser.add_argument( 129 | "--learning_rate", 130 | type=float, 131 | default=5e-6, 132 | help="Initial learning rate (after the potential warmup period) to use.", 133 | ) 134 | parser.add_argument( 135 | "--scale_lr", 136 | action="store_true", 137 | default=False, 138 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 139 | ) 140 | parser.add_argument( 141 | "--lr_scheduler", 142 | type=str, 143 | default="constant", 144 | help=( 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 146 | ' "constant", "constant_with_warmup"]' 147 | ), 148 | ) 149 | parser.add_argument( 150 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 151 | ) 152 | parser.add_argument( 153 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 154 | ) 155 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 156 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 157 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 158 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 159 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 160 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 161 | parser.add_argument( 162 | "--use_auth_token", 163 | action="store_true", 164 | help=( 165 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" 166 | " private models)." 167 | ), 168 | ) 169 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 170 | parser.add_argument( 171 | "--hub_model_id", 172 | type=str, 173 | default=None, 174 | help="The name of the repository to keep in sync with the local `output_dir`.", 175 | ) 176 | parser.add_argument( 177 | "--logging_dir", 178 | type=str, 179 | default="logs", 180 | help=( 181 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 182 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 183 | ), 184 | ) 185 | parser.add_argument( 186 | "--mixed_precision", 187 | type=str, 188 | default="no", 189 | choices=["no", "fp16", "bf16"], 190 | help=( 191 | "Whether to use mixed precision. Choose" 192 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 193 | "and an Nvidia Ampere GPU." 194 | ), 195 | ) 196 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 197 | 198 | args = parser.parse_args() 199 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 200 | if env_local_rank != -1 and env_local_rank != args.local_rank: 201 | args.local_rank = env_local_rank 202 | 203 | if args.instance_data_dir is None: 204 | raise ValueError("You must specify a train data directory.") 205 | 206 | if args.with_prior_preservation: 207 | if args.class_data_dir is None: 208 | raise ValueError("You must specify a data directory for class images.") 209 | if args.class_prompt is None: 210 | raise ValueError("You must specify prompt for class images.") 211 | 212 | return args 213 | 214 | 215 | class DreamBoothDataset(Dataset): 216 | """ 217 | A dataset to prepare the instance and class images with the promots for fine-tuning the model. 218 | It pre-processes the images and the tokenizes prompts. 219 | """ 220 | 221 | def __init__( 222 | self, 223 | instance_data_root, 224 | instance_prompt, 225 | tokenizer, 226 | class_data_root=None, 227 | class_prompt=None, 228 | size=512, 229 | center_crop=False, 230 | ): 231 | self.size = size 232 | self.center_crop = center_crop 233 | self.tokenizer = tokenizer 234 | 235 | self.instance_data_root = Path(instance_data_root) 236 | if not self.instance_data_root.exists(): 237 | raise ValueError("Instance images root doesn't exists.") 238 | 239 | self.instance_images_path = list(Path(instance_data_root).iterdir()) 240 | self.num_instance_images = len(self.instance_images_path) 241 | self.instance_prompt = instance_prompt 242 | self._length = self.num_instance_images 243 | 244 | if class_data_root is not None: 245 | self.class_data_root = Path(class_data_root) 246 | self.class_data_root.mkdir(parents=True, exist_ok=True) 247 | self.class_images_path = list(Path(class_data_root).iterdir()) 248 | self.num_class_images = len(self.class_images_path) 249 | self._length = max(self.num_class_images, self.num_instance_images) 250 | self.class_prompt = class_prompt 251 | else: 252 | self.class_data_root = None 253 | 254 | self.image_transforms = transforms.Compose( 255 | [ 256 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 257 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 258 | transforms.ToTensor(), 259 | transforms.Normalize([0.5], [0.5]), 260 | ] 261 | ) 262 | 263 | def __len__(self): 264 | return self._length 265 | 266 | def __getitem__(self, index): 267 | example = {} 268 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 269 | if not instance_image.mode == "RGB": 270 | instance_image = instance_image.convert("RGB") 271 | example["instance_images"] = self.image_transforms(instance_image) 272 | example["instance_prompt_ids"] = self.tokenizer( 273 | self.instance_prompt, 274 | padding="do_not_pad", 275 | truncation=True, 276 | max_length=self.tokenizer.model_max_length, 277 | ).input_ids 278 | 279 | if self.class_data_root: 280 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 281 | if not class_image.mode == "RGB": 282 | class_image = class_image.convert("RGB") 283 | example["class_images"] = self.image_transforms(class_image) 284 | example["class_prompt_ids"] = self.tokenizer( 285 | self.class_prompt, 286 | padding="do_not_pad", 287 | truncation=True, 288 | max_length=self.tokenizer.model_max_length, 289 | ).input_ids 290 | 291 | return example 292 | 293 | 294 | class PromptDataset(Dataset): 295 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 296 | 297 | def __init__(self, prompt, num_samples): 298 | self.prompt = prompt 299 | self.num_samples = num_samples 300 | 301 | def __len__(self): 302 | return self.num_samples 303 | 304 | def __getitem__(self, index): 305 | example = {} 306 | example["prompt"] = self.prompt 307 | example["index"] = index 308 | return example 309 | 310 | 311 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 312 | if token is None: 313 | token = HfFolder.get_token() 314 | if organization is None: 315 | username = whoami(token)["name"] 316 | return f"{username}/{model_id}" 317 | else: 318 | return f"{organization}/{model_id}" 319 | 320 | 321 | def main(): 322 | args = parse_args() 323 | logging_dir = Path(args.output_dir, args.logging_dir) 324 | print("instance_prompt: ", args.instance_prompt) 325 | accelerator = Accelerator( 326 | gradient_accumulation_steps=args.gradient_accumulation_steps, 327 | mixed_precision=args.mixed_precision, 328 | log_with="tensorboard", 329 | logging_dir=logging_dir, 330 | ) 331 | 332 | if args.seed is not None: 333 | set_seed(args.seed) 334 | 335 | if args.with_prior_preservation: 336 | class_images_dir = Path(args.class_data_dir) 337 | if not class_images_dir.exists(): 338 | class_images_dir.mkdir(parents=True) 339 | cur_class_images = len(list(class_images_dir.iterdir())) 340 | 341 | if cur_class_images < args.num_class_images: 342 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 343 | pipeline = StableDiffusionPipeline.from_pretrained( 344 | args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token, torch_dtype=torch_dtype 345 | ) 346 | pipeline.set_progress_bar_config(disable=True) 347 | 348 | num_new_images = args.num_class_images - cur_class_images 349 | logger.info(f"Number of class images to sample: {num_new_images}.") 350 | 351 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 352 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 353 | 354 | sample_dataloader = accelerator.prepare(sample_dataloader) 355 | pipeline.to(accelerator.device) 356 | 357 | context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext 358 | for example in tqdm( 359 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 360 | ): 361 | with context: 362 | images = pipeline(example["prompt"]).images 363 | 364 | for i, image in enumerate(images): 365 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") 366 | 367 | del pipeline 368 | if torch.cuda.is_available(): 369 | torch.cuda.empty_cache() 370 | 371 | # Handle the repository creation 372 | if accelerator.is_main_process: 373 | if args.push_to_hub: 374 | if args.hub_model_id is None: 375 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 376 | else: 377 | repo_name = args.hub_model_id 378 | repo = Repository(args.output_dir, clone_from=repo_name) 379 | 380 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 381 | if "step_*" not in gitignore: 382 | gitignore.write("step_*\n") 383 | if "epoch_*" not in gitignore: 384 | gitignore.write("epoch_*\n") 385 | elif args.output_dir is not None: 386 | os.makedirs(args.output_dir, exist_ok=True) 387 | 388 | # Load the tokenizer 389 | if args.tokenizer_name: 390 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 391 | elif args.pretrained_model_name_or_path: 392 | tokenizer = CLIPTokenizer.from_pretrained( 393 | args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token 394 | ) 395 | 396 | # Load models and create wrapper for stable diffusion 397 | text_encoder = CLIPTextModel.from_pretrained( 398 | args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token 399 | ) 400 | vae = AutoencoderKL.from_pretrained( 401 | args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token 402 | ) 403 | unet = UNet2DConditionModel.from_pretrained( 404 | args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token 405 | ) 406 | 407 | if args.gradient_checkpointing: 408 | unet.enable_gradient_checkpointing() 409 | 410 | if args.scale_lr: 411 | args.learning_rate = ( 412 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 413 | ) 414 | 415 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 416 | if args.use_8bit_adam: 417 | try: 418 | import bitsandbytes as bnb 419 | except ImportError: 420 | raise ImportError( 421 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 422 | ) 423 | 424 | optimizer_class = bnb.optim.AdamW8bit 425 | else: 426 | optimizer_class = torch.optim.AdamW 427 | 428 | optimizer = optimizer_class( 429 | unet.parameters(), # only optimize unet 430 | lr=args.learning_rate, 431 | betas=(args.adam_beta1, args.adam_beta2), 432 | weight_decay=args.adam_weight_decay, 433 | eps=args.adam_epsilon, 434 | ) 435 | 436 | noise_scheduler = DDPMScheduler( 437 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" 438 | ) 439 | 440 | train_dataset = DreamBoothDataset( 441 | instance_data_root=args.instance_data_dir, 442 | instance_prompt=args.instance_prompt, 443 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 444 | class_prompt=args.class_prompt, 445 | tokenizer=tokenizer, 446 | size=args.resolution, 447 | center_crop=args.center_crop, 448 | ) 449 | 450 | def collate_fn(examples): 451 | input_ids = [example["instance_prompt_ids"] for example in examples] 452 | pixel_values = [example["instance_images"] for example in examples] 453 | 454 | # Concat class and instance examples for prior preservation. 455 | # We do this to avoid doing two forward passes. 456 | if args.with_prior_preservation: 457 | input_ids += [example["class_prompt_ids"] for example in examples] 458 | pixel_values += [example["class_images"] for example in examples] 459 | 460 | pixel_values = torch.stack(pixel_values) 461 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 462 | 463 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 464 | 465 | batch = { 466 | "input_ids": input_ids, 467 | "pixel_values": pixel_values, 468 | } 469 | return batch 470 | 471 | train_dataloader = torch.utils.data.DataLoader( 472 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn 473 | ) 474 | 475 | # Scheduler and math around the number of training steps. 476 | overrode_max_train_steps = False 477 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 478 | if args.max_train_steps is None: 479 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 480 | overrode_max_train_steps = True 481 | 482 | lr_scheduler = get_scheduler( 483 | args.lr_scheduler, 484 | optimizer=optimizer, 485 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 486 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 487 | ) 488 | 489 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 490 | unet, optimizer, train_dataloader, lr_scheduler 491 | ) 492 | cuda_secondary=os.environ.get('DREAMBOOTH_SECONDARY', accelerator.device) 493 | # Move text_encode and vae to gpu 494 | text_encoder.to(accelerator.device) 495 | vae.to(cuda_secondary) 496 | 497 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 498 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 499 | if overrode_max_train_steps: 500 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 501 | # Afterwards we recalculate our number of training epochs 502 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 503 | 504 | # We need to initialize the trackers we use, and also store our configuration. 505 | # The trackers initializes automatically on the main process. 506 | if accelerator.is_main_process: 507 | accelerator.init_trackers("dreambooth", config=vars(args)) 508 | 509 | # Train! 510 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 511 | 512 | logger.info("***** Running training *****") 513 | logger.info(f" Num examples = {len(train_dataset)}") 514 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 515 | logger.info(f" Num Epochs = {args.num_train_epochs}") 516 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 517 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 518 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 519 | logger.info(f" Total optimization steps = {args.max_train_steps}") 520 | # Only show the progress bar once on each machine. 521 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 522 | progress_bar.set_description("Steps") 523 | global_step = 0 524 | 525 | for epoch in range(args.num_train_epochs): 526 | unet.train() 527 | for step, batch in enumerate(train_dataloader): 528 | with accelerator.accumulate(unet): 529 | # Convert images to latent space 530 | with torch.no_grad(): 531 | latents = vae.encode(batch["pixel_values"].to(cuda_secondary)).latent_dist.sample() 532 | latents = latents.to(accelerator.device) * 0.18215 533 | 534 | # Sample noise that we'll add to the latents 535 | noise = torch.randn(latents.shape).to(latents.device) 536 | bsz = latents.shape[0] 537 | # Sample a random timestep for each image 538 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 539 | timesteps = timesteps.long() 540 | 541 | # Add noise to the latents according to the noise magnitude at each timestep 542 | # (this is the forward diffusion process) 543 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 544 | 545 | # Get the text embedding for conditioning 546 | with torch.no_grad(): 547 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 548 | 549 | # Predict the noise residual 550 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 551 | 552 | if args.with_prior_preservation: 553 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 554 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 555 | noise, noise_prior = torch.chunk(noise, 2, dim=0) 556 | 557 | # Compute instance loss 558 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 559 | 560 | # Compute prior loss 561 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 562 | 563 | # Add the prior loss to the instance loss. 564 | loss = loss + args.prior_loss_weight * prior_loss 565 | else: 566 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 567 | 568 | accelerator.backward(loss) 569 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 570 | optimizer.step() 571 | lr_scheduler.step() 572 | optimizer.zero_grad(set_to_none=True) 573 | 574 | # Checks if the accelerator has performed an optimization step behind the scenes 575 | if accelerator.sync_gradients: 576 | progress_bar.update(1) 577 | global_step += 1 578 | 579 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 580 | progress_bar.set_postfix(**logs) 581 | accelerator.log(logs, step=global_step) 582 | 583 | if global_step >= args.max_train_steps: 584 | break 585 | 586 | accelerator.wait_for_everyone() 587 | 588 | # Create the pipeline using using the trained modules and save it. 589 | if accelerator.is_main_process: 590 | pipeline = StableDiffusionPipeline.from_pretrained( 591 | args.pretrained_model_name_or_path, 592 | unet=accelerator.unwrap_model(unet), 593 | use_auth_token=args.use_auth_token, 594 | ) 595 | pipeline.save_pretrained(args.output_dir) 596 | 597 | if args.push_to_hub: 598 | repo.push_to_hub( 599 | args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True 600 | ) 601 | 602 | accelerator.end_training() 603 | 604 | 605 | if __name__ == "__main__": 606 | main() 607 | -------------------------------------------------------------------------------- /train_dreambooth_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | import multiprocessing as mp 3 | 4 | 5 | 6 | def launch1(q): 7 | import train_dreambooth_unet 8 | train_dreambooth_unet.main(q) 9 | q.get() 10 | 11 | def launch2(q): 12 | import train_dreambooth_vae 13 | train_dreambooth_vae.main(q) 14 | q.put(None) 15 | 16 | q = mp.Queue(1) 17 | 18 | p1 = mp.Process(target=launch1, args=(q,)) 19 | p2 = mp.Process(target=launch2, args=(q,)) 20 | p1.start() 21 | p2.start() 22 | 23 | 24 | p1.join() 25 | q.close() 26 | p2.terminate() 27 | p2.join() 28 | -------------------------------------------------------------------------------- /train_dreambooth_unet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from contextlib import nullcontext 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from torch.utils.data import Dataset 12 | 13 | from accelerate import Accelerator 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 17 | from diffusers.optimization import get_scheduler 18 | from huggingface_hub import HfFolder, Repository, whoami 19 | from PIL import Image 20 | from torchvision import transforms 21 | from tqdm.auto import tqdm 22 | from transformers import CLIPTextModel, CLIPTokenizer 23 | 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 30 | parser.add_argument( 31 | "--pretrained_model_name_or_path", 32 | type=str, 33 | default=None, 34 | required=True, 35 | help="Path to pretrained model or model identifier from huggingface.co/models.", 36 | ) 37 | parser.add_argument( 38 | "--tokenizer_name", 39 | type=str, 40 | default=None, 41 | help="Pretrained tokenizer name or path if not the same as model_name", 42 | ) 43 | parser.add_argument( 44 | "--instance_data_dir", 45 | type=str, 46 | default=None, 47 | required=True, 48 | help="A folder containing the training data of instance images.", 49 | ) 50 | parser.add_argument( 51 | "--class_data_dir", 52 | type=str, 53 | default=None, 54 | required=False, 55 | help="A folder containing the training data of class images.", 56 | ) 57 | parser.add_argument( 58 | "--instance_prompt", 59 | type=str, 60 | default=None, 61 | help="The prompt with identifier specifing the instance", 62 | ) 63 | parser.add_argument( 64 | "--class_prompt", 65 | type=str, 66 | default=None, 67 | help="The prompt to specify images in the same class as provided intance images.", 68 | ) 69 | parser.add_argument( 70 | "--with_prior_preservation", 71 | default=False, 72 | action="store_true", 73 | help="Flag to add prior perservation loss.", 74 | ) 75 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 76 | parser.add_argument( 77 | "--num_class_images", 78 | type=int, 79 | default=100, 80 | help=( 81 | "Minimal class images for prior perversation loss. If not have enough images, additional images will be" 82 | " sampled with class_prompt." 83 | ), 84 | ) 85 | parser.add_argument( 86 | "--output_dir", 87 | type=str, 88 | default="text-inversion-model", 89 | help="The output directory where the model predictions and checkpoints will be written.", 90 | ) 91 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 92 | parser.add_argument( 93 | "--resolution", 94 | type=int, 95 | default=512, 96 | help=( 97 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 98 | " resolution" 99 | ), 100 | ) 101 | parser.add_argument( 102 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 103 | ) 104 | parser.add_argument( 105 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 106 | ) 107 | parser.add_argument( 108 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 109 | ) 110 | parser.add_argument("--num_train_epochs", type=int, default=1) 111 | parser.add_argument( 112 | "--max_train_steps", 113 | type=int, 114 | default=None, 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 | ) 117 | parser.add_argument( 118 | "--gradient_accumulation_steps", 119 | type=int, 120 | default=1, 121 | help="Number of updates steps to accumulate before performing a backward/update pass.", 122 | ) 123 | parser.add_argument( 124 | "--gradient_checkpointing", 125 | action="store_true", 126 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 127 | ) 128 | parser.add_argument( 129 | "--learning_rate", 130 | type=float, 131 | default=5e-6, 132 | help="Initial learning rate (after the potential warmup period) to use.", 133 | ) 134 | parser.add_argument( 135 | "--scale_lr", 136 | action="store_true", 137 | default=False, 138 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 139 | ) 140 | parser.add_argument( 141 | "--lr_scheduler", 142 | type=str, 143 | default="constant", 144 | help=( 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 146 | ' "constant", "constant_with_warmup"]' 147 | ), 148 | ) 149 | parser.add_argument( 150 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 151 | ) 152 | parser.add_argument( 153 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 154 | ) 155 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 156 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 157 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 158 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 159 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 160 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 161 | parser.add_argument( 162 | "--use_auth_token", 163 | action="store_true", 164 | help=( 165 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" 166 | " private models)." 167 | ), 168 | ) 169 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 170 | parser.add_argument( 171 | "--hub_model_id", 172 | type=str, 173 | default=None, 174 | help="The name of the repository to keep in sync with the local `output_dir`.", 175 | ) 176 | parser.add_argument( 177 | "--logging_dir", 178 | type=str, 179 | default="logs", 180 | help=( 181 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 182 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 183 | ), 184 | ) 185 | parser.add_argument( 186 | "--mixed_precision", 187 | type=str, 188 | default="no", 189 | choices=["no", "fp16", "bf16"], 190 | help=( 191 | "Whether to use mixed precision. Choose" 192 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 193 | "and an Nvidia Ampere GPU." 194 | ), 195 | ) 196 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 197 | 198 | args = parser.parse_args() 199 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 200 | if env_local_rank != -1 and env_local_rank != args.local_rank: 201 | args.local_rank = env_local_rank 202 | 203 | if args.instance_data_dir is None: 204 | raise ValueError("You must specify a train data directory.") 205 | 206 | if args.with_prior_preservation: 207 | if args.class_data_dir is None: 208 | raise ValueError("You must specify a data directory for class images.") 209 | if args.class_prompt is None: 210 | raise ValueError("You must specify prompt for class images.") 211 | 212 | return args 213 | 214 | 215 | class DreamBoothDataset(Dataset): 216 | """ 217 | A dataset to prepare the instance and class images with the promots for fine-tuning the model. 218 | It pre-processes the images and the tokenizes prompts. 219 | """ 220 | 221 | def __init__( 222 | self, 223 | instance_data_root, 224 | instance_prompt, 225 | tokenizer, 226 | class_data_root=None, 227 | class_prompt=None, 228 | size=512, 229 | center_crop=False, 230 | ): 231 | self.size = size 232 | self.center_crop = center_crop 233 | self.tokenizer = tokenizer 234 | 235 | self.instance_data_root = Path(instance_data_root) 236 | if not self.instance_data_root.exists(): 237 | raise ValueError("Instance images root doesn't exists.") 238 | 239 | self.instance_images_path = list(Path(instance_data_root).iterdir()) 240 | self.num_instance_images = len(self.instance_images_path) 241 | self.instance_prompt = instance_prompt 242 | self._length = self.num_instance_images 243 | 244 | if class_data_root is not None: 245 | self.class_data_root = Path(class_data_root) 246 | self.class_data_root.mkdir(parents=True, exist_ok=True) 247 | self.class_images_path = list(Path(class_data_root).iterdir()) 248 | self.num_class_images = len(self.class_images_path) 249 | self._length = max(self.num_class_images, self.num_instance_images) 250 | self.class_prompt = class_prompt 251 | else: 252 | self.class_data_root = None 253 | 254 | self.image_transforms = transforms.Compose( 255 | [ 256 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 257 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 258 | transforms.ToTensor(), 259 | transforms.Normalize([0.5], [0.5]), 260 | ] 261 | ) 262 | 263 | def __len__(self): 264 | return self._length 265 | 266 | def __getitem__(self, index): 267 | example = {} 268 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 269 | if not instance_image.mode == "RGB": 270 | instance_image = instance_image.convert("RGB") 271 | example["instance_images"] = self.image_transforms(instance_image) 272 | example["instance_prompt_ids"] = self.tokenizer( 273 | self.instance_prompt, 274 | padding="do_not_pad", 275 | truncation=True, 276 | max_length=self.tokenizer.model_max_length, 277 | ).input_ids 278 | 279 | if self.class_data_root: 280 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 281 | if not class_image.mode == "RGB": 282 | class_image = class_image.convert("RGB") 283 | example["class_images"] = self.image_transforms(class_image) 284 | example["class_prompt_ids"] = self.tokenizer( 285 | self.class_prompt, 286 | padding="do_not_pad", 287 | truncation=True, 288 | max_length=self.tokenizer.model_max_length, 289 | ).input_ids 290 | 291 | return example 292 | 293 | 294 | class PromptDataset(Dataset): 295 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 296 | 297 | def __init__(self, prompt, num_samples): 298 | self.prompt = prompt 299 | self.num_samples = num_samples 300 | 301 | def __len__(self): 302 | return self.num_samples 303 | 304 | def __getitem__(self, index): 305 | example = {} 306 | example["prompt"] = self.prompt 307 | example["index"] = index 308 | return example 309 | 310 | 311 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 312 | if token is None: 313 | token = HfFolder.get_token() 314 | if organization is None: 315 | username = whoami(token)["name"] 316 | return f"{username}/{model_id}" 317 | else: 318 | return f"{organization}/{model_id}" 319 | 320 | 321 | def main(q): 322 | args = parse_args() 323 | print(args) 324 | logging_dir = Path(args.output_dir, args.logging_dir) 325 | print("instance_prompt: ", args.instance_prompt) 326 | accelerator = Accelerator( 327 | gradient_accumulation_steps=args.gradient_accumulation_steps, 328 | mixed_precision=args.mixed_precision, 329 | log_with="tensorboard", 330 | logging_dir=logging_dir, 331 | ) 332 | 333 | if args.seed is not None: 334 | set_seed(args.seed) 335 | 336 | if args.with_prior_preservation: 337 | class_images_dir = Path(args.class_data_dir) 338 | if not class_images_dir.exists(): 339 | class_images_dir.mkdir(parents=True) 340 | cur_class_images = len(list(class_images_dir.iterdir())) 341 | 342 | if cur_class_images < args.num_class_images: 343 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 344 | pipeline = StableDiffusionPipeline.from_pretrained( 345 | args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token, torch_dtype=torch_dtype 346 | ) 347 | pipeline.set_progress_bar_config(disable=True) 348 | 349 | num_new_images = args.num_class_images - cur_class_images 350 | print(f"Number of class images to sample: {num_new_images}.") 351 | 352 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 353 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 354 | 355 | sample_dataloader = accelerator.prepare(sample_dataloader) 356 | pipeline.to(accelerator.device) 357 | 358 | context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext 359 | for example in tqdm( 360 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 361 | ): 362 | with context: 363 | images = pipeline(example["prompt"]).images 364 | 365 | for i, image in enumerate(images): 366 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") 367 | 368 | del pipeline 369 | if torch.cuda.is_available(): 370 | torch.cuda.empty_cache() 371 | 372 | # Handle the repository creation 373 | if accelerator.is_main_process: 374 | if args.push_to_hub: 375 | if args.hub_model_id is None: 376 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 377 | else: 378 | repo_name = args.hub_model_id 379 | repo = Repository(args.output_dir, clone_from=repo_name) 380 | 381 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 382 | if "step_*" not in gitignore: 383 | gitignore.write("step_*\n") 384 | if "epoch_*" not in gitignore: 385 | gitignore.write("epoch_*\n") 386 | elif args.output_dir is not None: 387 | os.makedirs(args.output_dir, exist_ok=True) 388 | 389 | 390 | elif args.pretrained_model_name_or_path: 391 | tokenizer = CLIPTokenizer.from_pretrained( 392 | args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token 393 | ) 394 | 395 | 396 | unet = UNet2DConditionModel.from_pretrained( 397 | args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token 398 | ) 399 | 400 | if args.gradient_checkpointing: 401 | unet.enable_gradient_checkpointing() 402 | 403 | if args.scale_lr: 404 | args.learning_rate = ( 405 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 406 | ) 407 | 408 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 409 | if args.use_8bit_adam: 410 | try: 411 | import bitsandbytes as bnb 412 | except ImportError: 413 | raise ImportError( 414 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 415 | ) 416 | 417 | optimizer_class = bnb.optim.AdamW8bit 418 | else: 419 | optimizer_class = torch.optim.AdamW 420 | 421 | optimizer = optimizer_class( 422 | unet.parameters(), # only optimize unet 423 | lr=args.learning_rate, 424 | betas=(args.adam_beta1, args.adam_beta2), 425 | weight_decay=args.adam_weight_decay, 426 | eps=args.adam_epsilon, 427 | ) 428 | 429 | noise_scheduler = DDPMScheduler( 430 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" 431 | ) 432 | 433 | 434 | 435 | 436 | len_train_dataloader=q.get() 437 | 438 | 439 | # Scheduler and math around the number of training steps. 440 | overrode_max_train_steps = False 441 | num_update_steps_per_epoch = math.ceil(len_train_dataloader / args.gradient_accumulation_steps) 442 | if args.max_train_steps is None: 443 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 444 | overrode_max_train_steps = True 445 | 446 | lr_scheduler = get_scheduler( 447 | args.lr_scheduler, 448 | optimizer=optimizer, 449 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 450 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 451 | ) 452 | 453 | unet, optimizer, lr_scheduler = accelerator.prepare( 454 | unet, optimizer, lr_scheduler 455 | ) 456 | 457 | 458 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 459 | num_update_steps_per_epoch = math.ceil(len_train_dataloader/ args.gradient_accumulation_steps) 460 | if overrode_max_train_steps: 461 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 462 | # Afterwards we recalculate our number of training epochs 463 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 464 | 465 | # We need to initialize the trackers we use, and also store our configuration. 466 | # The trackers initializes automatically on the main process. 467 | if accelerator.is_main_process: 468 | accelerator.init_trackers("dreambooth", config=vars(args)) 469 | 470 | # Train! 471 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 472 | 473 | print("***** Running training *****") 474 | #print(f" Num examples = {len(train_dataset)}") 475 | print(f" Num batches each epoch = {len_train_dataloader}") 476 | print(f" Num Epochs = {args.num_train_epochs}") 477 | print(f" Instantaneous batch size per device = {args.train_batch_size}") 478 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 479 | print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 480 | print(f" Total optimization steps = {args.max_train_steps}") 481 | # Only show the progress bar once on each machine. 482 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 483 | progress_bar.set_description("Steps") 484 | global_step = 0 485 | 486 | for epoch in range(args.num_train_epochs): 487 | unet.train() 488 | while True: 489 | with accelerator.accumulate(unet): 490 | # Convert images to latent space 491 | 492 | el = q.get() 493 | if el is None: 494 | break 495 | latents,encoder_hidden_states = el 496 | latents=latents.to(accelerator.device) 497 | encoder_hidden_states=encoder_hidden_states.to(accelerator.device) 498 | 499 | # Sample noise that we'll add to the latents 500 | noise = torch.randn(latents.shape).to(latents.device) 501 | bsz = latents.shape[0] 502 | # Sample a random timestep for each image 503 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 504 | timesteps = timesteps.long() 505 | 506 | # Add noise to the latents according to the noise magnitude at each timestep 507 | # (this is the forward diffusion process) 508 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 509 | 510 | # Predict the noise residual 511 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 512 | 513 | if args.with_prior_preservation: 514 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 515 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 516 | noise, noise_prior = torch.chunk(noise, 2, dim=0) 517 | 518 | # Compute instance loss 519 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 520 | 521 | # Compute prior loss 522 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 523 | 524 | # Add the prior loss to the instance loss. 525 | loss = loss + args.prior_loss_weight * prior_loss 526 | else: 527 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 528 | 529 | accelerator.backward(loss) 530 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 531 | optimizer.step() 532 | lr_scheduler.step() 533 | optimizer.zero_grad(set_to_none=True) 534 | 535 | # Checks if the accelerator has performed an optimization step behind the scenes 536 | if accelerator.sync_gradients: 537 | progress_bar.update(1) 538 | global_step += 1 539 | 540 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 541 | progress_bar.set_postfix(**logs) 542 | accelerator.log(logs, step=global_step) 543 | 544 | if global_step >= args.max_train_steps: 545 | break 546 | 547 | accelerator.wait_for_everyone() 548 | 549 | # Create the pipeline using using the trained modules and save it. 550 | if accelerator.is_main_process: 551 | pipeline = StableDiffusionPipeline.from_pretrained( 552 | args.pretrained_model_name_or_path, 553 | unet=accelerator.unwrap_model(unet), 554 | use_auth_token=args.use_auth_token, 555 | ) 556 | pipeline.save_pretrained(args.output_dir) 557 | 558 | if args.push_to_hub: 559 | repo.push_to_hub( 560 | args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True 561 | ) 562 | 563 | accelerator.end_training() 564 | 565 | 566 | if __name__ == "__main__": 567 | main() 568 | -------------------------------------------------------------------------------- /train_dreambooth_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from contextlib import nullcontext 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from torch.utils.data import Dataset 12 | 13 | from accelerate import Accelerator 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 17 | from diffusers.optimization import get_scheduler 18 | from huggingface_hub import HfFolder, Repository, whoami 19 | from PIL import Image 20 | from torchvision import transforms 21 | from tqdm.auto import tqdm 22 | from transformers import CLIPTextModel, CLIPTokenizer 23 | 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 30 | parser.add_argument( 31 | "--pretrained_model_name_or_path", 32 | type=str, 33 | default=None, 34 | required=True, 35 | help="Path to pretrained model or model identifier from huggingface.co/models.", 36 | ) 37 | parser.add_argument( 38 | "--tokenizer_name", 39 | type=str, 40 | default=None, 41 | help="Pretrained tokenizer name or path if not the same as model_name", 42 | ) 43 | parser.add_argument( 44 | "--instance_data_dir", 45 | type=str, 46 | default=None, 47 | required=True, 48 | help="A folder containing the training data of instance images.", 49 | ) 50 | parser.add_argument( 51 | "--class_data_dir", 52 | type=str, 53 | default=None, 54 | required=False, 55 | help="A folder containing the training data of class images.", 56 | ) 57 | parser.add_argument( 58 | "--instance_prompt", 59 | type=str, 60 | default=None, 61 | help="The prompt with identifier specifing the instance", 62 | ) 63 | parser.add_argument( 64 | "--class_prompt", 65 | type=str, 66 | default=None, 67 | help="The prompt to specify images in the same class as provided intance images.", 68 | ) 69 | parser.add_argument( 70 | "--with_prior_preservation", 71 | default=False, 72 | action="store_true", 73 | help="Flag to add prior perservation loss.", 74 | ) 75 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 76 | parser.add_argument( 77 | "--num_class_images", 78 | type=int, 79 | default=100, 80 | help=( 81 | "Minimal class images for prior perversation loss. If not have enough images, additional images will be" 82 | " sampled with class_prompt." 83 | ), 84 | ) 85 | parser.add_argument( 86 | "--output_dir", 87 | type=str, 88 | default="text-inversion-model", 89 | help="The output directory where the model predictions and checkpoints will be written.", 90 | ) 91 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 92 | parser.add_argument( 93 | "--resolution", 94 | type=int, 95 | default=512, 96 | help=( 97 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 98 | " resolution" 99 | ), 100 | ) 101 | parser.add_argument( 102 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 103 | ) 104 | parser.add_argument( 105 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 106 | ) 107 | parser.add_argument( 108 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 109 | ) 110 | parser.add_argument("--num_train_epochs", type=int, default=1) 111 | parser.add_argument( 112 | "--max_train_steps", 113 | type=int, 114 | default=None, 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 | ) 117 | parser.add_argument( 118 | "--gradient_accumulation_steps", 119 | type=int, 120 | default=1, 121 | help="Number of updates steps to accumulate before performing a backward/update pass.", 122 | ) 123 | parser.add_argument( 124 | "--gradient_checkpointing", 125 | action="store_true", 126 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 127 | ) 128 | parser.add_argument( 129 | "--learning_rate", 130 | type=float, 131 | default=5e-6, 132 | help="Initial learning rate (after the potential warmup period) to use.", 133 | ) 134 | parser.add_argument( 135 | "--scale_lr", 136 | action="store_true", 137 | default=False, 138 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 139 | ) 140 | parser.add_argument( 141 | "--lr_scheduler", 142 | type=str, 143 | default="constant", 144 | help=( 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 146 | ' "constant", "constant_with_warmup"]' 147 | ), 148 | ) 149 | parser.add_argument( 150 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 151 | ) 152 | parser.add_argument( 153 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 154 | ) 155 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 156 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 157 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 158 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 159 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 160 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 161 | parser.add_argument( 162 | "--use_auth_token", 163 | action="store_true", 164 | help=( 165 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" 166 | " private models)." 167 | ), 168 | ) 169 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 170 | parser.add_argument( 171 | "--hub_model_id", 172 | type=str, 173 | default=None, 174 | help="The name of the repository to keep in sync with the local `output_dir`.", 175 | ) 176 | parser.add_argument( 177 | "--logging_dir", 178 | type=str, 179 | default="logs", 180 | help=( 181 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 182 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 183 | ), 184 | ) 185 | parser.add_argument( 186 | "--mixed_precision", 187 | type=str, 188 | default="no", 189 | choices=["no", "fp16", "bf16"], 190 | help=( 191 | "Whether to use mixed precision. Choose" 192 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 193 | "and an Nvidia Ampere GPU." 194 | ), 195 | ) 196 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 197 | 198 | args = parser.parse_args() 199 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 200 | if env_local_rank != -1 and env_local_rank != args.local_rank: 201 | args.local_rank = env_local_rank 202 | 203 | if args.instance_data_dir is None: 204 | raise ValueError("You must specify a train data directory.") 205 | 206 | if args.with_prior_preservation: 207 | if args.class_data_dir is None: 208 | raise ValueError("You must specify a data directory for class images.") 209 | if args.class_prompt is None: 210 | raise ValueError("You must specify prompt for class images.") 211 | 212 | return args 213 | 214 | 215 | class DreamBoothDataset(Dataset): 216 | """ 217 | A dataset to prepare the instance and class images with the promots for fine-tuning the model. 218 | It pre-processes the images and the tokenizes prompts. 219 | """ 220 | 221 | def __init__( 222 | self, 223 | instance_data_root, 224 | instance_prompt, 225 | tokenizer, 226 | class_data_root=None, 227 | class_prompt=None, 228 | size=512, 229 | center_crop=False, 230 | ): 231 | self.size = size 232 | self.center_crop = center_crop 233 | self.tokenizer = tokenizer 234 | 235 | self.instance_data_root = Path(instance_data_root) 236 | if not self.instance_data_root.exists(): 237 | raise ValueError("Instance images root doesn't exists.") 238 | 239 | self.instance_images_path = list(Path(instance_data_root).iterdir()) 240 | self.num_instance_images = len(self.instance_images_path) 241 | self.instance_prompt = instance_prompt 242 | self._length = self.num_instance_images 243 | 244 | if class_data_root is not None: 245 | self.class_data_root = Path(class_data_root) 246 | self.class_data_root.mkdir(parents=True, exist_ok=True) 247 | self.class_images_path = list(Path(class_data_root).iterdir()) 248 | self.num_class_images = len(self.class_images_path) 249 | self._length = max(self.num_class_images, self.num_instance_images) 250 | self.class_prompt = class_prompt 251 | else: 252 | self.class_data_root = None 253 | 254 | self.image_transforms = transforms.Compose( 255 | [ 256 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 257 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 258 | transforms.ToTensor(), 259 | transforms.Normalize([0.5], [0.5]), 260 | ] 261 | ) 262 | 263 | def __len__(self): 264 | return self._length 265 | 266 | def __getitem__(self, index): 267 | example = {} 268 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 269 | if not instance_image.mode == "RGB": 270 | instance_image = instance_image.convert("RGB") 271 | example["instance_images"] = self.image_transforms(instance_image) 272 | example["instance_prompt_ids"] = self.tokenizer( 273 | self.instance_prompt, 274 | padding="do_not_pad", 275 | truncation=True, 276 | max_length=self.tokenizer.model_max_length, 277 | ).input_ids 278 | 279 | if self.class_data_root: 280 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 281 | if not class_image.mode == "RGB": 282 | class_image = class_image.convert("RGB") 283 | example["class_images"] = self.image_transforms(class_image) 284 | example["class_prompt_ids"] = self.tokenizer( 285 | self.class_prompt, 286 | padding="do_not_pad", 287 | truncation=True, 288 | max_length=self.tokenizer.model_max_length, 289 | ).input_ids 290 | 291 | return example 292 | 293 | 294 | class PromptDataset(Dataset): 295 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 296 | 297 | def __init__(self, prompt, num_samples): 298 | self.prompt = prompt 299 | self.num_samples = num_samples 300 | 301 | def __len__(self): 302 | return self.num_samples 303 | 304 | def __getitem__(self, index): 305 | example = {} 306 | example["prompt"] = self.prompt 307 | example["index"] = index 308 | return example 309 | 310 | 311 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 312 | if token is None: 313 | token = HfFolder.get_token() 314 | if organization is None: 315 | username = whoami(token)["name"] 316 | return f"{username}/{model_id}" 317 | else: 318 | return f"{organization}/{model_id}" 319 | 320 | 321 | def main(q): 322 | args = parse_args() 323 | logging_dir = Path(args.output_dir, args.logging_dir) 324 | print("instance_prompt: ", args.instance_prompt) 325 | accelerator = Accelerator( 326 | gradient_accumulation_steps=args.gradient_accumulation_steps, 327 | mixed_precision=args.mixed_precision, 328 | log_with="tensorboard", 329 | logging_dir=logging_dir, 330 | ) 331 | 332 | if args.seed is not None: 333 | set_seed(args.seed) 334 | 335 | 336 | 337 | # Handle the repository creation 338 | if accelerator.is_main_process: 339 | if args.push_to_hub: 340 | if args.hub_model_id is None: 341 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 342 | else: 343 | repo_name = args.hub_model_id 344 | repo = Repository(args.output_dir, clone_from=repo_name) 345 | 346 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 347 | if "step_*" not in gitignore: 348 | gitignore.write("step_*\n") 349 | if "epoch_*" not in gitignore: 350 | gitignore.write("epoch_*\n") 351 | elif args.output_dir is not None: 352 | os.makedirs(args.output_dir, exist_ok=True) 353 | 354 | # Load the tokenizer 355 | if args.tokenizer_name: 356 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 357 | elif args.pretrained_model_name_or_path: 358 | tokenizer = CLIPTokenizer.from_pretrained( 359 | args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token 360 | ) 361 | 362 | # Load models and create wrapper for stable diffusion 363 | text_encoder = CLIPTextModel.from_pretrained( 364 | args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token 365 | ) 366 | vae = AutoencoderKL.from_pretrained( 367 | args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token 368 | ) 369 | 370 | if args.scale_lr: 371 | args.learning_rate = ( 372 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 373 | ) 374 | 375 | noise_scheduler = DDPMScheduler( 376 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" 377 | ) 378 | 379 | train_dataset = DreamBoothDataset( 380 | instance_data_root=args.instance_data_dir, 381 | instance_prompt=args.instance_prompt, 382 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 383 | class_prompt=args.class_prompt, 384 | tokenizer=tokenizer, 385 | size=args.resolution, 386 | center_crop=args.center_crop, 387 | ) 388 | 389 | def collate_fn(examples): 390 | input_ids = [example["instance_prompt_ids"] for example in examples] 391 | pixel_values = [example["instance_images"] for example in examples] 392 | 393 | # Concat class and instance examples for prior preservation. 394 | # We do this to avoid doing two forward passes. 395 | if args.with_prior_preservation: 396 | input_ids += [example["class_prompt_ids"] for example in examples] 397 | pixel_values += [example["class_images"] for example in examples] 398 | 399 | pixel_values = torch.stack(pixel_values) 400 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 401 | 402 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 403 | 404 | batch = { 405 | "input_ids": input_ids, 406 | "pixel_values": pixel_values, 407 | } 408 | return batch 409 | 410 | cuda_secondary=os.environ.get('DREAMBOOTH_SECONDARY', accelerator.device) 411 | 412 | train_dataloader = torch.utils.data.DataLoader( 413 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn 414 | ) 415 | 416 | # Scheduler and math around the number of training steps. 417 | overrode_max_train_steps = False 418 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 419 | if args.max_train_steps is None: 420 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 421 | overrode_max_train_steps = True 422 | 423 | # Move text_encode and vae to gpu 424 | text_encoder.to(cuda_secondary) 425 | vae.to(cuda_secondary) 426 | 427 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 428 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 429 | if overrode_max_train_steps: 430 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 431 | # Afterwards we recalculate our number of training epochs 432 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 433 | 434 | # We need to initialize the trackers we use, and also store our configuration. 435 | # The trackers initializes automatically on the main process. 436 | if accelerator.is_main_process: 437 | accelerator.init_trackers("dreambooth", config=vars(args)) 438 | 439 | # Train! 440 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 441 | 442 | logger.info("***** Running training *****") 443 | logger.info(f" Num examples = {len(train_dataset)}") 444 | print(f" Num batches each epoch = {len(train_dataloader)}") 445 | logger.info(f" Num Epochs = {args.num_train_epochs}") 446 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 447 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 448 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 449 | logger.info(f" Total optimization steps = {args.max_train_steps}") 450 | # Only show the progress bar once on each machine. 451 | 452 | 453 | global_step = 0 454 | 455 | q.put(len(train_dataloader)) 456 | for epoch in range(args.num_train_epochs): 457 | 458 | for step, batch in enumerate(train_dataloader): 459 | # Convert images to latent space 460 | with torch.no_grad(): 461 | latents = vae.encode(batch["pixel_values"].to(cuda_secondary)).latent_dist.sample() 462 | latents = latents.to(cuda_secondary) * 0.18215 463 | 464 | # Get the text embedding for conditioning 465 | with torch.no_grad(): 466 | encoder_hidden_states = text_encoder(batch["input_ids"].to(cuda_secondary))[0] 467 | 468 | q.put((latents.to("cpu"),encoder_hidden_states.to("cpu"))) 469 | 470 | q.put(None) 471 | 472 | accelerator.wait_for_everyone() 473 | 474 | 475 | 476 | accelerator.end_training() 477 | 478 | 479 | if __name__ == "__main__": 480 | main() 481 | --------------------------------------------------------------------------------