├── README.md ├── predict.py ├── requirements.txt ├── sd_config.py └── train_sd1-5_naruto.py /README.md: -------------------------------------------------------------------------------- 1 | # Stbale Diffusion Examples 2 | 3 | ## 1. 环境安装 4 | 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | ## (可选)模型与数据集准备 10 | 11 | - 如果你与HuggingFace网络连接顺畅,可以直接运行训练代码; 12 | - 如果不,[百度云](https://pan.baidu.com/s/1Yu5HjXnHxK0Wgymc8G-g5g?pwd=gtk8),提取码: gtk8;将压缩文件下载并解压后,放到与代码同一目录下,并将`sd_config.py`中`pretrained_model_name_or_path`和`dataset_name`的default进行修改: 13 | 14 | ```python 15 | parser.add_argument( 16 | "--pretrained_model_name_or_path", 17 | type=str, 18 | default="./stable-diffusion-v1-5", 19 | ) 20 | parser.add_argument( 21 | "--dataset_name", 22 | type=str, 23 | default="./naruto-blip-captions", 24 | ) 25 | ``` 26 | 27 | 28 | ## 2. 训练 29 | 30 | SD1.5 + 火影任务数据集: 31 | 32 | ```python 33 | python train_sd1-5_naruto.py \ 34 | --use_ema \ 35 | --resolution=512 --center_crop --random_flip \ 36 | --train_batch_size=1 \ 37 | --gradient_accumulation_steps=4 \ 38 | --gradient_checkpointing \ 39 | --max_train_steps=15000 \ 40 | --learning_rate=1e-05 \ 41 | --max_grad_norm=1 \ 42 | --seed=42 \ 43 | --lr_scheduler="constant" \ 44 | --lr_warmup_steps=0 \ 45 | --output_dir="sd-naruto-model" 46 | ``` 47 | 48 | ## 3. 推理 49 | 50 | ```bash 51 | python predict.py 52 | ``` 53 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import torch 3 | 4 | model_id = "./sd-naruto-model" 5 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) 6 | pipe = pipe.to("cuda") 7 | 8 | prompt = "Lebron James with a hat" 9 | image = pipe(prompt).images[0] 10 | 11 | image.save("result.png") 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | swanlab 4 | transformers 5 | diffusers 6 | datasets 7 | accelerate -------------------------------------------------------------------------------- /sd_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 6 | parser.add_argument( 7 | "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." 8 | ) 9 | parser.add_argument( 10 | "--pretrained_model_name_or_path", 11 | type=str, 12 | default="runwayml/stable-diffusion-v1-5", 13 | help="Path to pretrained model or model identifier from huggingface.co/models.", 14 | ) 15 | parser.add_argument( 16 | "--dataset_name", 17 | type=str, 18 | default="lambdalabs/naruto-blip-captions", 19 | help=( 20 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 21 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 22 | " or to a folder containing files that 🤗 Datasets can understand." 23 | ), 24 | ) 25 | parser.add_argument( 26 | "--revision", 27 | type=str, 28 | default=None, 29 | required=False, 30 | help="Revision of pretrained model identifier from huggingface.co/models.", 31 | ) 32 | parser.add_argument( 33 | "--variant", 34 | type=str, 35 | default=None, 36 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 37 | ) 38 | parser.add_argument( 39 | "--dataset_config_name", 40 | type=str, 41 | default=None, 42 | help="The config of the Dataset, leave as None if there's only one config.", 43 | ) 44 | parser.add_argument( 45 | "--train_data_dir", 46 | type=str, 47 | default=None, 48 | help=( 49 | "A folder containing the training data. Folder contents must follow the structure described in" 50 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 51 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 52 | ), 53 | ) 54 | parser.add_argument( 55 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 56 | ) 57 | parser.add_argument( 58 | "--caption_column", 59 | type=str, 60 | default="text", 61 | help="The column of the dataset containing a caption or a list of captions.", 62 | ) 63 | parser.add_argument( 64 | "--max_train_samples", 65 | type=int, 66 | default=None, 67 | help=( 68 | "For debugging purposes or quicker training, truncate the number of training examples to this " 69 | "value if set." 70 | ), 71 | ) 72 | parser.add_argument( 73 | "--validation_prompts", 74 | type=str, 75 | default=["Bill Gates with a hoodie", "John Oliver with Naruto style", "Lebron James with a hat", "Mickael Jackson as a ninja"], 76 | nargs="+", 77 | help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), 78 | ) 79 | parser.add_argument( 80 | "--output_dir", 81 | type=str, 82 | default="sd-model-finetuned", 83 | help="The output directory where the model predictions and checkpoints will be written.", 84 | ) 85 | parser.add_argument( 86 | "--cache_dir", 87 | type=str, 88 | default=None, 89 | help="The directory where the downloaded models and datasets will be stored.", 90 | ) 91 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 92 | parser.add_argument( 93 | "--resolution", 94 | type=int, 95 | default=512, 96 | help=( 97 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 98 | " resolution" 99 | ), 100 | ) 101 | parser.add_argument( 102 | "--center_crop", 103 | default=False, 104 | action="store_true", 105 | help=( 106 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 107 | " cropped. The images will be resized to the resolution first before cropping." 108 | ), 109 | ) 110 | parser.add_argument( 111 | "--random_flip", 112 | action="store_true", 113 | help="whether to randomly flip images horizontally", 114 | ) 115 | parser.add_argument( 116 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 117 | ) 118 | parser.add_argument("--num_train_epochs", type=int, default=100) 119 | parser.add_argument( 120 | "--max_train_steps", 121 | type=int, 122 | default=None, 123 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 124 | ) 125 | parser.add_argument( 126 | "--gradient_accumulation_steps", 127 | type=int, 128 | default=1, 129 | help="Number of updates steps to accumulate before performing a backward/update pass.", 130 | ) 131 | parser.add_argument( 132 | "--gradient_checkpointing", 133 | action="store_true", 134 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 135 | ) 136 | parser.add_argument( 137 | "--learning_rate", 138 | type=float, 139 | default=1e-4, 140 | help="Initial learning rate (after the potential warmup period) to use.", 141 | ) 142 | parser.add_argument( 143 | "--scale_lr", 144 | action="store_true", 145 | default=False, 146 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 147 | ) 148 | parser.add_argument( 149 | "--lr_scheduler", 150 | type=str, 151 | default="constant", 152 | help=( 153 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 154 | ' "constant", "constant_with_warmup"]' 155 | ), 156 | ) 157 | parser.add_argument( 158 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 159 | ) 160 | parser.add_argument( 161 | "--snr_gamma", 162 | type=float, 163 | default=None, 164 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 165 | "More details here: https://arxiv.org/abs/2303.09556.", 166 | ) 167 | parser.add_argument( 168 | "--dream_training", 169 | action="store_true", 170 | help=( 171 | "Use the DREAM training method, which makes training more efficient and accurate at the ", 172 | "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210", 173 | ), 174 | ) 175 | parser.add_argument( 176 | "--dream_detail_preservation", 177 | type=float, 178 | default=1.0, 179 | help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)", 180 | ) 181 | parser.add_argument( 182 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 183 | ) 184 | parser.add_argument( 185 | "--allow_tf32", 186 | action="store_true", 187 | help=( 188 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 189 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 190 | ), 191 | ) 192 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 193 | parser.add_argument( 194 | "--non_ema_revision", 195 | type=str, 196 | default=None, 197 | required=False, 198 | help=( 199 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" 200 | " remote repository specified with --pretrained_model_name_or_path." 201 | ), 202 | ) 203 | parser.add_argument( 204 | "--dataloader_num_workers", 205 | type=int, 206 | default=0, 207 | help=( 208 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 209 | ), 210 | ) 211 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 212 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 213 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 214 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 215 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 216 | parser.add_argument( 217 | "--prediction_type", 218 | type=str, 219 | default=None, 220 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.", 221 | ) 222 | parser.add_argument( 223 | "--logging_dir", 224 | type=str, 225 | default="logs", 226 | help=( 227 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 228 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 229 | ), 230 | ) 231 | parser.add_argument( 232 | "--mixed_precision", 233 | type=str, 234 | default=None, 235 | choices=["no", "fp16", "bf16"], 236 | help=( 237 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 238 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 239 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 240 | ), 241 | ) 242 | parser.add_argument( 243 | "--report_to", 244 | type=str, 245 | default="tensorboard", 246 | help=( 247 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 248 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 249 | ), 250 | ) 251 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 252 | parser.add_argument( 253 | "--checkpointing_steps", 254 | type=int, 255 | default=500, 256 | help=( 257 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 258 | " training using `--resume_from_checkpoint`." 259 | ), 260 | ) 261 | parser.add_argument( 262 | "--checkpoints_total_limit", 263 | type=int, 264 | default=None, 265 | help=("Max number of checkpoints to store."), 266 | ) 267 | parser.add_argument( 268 | "--resume_from_checkpoint", 269 | type=str, 270 | default=None, 271 | help=( 272 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 273 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 274 | ), 275 | ) 276 | parser.add_argument( 277 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 278 | ) 279 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 280 | parser.add_argument( 281 | "--validation_epochs", 282 | type=int, 283 | default=5, 284 | help="Run validation every X epochs.", 285 | ) 286 | parser.add_argument( 287 | "--tracker_project_name", 288 | type=str, 289 | default="text2image-fine-tune", 290 | help=( 291 | "The `project_name` argument passed to Accelerator.init_trackers for" 292 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 293 | ), 294 | ) 295 | 296 | args = parser.parse_args() 297 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 298 | if env_local_rank != -1 and env_local_rank != args.local_rank: 299 | args.local_rank = env_local_rank 300 | 301 | # Sanity checks 302 | if args.dataset_name is None and args.train_data_dir is None: 303 | raise ValueError("Need either a dataset name or a training folder.") 304 | 305 | # default to using the same revision for the non-ema model if not specified 306 | if args.non_ema_revision is None: 307 | args.non_ema_revision = args.revision 308 | 309 | return args 310 | -------------------------------------------------------------------------------- /train_sd1-5_naruto.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import random 5 | import shutil 6 | from contextlib import nullcontext 7 | 8 | import accelerate 9 | import datasets 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.state import AcceleratorState 18 | from accelerate.utils import ProjectConfiguration, set_seed 19 | from datasets import load_dataset 20 | from packaging import version 21 | from torchvision import transforms 22 | from tqdm.auto import tqdm 23 | from transformers import CLIPTextModel, CLIPTokenizer 24 | from transformers.utils import ContextManagers 25 | 26 | import diffusers 27 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 28 | from diffusers.optimization import get_scheduler 29 | from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr 30 | from diffusers.utils import check_min_version, deprecate, make_image_grid 31 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 32 | from diffusers.utils.torch_utils import is_compiled_module 33 | 34 | import swanlab 35 | from swanlab.integration.accelerate import SwanLabTracker 36 | 37 | from sd_config import parse_args 38 | 39 | 40 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 41 | check_min_version("0.29.0") 42 | 43 | logger = get_logger(__name__, log_level="INFO") 44 | 45 | DATASET_NAME_MAPPING = { 46 | "lambdalabs/naruto-blip-captions": ("image", "text"), 47 | "reach-vb/pokemon-blip-captions": ("image", "text"), 48 | } 49 | 50 | def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): 51 | logger.info("Running validation... ") 52 | 53 | pipeline = StableDiffusionPipeline.from_pretrained( 54 | args.pretrained_model_name_or_path, 55 | vae=accelerator.unwrap_model(vae), 56 | text_encoder=accelerator.unwrap_model(text_encoder), 57 | tokenizer=tokenizer, 58 | unet=accelerator.unwrap_model(unet), 59 | safety_checker=None, 60 | revision=args.revision, 61 | variant=args.variant, 62 | torch_dtype=weight_dtype, 63 | ) 64 | 65 | pipeline = pipeline.to(accelerator.device) 66 | pipeline.set_progress_bar_config(disable=True) 67 | 68 | if args.enable_xformers_memory_efficient_attention: 69 | pipeline.enable_xformers_memory_efficient_attention() 70 | 71 | if args.seed is None: 72 | generator = None 73 | else: 74 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 75 | 76 | images = [] 77 | for i in range(len(args.validation_prompts)): 78 | if torch.backends.mps.is_available(): 79 | autocast_ctx = nullcontext() 80 | else: 81 | autocast_ctx = torch.autocast(accelerator.device.type) 82 | 83 | with autocast_ctx: 84 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] 85 | 86 | images.append(swanlab.Image(image, caption=f"{i}: {args.validation_prompts[i]}")) 87 | 88 | accelerator.log({"validation": images}) 89 | 90 | del pipeline 91 | torch.cuda.empty_cache() 92 | 93 | return images 94 | 95 | 96 | def main(): 97 | args = parse_args() 98 | 99 | if args.non_ema_revision is not None: 100 | deprecate( 101 | "non_ema_revision!=None", 102 | "0.15.0", 103 | message=( 104 | "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" 105 | " use `--variant=non_ema` instead." 106 | ), 107 | ) 108 | 109 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 110 | 111 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 112 | 113 | # 定义swanlab_tracker 114 | swanlab_tracker = SwanLabTracker( 115 | "SD-Naruto", 116 | experiment_name="SD1-5_火影忍者", 117 | description="基础模型:sd-v1.5;数据集naruto-blip-captions;" 118 | ), 119 | 120 | # 初始化acclerator 121 | accelerator = Accelerator( 122 | gradient_accumulation_steps=args.gradient_accumulation_steps, 123 | mixed_precision=args.mixed_precision, 124 | log_with=swanlab_tracker, 125 | project_config=accelerator_project_config, 126 | ) 127 | 128 | # Disable AMP for MPS. 129 | if torch.backends.mps.is_available(): 130 | accelerator.native_amp = False 131 | 132 | # Make one log on every process with the configuration for debugging. 133 | logging.basicConfig( 134 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 135 | datefmt="%m/%d/%Y %H:%M:%S", 136 | level=logging.INFO, 137 | ) 138 | logger.info(accelerator.state, main_process_only=False) 139 | if accelerator.is_local_main_process: 140 | datasets.utils.logging.set_verbosity_warning() 141 | transformers.utils.logging.set_verbosity_warning() 142 | diffusers.utils.logging.set_verbosity_info() 143 | else: 144 | datasets.utils.logging.set_verbosity_error() 145 | transformers.utils.logging.set_verbosity_error() 146 | diffusers.utils.logging.set_verbosity_error() 147 | 148 | # 设置随机数种子 149 | if args.seed is not None: 150 | set_seed(args.seed) 151 | 152 | # Handle the repository creation 153 | if accelerator.is_main_process: 154 | if args.output_dir is not None: 155 | os.makedirs(args.output_dir, exist_ok=True) 156 | 157 | # Load scheduler, tokenizer and models. 158 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 159 | tokenizer = CLIPTokenizer.from_pretrained( 160 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision 161 | ) 162 | 163 | def deepspeed_zero_init_disabled_context_manager(): 164 | """ 165 | returns either a context list that includes one that will disable zero.Init or an empty context list 166 | """ 167 | deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None 168 | if deepspeed_plugin is None: 169 | return [] 170 | 171 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)] 172 | 173 | # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. 174 | # For this to work properly all models must be run through `accelerate.prepare`. But accelerate 175 | # will try to assign the same optimizer with the same weights to all models during 176 | # `deepspeed.initialize`, which of course doesn't work. 177 | # 178 | # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 179 | # frozen models from being partitioned during `zero.Init` which gets called during 180 | # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding 181 | # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. 182 | with ContextManagers(deepspeed_zero_init_disabled_context_manager()): 183 | text_encoder = CLIPTextModel.from_pretrained( 184 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 185 | ) 186 | vae = AutoencoderKL.from_pretrained( 187 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 188 | ) 189 | 190 | unet = UNet2DConditionModel.from_pretrained( 191 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision 192 | ) 193 | 194 | # Freeze vae and text_encoder and set unet to trainable 195 | vae.requires_grad_(False) 196 | text_encoder.requires_grad_(False) 197 | unet.train() 198 | 199 | # Create EMA for the unet. 200 | if args.use_ema: 201 | ema_unet = UNet2DConditionModel.from_pretrained( 202 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 203 | ) 204 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) 205 | 206 | # `accelerate` 0.16.0 will have better support for customized saving 207 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 208 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 209 | def save_model_hook(models, weights, output_dir): 210 | if accelerator.is_main_process: 211 | if args.use_ema: 212 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 213 | 214 | for i, model in enumerate(models): 215 | model.save_pretrained(os.path.join(output_dir, "unet")) 216 | 217 | # make sure to pop weight so that corresponding model is not saved again 218 | weights.pop() 219 | 220 | def load_model_hook(models, input_dir): 221 | if args.use_ema: 222 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) 223 | ema_unet.load_state_dict(load_model.state_dict()) 224 | ema_unet.to(accelerator.device) 225 | del load_model 226 | 227 | for _ in range(len(models)): 228 | # pop models so that they are not loaded again 229 | model = models.pop() 230 | 231 | # load diffusers style into model 232 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 233 | model.register_to_config(**load_model.config) 234 | 235 | model.load_state_dict(load_model.state_dict()) 236 | del load_model 237 | 238 | accelerator.register_save_state_pre_hook(save_model_hook) 239 | accelerator.register_load_state_pre_hook(load_model_hook) 240 | 241 | if args.gradient_checkpointing: 242 | unet.enable_gradient_checkpointing() 243 | 244 | # Enable TF32 for faster training on Ampere GPUs, 245 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 246 | if args.allow_tf32: 247 | torch.backends.cuda.matmul.allow_tf32 = True 248 | 249 | if args.scale_lr: 250 | args.learning_rate = ( 251 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 252 | ) 253 | 254 | optimizer_cls = torch.optim.AdamW 255 | 256 | optimizer = optimizer_cls( 257 | unet.parameters(), 258 | lr=args.learning_rate, 259 | betas=(args.adam_beta1, args.adam_beta2), 260 | weight_decay=args.adam_weight_decay, 261 | eps=args.adam_epsilon, 262 | ) 263 | 264 | # Get the datasets: you can either provide your own training and evaluation files (see below) 265 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 266 | 267 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 268 | # download the dataset. 269 | if args.dataset_name is not None: 270 | # Downloading and loading a dataset from the hub. 271 | dataset = load_dataset( 272 | args.dataset_name, 273 | args.dataset_config_name, 274 | cache_dir=args.cache_dir, 275 | data_dir=args.train_data_dir, 276 | ) 277 | else: 278 | data_files = {} 279 | if args.train_data_dir is not None: 280 | data_files["train"] = os.path.join(args.train_data_dir, "**") 281 | dataset = load_dataset( 282 | "imagefolder", 283 | data_files=data_files, 284 | cache_dir=args.cache_dir, 285 | ) 286 | # See more about loading custom images at 287 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 288 | 289 | # Preprocessing the datasets. 290 | # We need to tokenize inputs and targets. 291 | column_names = dataset["train"].column_names 292 | 293 | # 6. Get the column names for input/target. 294 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) 295 | if args.image_column is None: 296 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 297 | else: 298 | image_column = args.image_column 299 | if image_column not in column_names: 300 | raise ValueError( 301 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 302 | ) 303 | if args.caption_column is None: 304 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 305 | else: 306 | caption_column = args.caption_column 307 | if caption_column not in column_names: 308 | raise ValueError( 309 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 310 | ) 311 | 312 | # Preprocessing the datasets. 313 | # We need to tokenize input captions and transform the images. 314 | def tokenize_captions(examples, is_train=True): 315 | captions = [] 316 | for caption in examples[caption_column]: 317 | if isinstance(caption, str): 318 | captions.append(caption) 319 | elif isinstance(caption, (list, np.ndarray)): 320 | # take a random caption if there are multiple 321 | captions.append(random.choice(caption) if is_train else caption[0]) 322 | else: 323 | raise ValueError( 324 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 325 | ) 326 | inputs = tokenizer( 327 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 328 | ) 329 | return inputs.input_ids 330 | 331 | # Preprocessing the datasets. 332 | train_transforms = transforms.Compose( 333 | [ 334 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 335 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 336 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 337 | transforms.ToTensor(), 338 | transforms.Normalize([0.5], [0.5]), 339 | ] 340 | ) 341 | 342 | def preprocess_train(examples): 343 | images = [image.convert("RGB") for image in examples[image_column]] 344 | examples["pixel_values"] = [train_transforms(image) for image in images] 345 | examples["input_ids"] = tokenize_captions(examples) 346 | return examples 347 | 348 | with accelerator.main_process_first(): 349 | if args.max_train_samples is not None: 350 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 351 | # Set the training transforms 352 | train_dataset = dataset["train"].with_transform(preprocess_train) 353 | 354 | def collate_fn(examples): 355 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 356 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 357 | input_ids = torch.stack([example["input_ids"] for example in examples]) 358 | return {"pixel_values": pixel_values, "input_ids": input_ids} 359 | 360 | # DataLoaders creation: 361 | train_dataloader = torch.utils.data.DataLoader( 362 | train_dataset, 363 | shuffle=True, 364 | collate_fn=collate_fn, 365 | batch_size=args.train_batch_size, 366 | num_workers=args.dataloader_num_workers, 367 | ) 368 | 369 | # Scheduler and math around the number of training steps. 370 | overrode_max_train_steps = False 371 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 372 | if args.max_train_steps is None: 373 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 374 | overrode_max_train_steps = True 375 | 376 | lr_scheduler = get_scheduler( 377 | args.lr_scheduler, 378 | optimizer=optimizer, 379 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 380 | num_training_steps=args.max_train_steps * accelerator.num_processes, 381 | ) 382 | 383 | # Prepare everything with our `accelerator`. 384 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 385 | unet, optimizer, train_dataloader, lr_scheduler 386 | ) 387 | 388 | if args.use_ema: 389 | ema_unet.to(accelerator.device) 390 | 391 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision 392 | # as these weights are only used for inference, keeping weights in full precision is not required. 393 | weight_dtype = torch.float32 394 | if accelerator.mixed_precision == "fp16": 395 | weight_dtype = torch.float16 396 | args.mixed_precision = accelerator.mixed_precision 397 | elif accelerator.mixed_precision == "bf16": 398 | weight_dtype = torch.bfloat16 399 | args.mixed_precision = accelerator.mixed_precision 400 | 401 | # Move text_encode and vae to gpu and cast to weight_dtype 402 | text_encoder.to(accelerator.device, dtype=weight_dtype) 403 | vae.to(accelerator.device, dtype=weight_dtype) 404 | 405 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 406 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 407 | if overrode_max_train_steps: 408 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 409 | # Afterwards we recalculate our number of training epochs 410 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 411 | 412 | # We need to initialize the trackers we use, and also store our configuration. 413 | # The trackers initializes automatically on the main process. 414 | if accelerator.is_main_process: 415 | tracker_config = dict(vars(args)) 416 | tracker_config.pop("validation_prompts") 417 | accelerator.init_trackers(args.tracker_project_name, tracker_config) 418 | 419 | # Function for unwrapping if model was compiled with `torch.compile`. 420 | def unwrap_model(model): 421 | model = accelerator.unwrap_model(model) 422 | model = model._orig_mod if is_compiled_module(model) else model 423 | return model 424 | 425 | # Train! 426 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 427 | 428 | logger.info("***** Running training *****") 429 | logger.info(f" Num examples = {len(train_dataset)}") 430 | logger.info(f" Num Epochs = {args.num_train_epochs}") 431 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 432 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 433 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 434 | logger.info(f" Total optimization steps = {args.max_train_steps}") 435 | global_step = 0 436 | first_epoch = 0 437 | 438 | # Potentially load in the weights and states from a previous save 439 | if args.resume_from_checkpoint: 440 | if args.resume_from_checkpoint != "latest": 441 | path = os.path.basename(args.resume_from_checkpoint) 442 | else: 443 | # Get the most recent checkpoint 444 | dirs = os.listdir(args.output_dir) 445 | dirs = [d for d in dirs if d.startswith("checkpoint")] 446 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 447 | path = dirs[-1] if len(dirs) > 0 else None 448 | 449 | if path is None: 450 | accelerator.print( 451 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 452 | ) 453 | args.resume_from_checkpoint = None 454 | initial_global_step = 0 455 | else: 456 | accelerator.print(f"Resuming from checkpoint {path}") 457 | accelerator.load_state(os.path.join(args.output_dir, path)) 458 | global_step = int(path.split("-")[1]) 459 | 460 | initial_global_step = global_step 461 | first_epoch = global_step // num_update_steps_per_epoch 462 | 463 | else: 464 | initial_global_step = 0 465 | 466 | progress_bar = tqdm( 467 | range(0, args.max_train_steps), 468 | initial=initial_global_step, 469 | desc="Steps", 470 | # Only show the progress bar once on each machine. 471 | disable=not accelerator.is_local_main_process, 472 | ) 473 | 474 | for epoch in range(first_epoch, args.num_train_epochs): 475 | train_loss = 0.0 476 | accelerator.log({"epoch": epoch}, step=epoch) 477 | 478 | for step, batch in enumerate(train_dataloader): 479 | with accelerator.accumulate(unet): 480 | # Convert images to latent space 481 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() 482 | latents = latents * vae.config.scaling_factor 483 | 484 | # Sample noise that we'll add to the latents 485 | noise = torch.randn_like(latents) 486 | if args.noise_offset: 487 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 488 | noise += args.noise_offset * torch.randn( 489 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 490 | ) 491 | if args.input_perturbation: 492 | new_noise = noise + args.input_perturbation * torch.randn_like(noise) 493 | bsz = latents.shape[0] 494 | # Sample a random timestep for each image 495 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 496 | timesteps = timesteps.long() 497 | 498 | # Add noise to the latents according to the noise magnitude at each timestep 499 | # (this is the forward diffusion process) 500 | if args.input_perturbation: 501 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) 502 | else: 503 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 504 | 505 | # Get the text embedding for conditioning 506 | encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] 507 | 508 | # Get the target for loss depending on the prediction type 509 | if args.prediction_type is not None: 510 | # set prediction_type of scheduler if defined 511 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 512 | 513 | if noise_scheduler.config.prediction_type == "epsilon": 514 | target = noise 515 | elif noise_scheduler.config.prediction_type == "v_prediction": 516 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 517 | else: 518 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 519 | 520 | if args.dream_training: 521 | noisy_latents, target = compute_dream_and_update_latents( 522 | unet, 523 | noise_scheduler, 524 | timesteps, 525 | noise, 526 | noisy_latents, 527 | target, 528 | encoder_hidden_states, 529 | args.dream_detail_preservation, 530 | ) 531 | 532 | # Predict the noise residual and compute loss 533 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] 534 | 535 | if args.snr_gamma is None: 536 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 537 | else: 538 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 539 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 540 | # This is discussed in Section 4.2 of the same paper. 541 | snr = compute_snr(noise_scheduler, timesteps) 542 | mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( 543 | dim=1 544 | )[0] 545 | if noise_scheduler.config.prediction_type == "epsilon": 546 | mse_loss_weights = mse_loss_weights / snr 547 | elif noise_scheduler.config.prediction_type == "v_prediction": 548 | mse_loss_weights = mse_loss_weights / (snr + 1) 549 | 550 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 551 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 552 | loss = loss.mean() 553 | 554 | # Gather the losses across all processes for logging (if we use distributed training). 555 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 556 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 557 | 558 | # Backpropagate 559 | accelerator.backward(loss) 560 | if accelerator.sync_gradients: 561 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 562 | optimizer.step() 563 | lr_scheduler.step() 564 | optimizer.zero_grad() 565 | 566 | # Checks if the accelerator has performed an optimization step behind the scenes 567 | if accelerator.sync_gradients: 568 | if args.use_ema: 569 | ema_unet.step(unet.parameters()) 570 | progress_bar.update(1) 571 | global_step += 1 572 | accelerator.log({"train_loss": train_loss}, step=global_step) 573 | accelerator.log({"lr": lr_scheduler.get_last_lr()[0]}, step=global_step) 574 | train_loss = 0.0 575 | 576 | if global_step % args.checkpointing_steps == 0: 577 | if accelerator.is_main_process: 578 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 579 | if args.checkpoints_total_limit is not None: 580 | checkpoints = os.listdir(args.output_dir) 581 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 582 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 583 | 584 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 585 | if len(checkpoints) >= args.checkpoints_total_limit: 586 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 587 | removing_checkpoints = checkpoints[0:num_to_remove] 588 | 589 | logger.info( 590 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 591 | ) 592 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 593 | 594 | for removing_checkpoint in removing_checkpoints: 595 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 596 | shutil.rmtree(removing_checkpoint) 597 | 598 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 599 | accelerator.save_state(save_path) 600 | logger.info(f"Saved state to {save_path}") 601 | 602 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 603 | 604 | progress_bar.set_postfix(**logs) 605 | 606 | if global_step >= args.max_train_steps: 607 | break 608 | 609 | if accelerator.is_main_process: 610 | if args.validation_prompts is not None and epoch % args.validation_epochs == 0: 611 | if args.use_ema: 612 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 613 | ema_unet.store(unet.parameters()) 614 | ema_unet.copy_to(unet.parameters()) 615 | log_validation( 616 | vae, 617 | text_encoder, 618 | tokenizer, 619 | unet, 620 | args, 621 | accelerator, 622 | weight_dtype, 623 | global_step, 624 | ) 625 | if args.use_ema: 626 | # Switch back to the original UNet parameters. 627 | ema_unet.restore(unet.parameters()) 628 | 629 | # Create the pipeline using the trained modules and save it. 630 | accelerator.wait_for_everyone() 631 | if accelerator.is_main_process: 632 | unet = unwrap_model(unet) 633 | if args.use_ema: 634 | ema_unet.copy_to(unet.parameters()) 635 | 636 | pipeline = StableDiffusionPipeline.from_pretrained( 637 | args.pretrained_model_name_or_path, 638 | text_encoder=text_encoder, 639 | vae=vae, 640 | unet=unet, 641 | revision=args.revision, 642 | variant=args.variant, 643 | ) 644 | 645 | pipeline.save_pretrained(args.output_dir) 646 | 647 | # Run a final round of inference. 648 | images = [] 649 | if args.validation_prompts is not None: 650 | logger.info("Running inference for collecting generated images...") 651 | pipeline = pipeline.to(accelerator.device) 652 | pipeline.torch_dtype = weight_dtype 653 | pipeline.set_progress_bar_config(disable=True) 654 | 655 | if args.enable_xformers_memory_efficient_attention: 656 | pipeline.enable_xformers_memory_efficient_attention() 657 | 658 | if args.seed is None: 659 | generator = None 660 | else: 661 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 662 | 663 | for i in range(len(args.validation_prompts)): 664 | with torch.autocast("cuda"): 665 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] 666 | images.append(image) 667 | 668 | accelerator.end_training() 669 | 670 | 671 | if __name__ == "__main__": 672 | main() 673 | --------------------------------------------------------------------------------