├── README.md └── train_dreambooth.py /README.md: -------------------------------------------------------------------------------- 1 | # Dreambooth for Audio-Gen (WIP) 2 | -------------------------------------------------------------------------------- /train_dreambooth.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import itertools 18 | import logging 19 | import math 20 | import os 21 | import warnings 22 | from pathlib import Path 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | import transformers 28 | from accelerate import Accelerator 29 | from accelerate.logging import get_logger 30 | from accelerate.utils import ProjectConfiguration, set_seed 31 | from huggingface_hub import create_repo, upload_folder 32 | from torch.utils.data import Dataset 33 | from transformers import T5EncoderModel 34 | from tqdm.auto import tqdm 35 | from transformers import AutoTokenizer, get_scheduler 36 | from transformers.utils import check_min_version 37 | import wandb 38 | 39 | check_min_version("4.31.0.dev0") 40 | 41 | logger = get_logger(__name__) 42 | 43 | 44 | def save_model_card( 45 | repo_id: str, 46 | audios=None, 47 | base_model=str, 48 | train_text_encoder=False, 49 | prompt=str, 50 | repo_folder=None, 51 | ): 52 | img_str = "" 53 | for i, audio in enumerate(audios): 54 | audio.save(os.path.join(repo_folder, f"audio_{i}.png")) 55 | img_str += f"![img_{i}](./audio_{i}.png)\n" 56 | 57 | yaml = f""" 58 | --- 59 | license: creativeml-openrail-m 60 | base_model: {base_model} 61 | instance_prompt: {prompt} 62 | tags: 63 | - text-to-audio 64 | - transformers 65 | - dreambooth 66 | inference: true 67 | --- 68 | """ 69 | model_card = f""" 70 | # DreamBooth - {repo_id} 71 | 72 | This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). 73 | You can find some example audios in the following. \n 74 | {img_str} 75 | 76 | DreamBooth for the text encoder was enabled: {train_text_encoder}. 77 | """ 78 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 79 | f.write(yaml + model_card) 80 | 81 | 82 | def log_validation(text_encoder, args, accelerator, weight_dtype, epoch): 83 | logger.info( 84 | f"Running validation... \n Generating {args.num_validation_audios} audios with prompt:" 85 | f" {args.validation_prompt}." 86 | ) 87 | 88 | if text_encoder is not None: 89 | text_encoder = accelerator.unwrap_model(text_encoder) 90 | 91 | # run inference 92 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) 93 | audios = [] 94 | 95 | for audio in args.validation_audios: 96 | # audio = audio.open(audio) 97 | # audio = pipeline(**pipeline_args, audio=audio, generator=generator).audios[0] 98 | #audios.append(audio) 99 | pass 100 | 101 | for tracker in accelerator.trackers: 102 | if tracker.name == "wandb": 103 | tracker.log( 104 | { 105 | "validation": [ 106 | wandb.Audio(audio, caption=f"{i}: {args.validation_prompt}") for i, audio in enumerate(audios) 107 | ] 108 | } 109 | ) 110 | 111 | return audio 112 | 113 | 114 | def parse_args(input_args=None): 115 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 116 | parser.add_argument( 117 | "--pretrained_model_name_or_path", 118 | type=str, 119 | default=None, 120 | required=True, 121 | help="Path to pretrained model or model identifier from huggingface.co/models.", 122 | ) 123 | parser.add_argument( 124 | "--revision", 125 | type=str, 126 | default=None, 127 | required=False, 128 | help=( 129 | "Revision of pretrained model identifier from huggingface.co/models." 130 | ), 131 | ) 132 | parser.add_argument( 133 | "--tokenizer_name", 134 | type=str, 135 | default=None, 136 | help="Pretrained tokenizer name or path if not the same as model_name", 137 | ) 138 | parser.add_argument( 139 | "--instance_data_dir", 140 | type=str, 141 | default=None, 142 | required=True, 143 | help="A folder containing the training data of instance audios.", 144 | ) 145 | parser.add_argument( 146 | "--instance_prompt", 147 | type=str, 148 | default=None, 149 | required=True, 150 | help="The prompt with identifier specifying the instance", 151 | ) 152 | parser.add_argument( 153 | "--num_class_audios", 154 | type=int, 155 | default=100, 156 | help=( 157 | "Minimal class audios for prior preservation loss. If there are not enough audios already present in" 158 | " class_data_dir, additional audios will be sampled with class_prompt." 159 | ), 160 | ) 161 | parser.add_argument( 162 | "--output_dir", 163 | type=str, 164 | default="text-inversion-model", 165 | help="The output directory where the model predictions and checkpoints will be written.", 166 | ) 167 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 168 | parser.add_argument( 169 | "--sampling_rate", 170 | type=int, 171 | default=None, 172 | help=( 173 | "The sampling_rate for input audios, all the audios in the train/validation dataset will be resampled to this" 174 | " sampling_rate. IF undefined, will default to sampling rate of `vq_model`." 175 | ), 176 | ) 177 | parser.add_argument( 178 | "--train_text_encoder", 179 | action="store_true", 180 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 181 | ) 182 | parser.add_argument( 183 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 184 | ) 185 | parser.add_argument( 186 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling audios." 187 | ) 188 | parser.add_argument("--num_train_epochs", type=int, default=1) 189 | parser.add_argument( 190 | "--max_train_steps", 191 | type=int, 192 | default=None, 193 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 194 | ) 195 | parser.add_argument( 196 | "--checkpointing_steps", 197 | type=int, 198 | default=500, 199 | help=( 200 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 201 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 202 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 203 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 204 | "instructions." 205 | ), 206 | ) 207 | parser.add_argument( 208 | "--checkpoints_total_limit", 209 | type=int, 210 | default=None, 211 | help=( 212 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 213 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 214 | " for more details" 215 | ), 216 | ) 217 | parser.add_argument( 218 | "--resume_from_checkpoint", 219 | type=str, 220 | default=None, 221 | help=( 222 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 223 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 224 | ), 225 | ) 226 | parser.add_argument( 227 | "--gradient_accumulation_steps", 228 | type=int, 229 | default=1, 230 | help="Number of updates steps to accumulate before performing a backward/update pass.", 231 | ) 232 | parser.add_argument( 233 | "--gradient_checkpointing", 234 | action="store_true", 235 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 236 | ) 237 | parser.add_argument( 238 | "--learning_rate", 239 | type=float, 240 | default=5e-6, 241 | help="Initial learning rate (after the potential warmup period) to use.", 242 | ) 243 | parser.add_argument( 244 | "--scale_lr", 245 | action="store_true", 246 | default=False, 247 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 248 | ) 249 | parser.add_argument( 250 | "--lr_scheduler", 251 | type=str, 252 | default="constant", 253 | help=( 254 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 255 | ' "constant", "constant_with_warmup"]' 256 | ), 257 | ) 258 | parser.add_argument( 259 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 260 | ) 261 | parser.add_argument( 262 | "--lr_num_cycles", 263 | type=int, 264 | default=1, 265 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 266 | ) 267 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 268 | parser.add_argument( 269 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 270 | ) 271 | parser.add_argument( 272 | "--dataloader_num_workers", 273 | type=int, 274 | default=0, 275 | help=( 276 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 277 | ), 278 | ) 279 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 280 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 281 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 282 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 283 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 284 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 285 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 286 | parser.add_argument( 287 | "--hub_model_id", 288 | type=str, 289 | default=None, 290 | help="The name of the repository to keep in sync with the local `output_dir`.", 291 | ) 292 | parser.add_argument( 293 | "--logging_dir", 294 | type=str, 295 | default="logs", 296 | help=( 297 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 298 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 299 | ), 300 | ) 301 | parser.add_argument( 302 | "--allow_tf32", 303 | action="store_true", 304 | help=( 305 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 306 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 307 | ), 308 | ) 309 | parser.add_argument( 310 | "--report_to", 311 | type=str, 312 | default="wandb", 313 | help=( 314 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 315 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 316 | ), 317 | ) 318 | parser.add_argument( 319 | "--validation_prompt", 320 | type=str, 321 | default=None, 322 | help="A prompt that is used during validation to verify that the model is learning.", 323 | ) 324 | parser.add_argument( 325 | "--num_validation_audios", 326 | type=int, 327 | default=4, 328 | help="Number of audios that should be generated during validation with `validation_prompt`.", 329 | ) 330 | parser.add_argument( 331 | "--validation_steps", 332 | type=int, 333 | default=100, 334 | help=( 335 | "Run validation every X steps. Validation consists of running the prompt" 336 | " `args.validation_prompt` multiple times: `args.num_validation_audios`" 337 | " and logging the audios." 338 | ), 339 | ) 340 | parser.add_argument( 341 | "--mixed_precision", 342 | type=str, 343 | default=None, 344 | choices=["no", "fp16", "bf16"], 345 | help=( 346 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 347 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 348 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 349 | ), 350 | ) 351 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 352 | parser.add_argument( 353 | "--set_grads_to_none", 354 | action="store_true", 355 | help=( 356 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 357 | " behaviors, so disable this argument if it causes any problems. More info:" 358 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 359 | ), 360 | ) 361 | parser.add_argument( 362 | "--tokenizer_max_length", 363 | type=int, 364 | default=None, 365 | required=False, 366 | help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", 367 | ) 368 | parser.add_argument( 369 | "--validation_audios", 370 | required=False, 371 | default=None, 372 | nargs="+", 373 | help="Optional set of audios to use for validation. Used when the target pipeline takes an initial audio as input such as when training audio variation or supersampling_rate.", 374 | ) 375 | 376 | if input_args is not None: 377 | args = parser.parse_args(input_args) 378 | else: 379 | args = parser.parse_args() 380 | 381 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 382 | if env_local_rank != -1 and env_local_rank != args.local_rank: 383 | args.local_rank = env_local_rank 384 | 385 | if args.with_prior_preservation: 386 | if args.class_data_dir is None: 387 | raise ValueError("You must specify a data directory for class audios.") 388 | if args.class_prompt is None: 389 | raise ValueError("You must specify prompt for class audios.") 390 | else: 391 | # logger is not available yet 392 | if args.class_data_dir is not None: 393 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") 394 | if args.class_prompt is not None: 395 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") 396 | 397 | 398 | return args 399 | 400 | 401 | class DreamBoothDataset(Dataset): 402 | """ 403 | A dataset to prepare the instance and class audios with the prompts for fine-tuning the model. 404 | It pre-processes the audios and the tokenizes prompts. 405 | """ 406 | 407 | def __init__( 408 | self, 409 | instance_data_root, 410 | instance_prompt, 411 | tokenizer, 412 | size=512, 413 | tokenizer_max_length=None, 414 | ): 415 | self.size = size 416 | self.tokenizer = tokenizer 417 | self.tokenizer_max_length = tokenizer_max_length 418 | 419 | self.instance_data_root = Path(instance_data_root) 420 | if not self.instance_data_root.exists(): 421 | raise ValueError(f"Instance {self.instance_data_root} audios root doesn't exists.") 422 | 423 | self.instance_audios_path = list(Path(instance_data_root).iterdir()) 424 | self.num_instance_audios = len(self.instance_audios_path) 425 | self.instance_prompt = instance_prompt 426 | self._length = self.num_instance_audios 427 | 428 | def __len__(self): 429 | return self._length 430 | 431 | def __getitem__(self, index): 432 | example = {} 433 | instance_audio = audio.open(self.instance_audios_path[index % self.num_instance_audios]) 434 | instance_audio = exif_transpose(instance_audio) 435 | 436 | if not instance_audio.mode == "RGB": 437 | instance_audio = instance_audio.convert("RGB") 438 | 439 | example["instance_audios"] = self.audio_transforms(instance_audio) 440 | 441 | text_inputs = tokenize_prompt( 442 | self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length 443 | ) 444 | example["instance_prompt_ids"] = text_inputs.input_ids 445 | example["instance_attention_mask"] = text_inputs.attention_mask 446 | 447 | return example 448 | 449 | 450 | def collate_fn(examples, with_prior_preservation=False): 451 | has_attention_mask = "instance_attention_mask" in examples[0] 452 | 453 | input_ids = [example["instance_prompt_ids"] for example in examples] 454 | pixel_values = [example["instance_audios"] for example in examples] 455 | 456 | if has_attention_mask: 457 | attention_mask = [example["instance_attention_mask"] for example in examples] 458 | 459 | # Concat class and instance examples for prior preservation. 460 | # We do this to avoid doing two forward passes. 461 | pixel_values = torch.stack(pixel_values) 462 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 463 | 464 | input_ids = torch.cat(input_ids, dim=0) 465 | 466 | batch = { 467 | "input_ids": input_ids, 468 | "pixel_values": pixel_values, 469 | } 470 | 471 | if has_attention_mask: 472 | attention_mask = torch.cat(attention_mask, dim=0) 473 | batch["attention_mask"] = attention_mask 474 | 475 | return batch 476 | 477 | 478 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): 479 | if tokenizer_max_length is not None: 480 | max_length = tokenizer_max_length 481 | else: 482 | max_length = tokenizer.model_max_length 483 | 484 | text_inputs = tokenizer( 485 | prompt, 486 | truncation=True, 487 | padding="max_length", 488 | max_length=max_length, 489 | return_tensors="pt", 490 | ) 491 | 492 | return text_inputs 493 | 494 | 495 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): 496 | text_input_ids = input_ids.to(text_encoder.device) 497 | 498 | if text_encoder_use_attention_mask: 499 | attention_mask = attention_mask.to(text_encoder.device) 500 | else: 501 | attention_mask = None 502 | 503 | prompt_embeds = text_encoder( 504 | text_input_ids, 505 | attention_mask=attention_mask, 506 | ) 507 | prompt_embeds = prompt_embeds[0] 508 | 509 | return prompt_embeds 510 | 511 | 512 | def main(args): 513 | logging_dir = Path(args.output_dir, args.logging_dir) 514 | 515 | accelerator_project_config = ProjectConfiguration( 516 | total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir 517 | ) 518 | 519 | accelerator = Accelerator( 520 | gradient_accumulation_steps=args.gradient_accumulation_steps, 521 | mixed_precision=args.mixed_precision, 522 | log_with=args.report_to, 523 | project_config=accelerator_project_config, 524 | ) 525 | 526 | if args.report_to == "wandb": 527 | if not is_wandb_available(): 528 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 529 | 530 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 531 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 532 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 533 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 534 | raise ValueError( 535 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 536 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 537 | ) 538 | 539 | # Make one log on every process with the configuration for debugging. 540 | logging.basicConfig( 541 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 542 | datefmt="%m/%d/%Y %H:%M:%S", 543 | level=logging.INFO, 544 | ) 545 | logger.info(accelerator.state, main_process_only=False) 546 | if accelerator.is_local_main_process: 547 | transformers.utils.logging.set_verbosity_warning() 548 | else: 549 | transformers.utils.logging.set_verbosity_error() 550 | 551 | # If passed along, set the training seed now. 552 | if args.seed is not None: 553 | set_seed(args.seed) 554 | 555 | # Handle the repository creation 556 | if accelerator.is_main_process: 557 | if args.output_dir is not None: 558 | os.makedirs(args.output_dir, exist_ok=True) 559 | 560 | if args.push_to_hub: 561 | repo_id = create_repo( 562 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 563 | ).repo_id 564 | 565 | # Load the tokenizer 566 | if args.tokenizer_name: 567 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 568 | elif args.pretrained_model_name_or_path: 569 | tokenizer = AutoTokenizer.from_pretrained( 570 | args.pretrained_model_name_or_path, 571 | revision=args.revision, 572 | use_fast=False, 573 | ) 574 | 575 | # import correct text encoder class 576 | model = None # TODO 577 | text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_name_or_path) 578 | vq_model = None 579 | 580 | if not args.train_text_encoder: 581 | text_encoder.requires_grad_(False) 582 | 583 | if args.gradient_checkpointing: 584 | model.enable_gradient_checkpointing() 585 | if args.train_text_encoder: 586 | text_encoder.gradient_checkpointing_enable() 587 | 588 | # Check that all trainable models are in full precision 589 | low_precision_error_string = ( 590 | "Please make sure to always have all model weights in full float32 precision when starting training - even if" 591 | " doing mixed precision training. copy of the weights should still be float32." 592 | ) 593 | 594 | if accelerator.unwrap_model(model).dtype != torch.float32: 595 | raise ValueError( 596 | f"Model loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}" 597 | ) 598 | 599 | if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: 600 | raise ValueError( 601 | f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." 602 | f" {low_precision_error_string}" 603 | ) 604 | 605 | # Enable TF32 for faster training on Ampere GPUs, 606 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 607 | if args.allow_tf32: 608 | torch.backends.cuda.matmul.allow_tf32 = True 609 | 610 | if args.scale_lr: 611 | args.learning_rate = ( 612 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 613 | ) 614 | 615 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 616 | if args.use_8bit_adam: 617 | try: 618 | import bitsandbytes as bnb 619 | except ImportError: 620 | raise ImportError( 621 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 622 | ) 623 | 624 | optimizer_class = bnb.optim.AdamW8bit 625 | else: 626 | optimizer_class = torch.optim.AdamW 627 | 628 | # Optimizer creation 629 | params_to_optimize = ( 630 | itertools.chain(model.parameters(), text_encoder.parameters()) if args.train_text_encoder else model.parameters() 631 | ) 632 | optimizer = optimizer_class( 633 | params_to_optimize, 634 | lr=args.learning_rate, 635 | betas=(args.adam_beta1, args.adam_beta2), 636 | weight_decay=args.adam_weight_decay, 637 | eps=args.adam_epsilon, 638 | ) 639 | 640 | # Dataset and DataLoaders creation: 641 | train_dataset = DreamBoothDataset( 642 | instance_data_root=args.instance_data_dir, 643 | instance_prompt=args.instance_prompt, 644 | tokenizer=tokenizer, 645 | size=args.sampling_rate, 646 | tokenizer_max_length=args.tokenizer_max_length, 647 | ) 648 | import ipdb; ipdb.set_trace() 649 | 650 | train_dataloader = torch.utils.data.DataLoader( 651 | train_dataset, 652 | batch_size=args.train_batch_size, 653 | shuffle=True, 654 | collate_fn=lambda examples: collate_fn(examples), 655 | num_workers=args.dataloader_num_workers, 656 | ) 657 | 658 | # Scheduler and math around the number of training steps. 659 | overrode_max_train_steps = False 660 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 661 | if args.max_train_steps is None: 662 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 663 | overrode_max_train_steps = True 664 | 665 | lr_scheduler = get_scheduler( 666 | args.lr_scheduler, 667 | optimizer=optimizer, 668 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 669 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 670 | num_cycles=args.lr_num_cycles, 671 | power=args.lr_power, 672 | ) 673 | 674 | # Prepare everything with our `accelerator`. 675 | if args.train_text_encoder: 676 | model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 677 | model, text_encoder, optimizer, train_dataloader, lr_scheduler 678 | ) 679 | else: 680 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 681 | model, optimizer, train_dataloader, lr_scheduler 682 | ) 683 | 684 | # For mixed precision training we cast the text_encoder and vq_model weights to half-precision 685 | # as these models are only used for inference, keeping weights in full precision is not required. 686 | weight_dtype = torch.float32 687 | if accelerator.mixed_precision == "fp16": 688 | weight_dtype = torch.float16 689 | elif accelerator.mixed_precision == "bf16": 690 | weight_dtype = torch.bfloat16 691 | 692 | # Move vq_model and text_encoder to device and cast to weight_dtype 693 | if vq_model is not None: 694 | vq_model.to(accelerator.device, dtype=weight_dtype) 695 | 696 | if not args.train_text_encoder and text_encoder is not None: 697 | text_encoder.to(accelerator.device, dtype=weight_dtype) 698 | 699 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 700 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 701 | if overrode_max_train_steps: 702 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 703 | # Afterwards we recalculate our number of training epochs 704 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 705 | 706 | # We need to initialize the trackers we use, and also store our configuration. 707 | # The trackers initializes automatically on the main process. 708 | if accelerator.is_main_process: 709 | accelerator.init_trackers("dreambooth", config=vars(args)) 710 | 711 | # Train! 712 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 713 | 714 | logger.info("***** Running training *****") 715 | logger.info(f" Num examples = {len(train_dataset)}") 716 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 717 | logger.info(f" Num Epochs = {args.num_train_epochs}") 718 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 719 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 720 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 721 | logger.info(f" Total optimization steps = {args.max_train_steps}") 722 | global_step = 0 723 | first_epoch = 0 724 | 725 | # Potentially load in the weights and states from a previous save 726 | if args.resume_from_checkpoint: 727 | if args.resume_from_checkpoint != "latest": 728 | path = os.path.basename(args.resume_from_checkpoint) 729 | else: 730 | # Get the mos recent checkpoint 731 | dirs = os.listdir(args.output_dir) 732 | dirs = [d for d in dirs if d.startswith("checkpoint")] 733 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 734 | path = dirs[-1] if len(dirs) > 0 else None 735 | 736 | if path is None: 737 | accelerator.print( 738 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 739 | ) 740 | args.resume_from_checkpoint = None 741 | else: 742 | accelerator.print(f"Resuming from checkpoint {path}") 743 | accelerator.load_state(os.path.join(args.output_dir, path)) 744 | global_step = int(path.split("-")[1]) 745 | 746 | resume_global_step = global_step * args.gradient_accumulation_steps 747 | first_epoch = global_step // num_update_steps_per_epoch 748 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 749 | 750 | # Only show the progress bar once on each machine. 751 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 752 | progress_bar.set_description("Steps") 753 | 754 | for epoch in range(first_epoch, args.num_train_epochs): 755 | model.train() 756 | if args.train_text_encoder: 757 | text_encoder.train() 758 | 759 | for step, batch in enumerate(train_dataloader): 760 | # Skip steps until we reach the resumed step 761 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 762 | if step % args.gradient_accumulation_steps == 0: 763 | progress_bar.update(1) 764 | continue 765 | 766 | with accelerator.accumulate(model): 767 | pixel_values = batch["pixel_values"].to(dtype=weight_dtype) 768 | 769 | # Convert audios to latent space 770 | model_input = vq_model.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 771 | model_input = model_input * vq_model.config.scaling_factor 772 | 773 | # Sample noise that we'll add to the model input 774 | if args.offset_noise: 775 | noise = torch.randn_like(model_input) + 0.1 * torch.randn( 776 | model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device 777 | ) 778 | else: 779 | noise = torch.randn_like(model_input) 780 | bsz, channels, height, width = model_input.shape 781 | # Sample a random timestep for each audio 782 | timesteps = torch.randint( 783 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 784 | ) 785 | timesteps = timesteps.long() 786 | 787 | # Add noise to the model input according to the noise magnitude at each timestep 788 | # (this is the forward diffusion process) 789 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) 790 | 791 | # Get the text embedding for conditioning 792 | encoder_hidden_states = encode_prompt( 793 | text_encoder, 794 | batch["input_ids"], 795 | batch["attention_mask"], 796 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, 797 | ) 798 | 799 | # Predict the noise residual 800 | model_pred = model( 801 | noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels 802 | ).sample 803 | 804 | if model_pred.shape[1] == 6: 805 | model_pred, _ = torch.chunk(model_pred, 2, dim=1) 806 | 807 | # Get the target for loss depending on the prediction type 808 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 809 | 810 | accelerator.backward(loss) 811 | if accelerator.sync_gradients: 812 | params_to_clip = ( 813 | itertools.chain(model.parameters(), text_encoder.parameters()) 814 | if args.train_text_encoder 815 | else model.parameters() 816 | ) 817 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 818 | optimizer.step() 819 | lr_scheduler.step() 820 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 821 | 822 | # Checks if the accelerator has performed an optimization step behind the scenes 823 | if accelerator.sync_gradients: 824 | progress_bar.update(1) 825 | global_step += 1 826 | 827 | if accelerator.is_main_process: 828 | audios = [] 829 | if global_step % args.checkpointing_steps == 0: 830 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 831 | accelerator.save_state(save_path) 832 | logger.info(f"Saved state to {save_path}") 833 | 834 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 835 | audios = log_validation( 836 | text_encoder, 837 | tokenizer, 838 | model, 839 | vq_model, 840 | args, 841 | accelerator, 842 | weight_dtype, 843 | epoch, 844 | ) 845 | 846 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 847 | progress_bar.set_postfix(**logs) 848 | accelerator.log(logs, step=global_step) 849 | 850 | if global_step >= args.max_train_steps: 851 | break 852 | 853 | # Create the pipeline using using the trained modules and save it. 854 | accelerator.wait_for_everyone() 855 | if accelerator.is_main_process: 856 | # TODO(PVP): run inference 857 | 858 | if args.push_to_hub: 859 | save_model_card( 860 | repo_id, 861 | audios=audios, 862 | base_model=args.pretrained_model_name_or_path, 863 | train_text_encoder=args.train_text_encoder, 864 | prompt=args.instance_prompt, 865 | repo_folder=args.output_dir, 866 | pipeline=pipeline, 867 | ) 868 | upload_folder( 869 | repo_id=repo_id, 870 | folder_path=args.output_dir, 871 | commit_message="End of training", 872 | ignore_patterns=["step_*", "epoch_*"], 873 | ) 874 | 875 | accelerator.end_training() 876 | 877 | 878 | if __name__ == "__main__": 879 | args = parse_args() 880 | main(args) 881 | --------------------------------------------------------------------------------