├── LoRA ├── .gitignore ├── cfg │ ├── mountain.json │ └── mountain_up.json ├── dataset │ ├── mountain │ │ ├── metadata.jsonl │ │ └── mountain.jpg │ └── mountain_up │ │ ├── metadata.jsonl │ │ └── mountain_up.jpg ├── inversion_pipeline.py ├── morph.py ├── style_transfer.py ├── test_dataset.py ├── train_lora.py └── training_cfg.py ├── ReplaceAttn ├── .gitignore ├── replace_attn.py └── video_editing_pipeline.py └── TrainingScript ├── .gitignore ├── cfg_0.json ├── cfg_1.json ├── cfg_lora.json ├── ddpm_trainer.py ├── sd_lora_trainer.py ├── train_0.py ├── train_1.py ├── train_official.py ├── trainer.py ├── training_cfg_0.py ├── training_cfg_1.py └── unet_cfg └── config.json /LoRA/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | ckpt -------------------------------------------------------------------------------- /LoRA/cfg/mountain.json: -------------------------------------------------------------------------------- 1 | { 2 | "log_dir": "log", 3 | "output_dir": "ckpt", 4 | "data_dir": "dataset/mountain", 5 | "ckpt_name": "mountain", 6 | "gradient_accumulation_steps": 1, 7 | "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", 8 | "rank": 8, 9 | "enable_xformers_memory_efficient_attention": true, 10 | "learning_rate": 1e-4, 11 | "adam_beta1": 0.9, 12 | "adam_beta2": 0.999, 13 | "adam_weight_decay": 1e-2, 14 | "adam_epsilon": 1e-08, 15 | "resolution": 512, 16 | "n_epochs": 200, 17 | "checkpointing_steps": 500, 18 | "train_batch_size": 1, 19 | "dataloader_num_workers": 1, 20 | "lr_scheduler_name": "constant", 21 | "resume_from_checkpoint": false, 22 | "noise_offset": 0.1, 23 | "max_grad_norm": 1.0 24 | } -------------------------------------------------------------------------------- /LoRA/cfg/mountain_up.json: -------------------------------------------------------------------------------- 1 | { 2 | "log_dir": "log", 3 | "output_dir": "ckpt", 4 | "data_dir": "dataset/mountain_up", 5 | "ckpt_name": "mountain_up", 6 | "gradient_accumulation_steps": 1, 7 | "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", 8 | "rank": 8, 9 | "enable_xformers_memory_efficient_attention": true, 10 | "learning_rate": 1e-4, 11 | "adam_beta1": 0.9, 12 | "adam_beta2": 0.999, 13 | "adam_weight_decay": 1e-2, 14 | "adam_epsilon": 1e-08, 15 | "resolution": 512, 16 | "n_epochs": 200, 17 | "checkpointing_steps": 500, 18 | "train_batch_size": 1, 19 | "dataloader_num_workers": 1, 20 | "lr_scheduler_name": "constant", 21 | "resume_from_checkpoint": false, 22 | "noise_offset": 0.1, 23 | "max_grad_norm": 1.0 24 | } -------------------------------------------------------------------------------- /LoRA/dataset/mountain/metadata.jsonl: -------------------------------------------------------------------------------- 1 | {"file_name": "mountain.jpg", "text": "mountain"} -------------------------------------------------------------------------------- /LoRA/dataset/mountain/mountain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SingleZombie/DiffusersExample/edd72750d63ed13d395b7be1e46f7fc134b01265/LoRA/dataset/mountain/mountain.jpg -------------------------------------------------------------------------------- /LoRA/dataset/mountain_up/metadata.jsonl: -------------------------------------------------------------------------------- 1 | {"file_name": "mountain_up.jpg", "text": "mountain"} -------------------------------------------------------------------------------- /LoRA/dataset/mountain_up/mountain_up.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SingleZombie/DiffusersExample/edd72750d63ed13d395b7be1e46f7fc134b01265/LoRA/dataset/mountain_up/mountain_up.jpg -------------------------------------------------------------------------------- /LoRA/inversion_pipeline.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel 2 | from diffusers.schedulers import KarrasDiffusionSchedulers 3 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 4 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 5 | import torch 6 | from PIL import Image 7 | import numpy as np 8 | from tqdm import tqdm 9 | from torchvision import transforms 10 | 11 | 12 | def get_img(img_path, resolution=512): 13 | img = Image.open(img_path).convert("RGB") 14 | norm_mean = [0.5, 0.5, 0.5] 15 | norm_std = [0.5, 0.5, 0.5] 16 | transform = transforms.Compose([ 17 | transforms.Resize((resolution, resolution)), 18 | transforms.ToTensor(), 19 | transforms.Normalize(norm_mean, norm_std) 20 | ]) 21 | img = transform(img) 22 | return img.unsqueeze(0) 23 | 24 | 25 | class InversionPipeline(StableDiffusionPipeline): 26 | def __init__( 27 | self, 28 | vae: AutoencoderKL, 29 | text_encoder: CLIPTextModel, 30 | tokenizer: CLIPTokenizer, 31 | unet: UNet2DConditionModel, 32 | scheduler: KarrasDiffusionSchedulers, 33 | safety_checker: StableDiffusionSafetyChecker, 34 | feature_extractor: CLIPImageProcessor, 35 | image_encoder: CLIPVisionModelWithProjection = None, 36 | requires_safety_checker: bool = True, 37 | ): 38 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, 39 | safety_checker, feature_extractor, image_encoder, requires_safety_checker) 40 | 41 | @torch.no_grad() 42 | def image2latent(self, image): 43 | DEVICE = torch.device( 44 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 45 | if type(image) is Image: 46 | image = np.array(image) 47 | image = torch.from_numpy(image).float() / 127.5 - 1 48 | image = image.permute(2, 0, 1).unsqueeze(0) 49 | # input image density range [-1, 1] 50 | latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean 51 | latents = latents * 0.18215 52 | return latents 53 | 54 | @torch.no_grad() 55 | def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size): 56 | DEVICE = torch.device( 57 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 58 | # text embeddings 59 | text_input = self.tokenizer( 60 | prompt, 61 | padding="max_length", 62 | max_length=77, 63 | return_tensors="pt" 64 | ) 65 | text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0] 66 | 67 | if guidance_scale > 1.: 68 | if neg_prompt: 69 | uc_text = neg_prompt 70 | else: 71 | uc_text = "" 72 | unconditional_input = self.tokenizer( 73 | [uc_text] * batch_size, 74 | padding="max_length", 75 | max_length=77, 76 | return_tensors="pt" 77 | ) 78 | unconditional_embeddings = self.text_encoder( 79 | unconditional_input.input_ids.to(DEVICE))[0] 80 | text_embeddings = torch.cat( 81 | [unconditional_embeddings, text_embeddings], dim=0) 82 | 83 | return text_embeddings 84 | 85 | @torch.no_grad() 86 | def ddim_inversion(self, latent, cond): 87 | timesteps = reversed(self.scheduler.timesteps) 88 | with torch.autocast(device_type='cuda', dtype=torch.float32): 89 | for i, t in enumerate(tqdm(timesteps, desc="DDIM inversion")): 90 | cond_batch = cond.repeat(latent.shape[0], 1, 1) 91 | 92 | alpha_prod_t = self.scheduler.alphas_cumprod[t] 93 | alpha_prod_t_prev = ( 94 | self.scheduler.alphas_cumprod[timesteps[i - 1]] 95 | if i > 0 else self.scheduler.final_alpha_cumprod 96 | ) 97 | 98 | mu = alpha_prod_t ** 0.5 99 | mu_prev = alpha_prod_t_prev ** 0.5 100 | sigma = (1 - alpha_prod_t) ** 0.5 101 | sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 102 | eps = self.unet( 103 | latent, t, encoder_hidden_states=cond_batch).sample 104 | 105 | pred_x0 = (latent - sigma_prev * eps) / mu_prev 106 | latent = mu * pred_x0 + sigma * eps 107 | return latent 108 | 109 | @torch.no_grad() 110 | def inverse(self, img_path, prompt, n_steps, neg_prompt='', guidance_scale=7.5, batch_size=1): 111 | self.scheduler.set_timesteps(n_steps) 112 | text_embeddings = self.get_text_embeddings( 113 | prompt, guidance_scale, neg_prompt, batch_size) 114 | img = get_img(img_path) 115 | img_noise = self.ddim_inversion( 116 | self.image2latent(img), text_embeddings) 117 | return img_noise 118 | -------------------------------------------------------------------------------- /LoRA/morph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inversion_pipeline import InversionPipeline 3 | 4 | lora_path = 'ckpt/mountain.safetensor' 5 | lora_path2 = 'ckpt/mountain_up.safetensor' 6 | sd_path = 'runwayml/stable-diffusion-v1-5' 7 | 8 | 9 | @torch.no_grad() 10 | def slerp(p0, p1, fract_mixing: float): 11 | if p0.dtype == torch.float16: 12 | recast_to = 'fp16' 13 | else: 14 | recast_to = 'fp32' 15 | 16 | p0 = p0.double() 17 | p1 = p1.double() 18 | 19 | norm = torch.linalg.norm(p0) * torch.linalg.norm(p1) 20 | epsilon = 1e-7 21 | dot = torch.sum(p0 * p1) / norm 22 | dot = dot.clamp(-1+epsilon, 1-epsilon) 23 | 24 | theta_0 = torch.arccos(dot) 25 | sin_theta_0 = torch.sin(theta_0) 26 | theta_t = theta_0 * fract_mixing 27 | s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 28 | s1 = torch.sin(theta_t) / sin_theta_0 29 | interp = p0*s0 + p1*s1 30 | 31 | if recast_to == 'fp16': 32 | interp = interp.half() 33 | elif recast_to == 'fp32': 34 | interp = interp.float() 35 | 36 | return interp 37 | 38 | 39 | pipeline: InversionPipeline = InversionPipeline.from_pretrained( 40 | sd_path).to("cuda") 41 | pipeline.load_lora_weights(lora_path, adapter_name='a') 42 | pipeline.load_lora_weights(lora_path2, adapter_name='b') 43 | 44 | img1_path = 'dataset/mountain/mountain.jpg' 45 | img2_path = 'dataset/mountain_up/mountain_up.jpg' 46 | prompt = 'mountain' 47 | latent1 = pipeline.inverse(img1_path, prompt, 50, guidance_scale=1) 48 | latent2 = pipeline.inverse(img2_path, prompt, 50, guidance_scale=1) 49 | n_frames = 10 50 | images = [] 51 | for i in range(n_frames + 1): 52 | alpha = i / n_frames 53 | pipeline.set_adapters(["a", "b"], adapter_weights=[1 - alpha, alpha]) 54 | latent = slerp(latent1, latent2, alpha) 55 | output = pipeline(prompt=prompt, latents=latent, 56 | guidance_scale=1.0).images[0] 57 | images.append(output) 58 | 59 | images[0].save("output.gif", save_all=True, 60 | append_images=images[1:], duration=100, loop=0) 61 | -------------------------------------------------------------------------------- /LoRA/style_transfer.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel 2 | from PIL import Image 3 | import cv2 4 | import numpy as np 5 | 6 | lora_path = '...' 7 | sd_path = 'runwayml/stable-diffusion-v1-5' 8 | controlnet_canny_path = 'lllyasviel/sd-controlnet-canny' 9 | 10 | prompt = '1 man, look at right, side face, Ace Attorney, Phoenix Wright, best quality, danganronpa' 11 | neg_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, {multiple people}' 12 | img_path = '...' 13 | init_image = Image.open(img_path).convert("RGB") 14 | init_image = init_image.resize((768, 512)) 15 | np_image = np.array(init_image) 16 | 17 | # get canny image 18 | np_image = cv2.Canny(np_image, 100, 200) 19 | np_image = np_image[:, :, None] 20 | np_image = np.concatenate([np_image, np_image, np_image], axis=2) 21 | canny_image = Image.fromarray(np_image) 22 | canny_image.save('tmp_edge.png') 23 | 24 | controlnet = ControlNetModel.from_pretrained(controlnet_canny_path) 25 | pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( 26 | sd_path, controlnet=controlnet 27 | ) 28 | pipe.load_lora_weights(lora_path) 29 | 30 | output = pipe( 31 | prompt=prompt, 32 | negative_prompt=neg_prompt, 33 | strength=0.5, 34 | guidance_scale=7.5, 35 | controlnet_conditioning_scale=0.5, 36 | num_inference_steps=50, 37 | image=init_image, 38 | cross_attention_kwargs={"scale": 1.0}, 39 | control_image=canny_image, 40 | ).images[0] 41 | output.save("tmp.png") 42 | -------------------------------------------------------------------------------- /LoRA/test_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | dataset = load_dataset("imagefolder", data_dir="dataset/mountain") 4 | print(dataset) 5 | print(dataset["train"].column_names) 6 | print(dataset["train"]['image']) 7 | print(dataset["train"]['text']) -------------------------------------------------------------------------------- /LoRA/train_lora.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" 16 | 17 | import argparse 18 | import logging 19 | import math 20 | import os 21 | import random 22 | import shutil 23 | from pathlib import Path 24 | 25 | import datasets 26 | import numpy as np 27 | import torch 28 | import torch.nn.functional as F 29 | import torch.utils.checkpoint 30 | import transformers 31 | from accelerate import Accelerator 32 | from accelerate.logging import get_logger 33 | from accelerate.utils import ProjectConfiguration, set_seed 34 | from datasets import load_dataset 35 | from huggingface_hub import upload_folder 36 | from packaging import version 37 | from peft import LoraConfig 38 | from peft.utils import get_peft_model_state_dict 39 | from torchvision import transforms 40 | from tqdm.auto import tqdm 41 | from transformers import CLIPTextModel, CLIPTokenizer 42 | 43 | import diffusers 44 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel 45 | from diffusers.optimization import get_scheduler 46 | from diffusers.training_utils import compute_snr 47 | from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available 48 | from diffusers.utils.import_utils import is_xformers_available 49 | from training_cfg import load_training_config 50 | 51 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 52 | check_min_version("0.26.0.dev0") 53 | 54 | logger = get_logger(__name__, log_level="INFO") 55 | 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('cfg', type=str) 60 | args = parser.parse_args() 61 | 62 | cfg_path = args.cfg 63 | 64 | cfg = load_training_config(cfg_path) 65 | logging_dir = Path(cfg.log_dir) 66 | 67 | accelerator_project_config = ProjectConfiguration( 68 | project_dir=cfg.output_dir, logging_dir=logging_dir) 69 | 70 | accelerator = Accelerator( 71 | gradient_accumulation_steps=cfg.gradient_accumulation_steps, 72 | mixed_precision=cfg.mixed_precision, 73 | log_with="tensorboard", 74 | project_config=accelerator_project_config, 75 | ) 76 | 77 | # Make one log on every process with the configuration for debugging. 78 | logging.basicConfig( 79 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 80 | datefmt="%m/%d/%Y %H:%M:%S", 81 | level=logging.INFO, 82 | ) 83 | logger.info(accelerator.state, main_process_only=False) 84 | if accelerator.is_local_main_process: 85 | datasets.utils.logging.set_verbosity_warning() 86 | transformers.utils.logging.set_verbosity_warning() 87 | diffusers.utils.logging.set_verbosity_info() 88 | else: 89 | datasets.utils.logging.set_verbosity_error() 90 | transformers.utils.logging.set_verbosity_error() 91 | diffusers.utils.logging.set_verbosity_error() 92 | 93 | # If passed along, set the training seed now. 94 | if cfg.seed is not None: 95 | set_seed(cfg.seed) 96 | 97 | # Handle the repository creation 98 | if accelerator.is_main_process: 99 | if cfg.output_dir is not None: 100 | os.makedirs(cfg.output_dir, exist_ok=True) 101 | # Load scheduler, tokenizer and models. 102 | noise_scheduler = DDPMScheduler.from_pretrained( 103 | cfg.pretrained_model_name_or_path, subfolder="scheduler") 104 | tokenizer = CLIPTokenizer.from_pretrained( 105 | cfg.pretrained_model_name_or_path, subfolder="tokenizer" 106 | ) 107 | text_encoder = CLIPTextModel.from_pretrained( 108 | cfg.pretrained_model_name_or_path, subfolder="text_encoder" 109 | ) 110 | vae = AutoencoderKL.from_pretrained( 111 | cfg.pretrained_model_name_or_path, subfolder="vae", 112 | ) 113 | unet = UNet2DConditionModel.from_pretrained( 114 | cfg.pretrained_model_name_or_path, subfolder="unet" 115 | ) 116 | # freeze parameters of models to save more memory 117 | unet.requires_grad_(False) 118 | vae.requires_grad_(False) 119 | text_encoder.requires_grad_(False) 120 | 121 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 122 | # as these weights are only used for inference, keeping weights in full precision is not required. 123 | weight_dtype = torch.float32 124 | if accelerator.mixed_precision == "fp16": 125 | weight_dtype = torch.float16 126 | elif accelerator.mixed_precision == "bf16": 127 | weight_dtype = torch.bfloat16 128 | 129 | # Freeze the unet parameters before adding adapters 130 | for param in unet.parameters(): 131 | param.requires_grad_(False) 132 | 133 | unet_lora_config = LoraConfig( 134 | r=cfg.rank, 135 | lora_alpha=cfg.rank, 136 | init_lora_weights="gaussian", 137 | target_modules=["to_k", "to_q", "to_v", "to_out.0"], 138 | ) 139 | 140 | # Move unet, vae and text_encoder to device and cast to weight_dtype 141 | unet.to(accelerator.device, dtype=weight_dtype) 142 | vae.to(accelerator.device, dtype=weight_dtype) 143 | text_encoder.to(accelerator.device, dtype=weight_dtype) 144 | 145 | # Add adapter and make sure the trainable params are in float32. 146 | unet.add_adapter(unet_lora_config) 147 | if cfg.mixed_precision == "fp16": 148 | for param in unet.parameters(): 149 | # only upcast trainable parameters (LoRA) into fp32 150 | if param.requires_grad: 151 | param.data = param.to(torch.float32) 152 | 153 | if cfg.enable_xformers_memory_efficient_attention: 154 | if is_xformers_available(): 155 | import xformers 156 | 157 | xformers_version = version.parse(xformers.__version__) 158 | if xformers_version == version.parse("0.0.16"): 159 | logger.warn( 160 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 161 | ) 162 | unet.enable_xformers_memory_efficient_attention() 163 | else: 164 | raise ValueError( 165 | "xformers is not available. Make sure it is installed correctly") 166 | 167 | lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) 168 | optimizer_cls = torch.optim.AdamW 169 | 170 | optimizer = optimizer_cls( 171 | lora_layers, 172 | lr=cfg.learning_rate, 173 | betas=(cfg.adam_beta1, cfg.adam_beta2), 174 | weight_decay=cfg.adam_weight_decay, 175 | eps=cfg.adam_epsilon, 176 | ) 177 | 178 | # Get the datasets: you can either provide your own training and evaluation files (see below) 179 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 180 | 181 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 182 | # download the dataset. 183 | dataset = load_dataset( 184 | "imagefolder", 185 | data_dir=cfg.data_dir 186 | ) 187 | # See more about loading custom images at 188 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 189 | 190 | # Preprocessing the datasets. 191 | # We need to tokenize inputs and targets. 192 | column_names = dataset["train"].column_names 193 | image_column = column_names[0] 194 | caption_column = column_names[1] 195 | 196 | # Preprocessing the datasets. 197 | # We need to tokenize input captions and transform the images. 198 | def tokenize_captions(examples, is_train=True): 199 | captions = [] 200 | for caption in examples[caption_column]: 201 | if isinstance(caption, str): 202 | captions.append(caption) 203 | elif isinstance(caption, (list, np.ndarray)): 204 | # take a random caption if there are multiple 205 | captions.append(random.choice(caption) 206 | if is_train else caption[0]) 207 | else: 208 | raise ValueError( 209 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 210 | ) 211 | inputs = tokenizer( 212 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 213 | ) 214 | return inputs.input_ids 215 | 216 | # Preprocessing the datasets. 217 | train_transforms = transforms.Compose( 218 | [ 219 | transforms.Resize( 220 | cfg.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 221 | transforms.RandomCrop(cfg.resolution), 222 | transforms.ToTensor(), 223 | transforms.Normalize([0.5], [0.5]), 224 | ] 225 | ) 226 | 227 | def preprocess_train(examples): 228 | images = [image.convert("RGB") for image in examples[image_column]] 229 | examples["pixel_values"] = [ 230 | train_transforms(image) for image in images] 231 | examples["input_ids"] = tokenize_captions(examples) 232 | return examples 233 | 234 | with accelerator.main_process_first(): 235 | # Set the training transforms 236 | train_dataset = dataset["train"].with_transform(preprocess_train) 237 | 238 | def collate_fn(examples): 239 | pixel_values = torch.stack([example["pixel_values"] 240 | for example in examples]) 241 | pixel_values = pixel_values.to( 242 | memory_format=torch.contiguous_format).float() 243 | input_ids = torch.stack([example["input_ids"] for example in examples]) 244 | return {"pixel_values": pixel_values, "input_ids": input_ids} 245 | 246 | # DataLoaders creation: 247 | train_dataloader = torch.utils.data.DataLoader( 248 | train_dataset, 249 | shuffle=True, 250 | collate_fn=collate_fn, 251 | batch_size=cfg.train_batch_size, 252 | num_workers=cfg.dataloader_num_workers, 253 | ) 254 | 255 | # Scheduler and math around the number of training steps. 256 | num_update_steps_per_epoch = math.ceil( 257 | len(train_dataloader) / cfg.gradient_accumulation_steps) 258 | max_train_steps = cfg.n_epochs * num_update_steps_per_epoch 259 | 260 | lr_scheduler = get_scheduler( 261 | cfg.lr_scheduler_name, 262 | optimizer=optimizer 263 | ) 264 | 265 | # Prepare everything with our `accelerator`. 266 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 267 | unet, optimizer, train_dataloader, lr_scheduler 268 | ) 269 | 270 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 271 | num_update_steps_per_epoch = math.ceil( 272 | len(train_dataloader) / cfg.gradient_accumulation_steps) 273 | # Afterwards we recalculate our number of training epochs 274 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 275 | 276 | # We need to initialize the trackers we use, and also store our configuration. 277 | # The trackers initializes automatically on the main process. 278 | # if accelerator.is_main_process: 279 | # accelerator.init_trackers("text2image-fine-tune", config=vars(args)) 280 | 281 | # Train! 282 | total_batch_size = cfg.train_batch_size * \ 283 | accelerator.num_processes * cfg.gradient_accumulation_steps 284 | 285 | logger.info("***** Running training *****") 286 | logger.info(f" Num examples = {len(train_dataset)}") 287 | logger.info(f" Num Epochs = {num_train_epochs}") 288 | logger.info( 289 | f" Instantaneous batch size per device = {cfg.train_batch_size}") 290 | logger.info( 291 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 292 | logger.info( 293 | f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}") 294 | logger.info(f" Total optimization steps = {max_train_steps}") 295 | global_step = 0 296 | first_epoch = 0 297 | 298 | # Potentially load in the weights and states from a previous save 299 | if cfg.resume_from_checkpoint: 300 | if cfg.resume_from_checkpoint != "latest": 301 | path = os.path.basename(cfg.resume_from_checkpoint) 302 | else: 303 | # Get the most recent checkpoint 304 | dirs = os.listdir(cfg.output_dir) 305 | dirs = [d for d in dirs if d.startswith("checkpoint")] 306 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 307 | path = dirs[-1] if len(dirs) > 0 else None 308 | 309 | if path is None: 310 | accelerator.print( 311 | f"Checkpoint '{cfg.resume_from_checkpoint}' does not exist. Starting a new training run." 312 | ) 313 | initial_global_step = 0 314 | else: 315 | accelerator.print(f"Resuming from checkpoint {path}") 316 | accelerator.load_state(os.path.join(cfg.output_dir, path)) 317 | global_step = int(path.split("-")[1]) 318 | 319 | initial_global_step = global_step 320 | first_epoch = global_step // num_update_steps_per_epoch 321 | else: 322 | initial_global_step = 0 323 | 324 | progress_bar = tqdm( 325 | range(0, max_train_steps), 326 | initial=initial_global_step, 327 | desc="Steps", 328 | # Only show the progress bar once on each machine. 329 | disable=not accelerator.is_local_main_process, 330 | ) 331 | 332 | for epoch in range(first_epoch, num_train_epochs): 333 | unet.train() 334 | train_loss = 0.0 335 | for step, batch in enumerate(train_dataloader): 336 | with accelerator.accumulate(unet): 337 | # Convert images to latent space 338 | latents = vae.encode(batch["pixel_values"].to( 339 | dtype=weight_dtype)).latent_dist.sample() 340 | latents = latents * vae.config.scaling_factor 341 | 342 | # Sample noise that we'll add to the latents 343 | noise = torch.randn_like(latents) 344 | if cfg.noise_offset: 345 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 346 | noise += cfg.noise_offset * torch.randn( 347 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 348 | ) 349 | 350 | bsz = latents.shape[0] 351 | # Sample a random timestep for each image 352 | timesteps = torch.randint( 353 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 354 | timesteps = timesteps.long() 355 | 356 | # Add noise to the latents according to the noise magnitude at each timestep 357 | # (this is the forward diffusion process) 358 | noisy_latents = noise_scheduler.add_noise( 359 | latents, noise, timesteps) 360 | 361 | # Get the text embedding for conditioning 362 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 363 | 364 | if noise_scheduler.config.prediction_type == "epsilon": 365 | target = noise 366 | elif noise_scheduler.config.prediction_type == "v_prediction": 367 | target = noise_scheduler.get_velocity( 368 | latents, noise, timesteps) 369 | else: 370 | raise ValueError( 371 | f"Unknown prediction type {noise_scheduler.config.prediction_type}") 372 | 373 | # Predict the noise residual and compute loss 374 | model_pred = unet(noisy_latents, timesteps, 375 | encoder_hidden_states).sample 376 | 377 | loss = F.mse_loss(model_pred.float(), 378 | target.float(), reduction="mean") 379 | 380 | # Gather the losses across all processes for logging (if we use distributed training). 381 | avg_loss = accelerator.gather( 382 | loss.repeat(cfg.train_batch_size)).mean() 383 | train_loss += avg_loss.item() / cfg.gradient_accumulation_steps 384 | 385 | # Backpropagate 386 | accelerator.backward(loss) 387 | if accelerator.sync_gradients: 388 | params_to_clip = lora_layers 389 | accelerator.clip_grad_norm_( 390 | params_to_clip, cfg.max_grad_norm) 391 | optimizer.step() 392 | lr_scheduler.step() 393 | optimizer.zero_grad() 394 | 395 | # Checks if the accelerator has performed an optimization step behind the scenes 396 | if accelerator.sync_gradients: 397 | progress_bar.update(1) 398 | global_step += 1 399 | accelerator.log({"train_loss": train_loss}, step=global_step) 400 | train_loss = 0.0 401 | 402 | if global_step % cfg.checkpointing_steps == 0: 403 | if accelerator.is_main_process: 404 | save_path = os.path.join( 405 | cfg.output_dir, f"checkpoint-{global_step}") 406 | accelerator.save_state(save_path) 407 | 408 | unwrapped_unet = accelerator.unwrap_model(unet) 409 | unet_lora_state_dict = convert_state_dict_to_diffusers( 410 | get_peft_model_state_dict(unwrapped_unet) 411 | ) 412 | 413 | StableDiffusionPipeline.save_lora_weights( 414 | save_directory=save_path, 415 | unet_lora_layers=unet_lora_state_dict, 416 | safe_serialization=True, 417 | ) 418 | 419 | logger.info(f"Saved state to {save_path}") 420 | 421 | logs = {"step_loss": loss.detach().item( 422 | ), "lr": lr_scheduler.get_last_lr()[0]} 423 | progress_bar.set_postfix(**logs) 424 | 425 | if global_step >= max_train_steps: 426 | break 427 | 428 | # Save the lora layers 429 | accelerator.wait_for_everyone() 430 | if accelerator.is_main_process: 431 | unet = unet.to(torch.float32) 432 | 433 | unwrapped_unet = accelerator.unwrap_model(unet) 434 | unet_lora_state_dict = convert_state_dict_to_diffusers( 435 | get_peft_model_state_dict(unwrapped_unet)) 436 | StableDiffusionPipeline.save_lora_weights( 437 | save_directory=cfg.output_dir, 438 | unet_lora_layers=unet_lora_state_dict, 439 | safe_serialization=True, 440 | weight_name=cfg.ckpt_name + '.safetensor' 441 | ) 442 | 443 | accelerator.end_training() 444 | 445 | 446 | if __name__ == "__main__": 447 | main() 448 | -------------------------------------------------------------------------------- /LoRA/training_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from omegaconf import OmegaConf 3 | 4 | 5 | @dataclass 6 | class TrainingConfig: 7 | log_dir: str 8 | output_dir: str 9 | data_dir: str 10 | ckpt_name: str 11 | rank: int 12 | gradient_accumulation_steps: int = 1 13 | mixed_precision: str = None 14 | seed: int = None 15 | pretrained_model_name_or_path: str = 'ckpt/v1-5' 16 | enable_xformers_memory_efficient_attention: bool = True 17 | 18 | # AdamW 19 | learning_rate: float = 1e-4 20 | adam_beta1: float = 0.9 21 | adam_beta2: float = 0.999 22 | adam_weight_decay: float = 1e-2 23 | adam_epsilon: float = 1e-08 24 | 25 | resolution: int = 512 26 | n_epochs: int = 200 27 | checkpointing_steps: int = 500 28 | train_batch_size: int = 1 29 | dataloader_num_workers: int = 1 30 | 31 | lr_scheduler_name: str = 'constant' 32 | 33 | resume_from_checkpoint: bool = False 34 | noise_offset: float = 0.1 35 | max_grad_norm: float = 1.0 36 | 37 | 38 | def load_training_config(config_path: str) -> TrainingConfig: 39 | data_dict = OmegaConf.load(config_path) 40 | return TrainingConfig(**data_dict) 41 | 42 | 43 | if __name__ == '__main__': 44 | config = load_training_config('config/train/mountain.json') 45 | print(config) 46 | -------------------------------------------------------------------------------- /ReplaceAttn/.gitignore: -------------------------------------------------------------------------------- 1 | *.mp4 2 | *.gif 3 | __pycache__ -------------------------------------------------------------------------------- /ReplaceAttn/replace_attn.py: -------------------------------------------------------------------------------- 1 | from video_editing_pipeline import VideoEditingPipeline 2 | import cv2 3 | from PIL import Image 4 | import numpy as np 5 | from diffusers import ControlNetModel 6 | import torch 7 | 8 | 9 | def video_to_frame(video_path: str, interval: int): 10 | vidcap = cv2.VideoCapture(video_path) 11 | success = True 12 | 13 | count = 0 14 | res = [] 15 | while success: 16 | count += 1 17 | success, image = vidcap.read() 18 | if count % interval != 1: 19 | continue 20 | if image is not None: 21 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 22 | image = cv2.resize(image[:, 100:800], (512, 512)) 23 | res.append(image) 24 | 25 | vidcap.release() 26 | return res 27 | 28 | 29 | input_video_path = 'woman.mp4' 30 | input_interval = 10 31 | frames = video_to_frame( 32 | input_video_path, input_interval) 33 | frames = frames[:10] 34 | 35 | control_frames = [] 36 | # get canny image 37 | for frame in frames: 38 | np_image = cv2.Canny(frame, 50, 100) 39 | np_image = np_image[:, :, None] 40 | np_image = np.concatenate([np_image, np_image, np_image], axis=2) 41 | canny_image = Image.fromarray(np_image) 42 | control_frames.append(canny_image) 43 | 44 | controlnet = ControlNetModel.from_pretrained( 45 | "lllyasviel/sd-controlnet-canny").to('cuda') 46 | 47 | pipeline = VideoEditingPipeline.from_pretrained( 48 | 'runwayml/stable-diffusion-v1-5', controlnet=controlnet).to('cuda') 49 | pipeline.safety_checker = None 50 | 51 | generator = torch.manual_seed(0) 52 | frames = [Image.fromarray(frame) for frame in frames] 53 | 54 | output_frames = pipeline(images=frames, 55 | control_images=control_frames, 56 | prompt='a beautiful woman with red hair', 57 | num_inference_steps=20, 58 | controlnet_conditioning_scale=0.7, 59 | strength=0.9, 60 | generator=generator, 61 | negative_prompt='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') 62 | 63 | output_frames[0].save("output.gif", save_all=True, 64 | append_images=output_frames[1:], duration=100, loop=0) 65 | -------------------------------------------------------------------------------- /ReplaceAttn/video_editing_pipeline.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionControlNetImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel 2 | from diffusers.schedulers import KarrasDiffusionSchedulers 3 | from diffusers.models.attention_processor import Attention, AttnProcessor 4 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 5 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 6 | import torch 7 | 8 | 9 | class AttnState: 10 | STORE = 0 11 | LOAD = 1 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | @property 17 | def state(self): 18 | return self.__state 19 | 20 | @property 21 | def timestep(self): 22 | return self.__timestep 23 | 24 | def set_timestep(self, t): 25 | self.__timestep = t 26 | 27 | def reset(self): 28 | self.__state = AttnState.STORE 29 | self.__timestep = 0 30 | 31 | def to_load(self): 32 | self.__state = AttnState.LOAD 33 | 34 | 35 | class CrossFrameAttnProcessor(AttnProcessor): 36 | """ 37 | Cross frame attention processor. Each frame attends the first frame and previous frame. 38 | 39 | Args: 40 | attn_state: Whether the model is processing the first frame or an intermediate frame 41 | """ 42 | 43 | def __init__(self, attn_state: AttnState): 44 | super().__init__() 45 | self.attn_state = attn_state 46 | self.cur_timestep = 0 47 | self.first_maps = {} 48 | self.prev_maps = {} 49 | 50 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, **kwargs): 51 | 52 | if encoder_hidden_states is None: 53 | # Is self attention 54 | 55 | tot_timestep = self.attn_state.timestep 56 | if self.attn_state.state == AttnState.STORE: 57 | self.first_maps[self.cur_timestep] = hidden_states.detach() 58 | self.prev_maps[self.cur_timestep] = hidden_states.detach() 59 | res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs) 60 | else: 61 | tmp = hidden_states.detach() 62 | cross_map = torch.cat( 63 | (self.first_maps[self.cur_timestep], self.prev_maps[self.cur_timestep]), dim=1) 64 | res = super().__call__(attn, hidden_states, cross_map, **kwargs) 65 | self.prev_maps[self.cur_timestep] = tmp 66 | 67 | self.cur_timestep += 1 68 | if self.cur_timestep == tot_timestep: 69 | self.cur_timestep = 0 70 | else: 71 | # Is cross attention 72 | res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs) 73 | 74 | return res 75 | 76 | 77 | class VideoEditingPipeline(StableDiffusionControlNetImg2ImgPipeline): 78 | def __init__( 79 | self, 80 | vae: AutoencoderKL, 81 | text_encoder: CLIPTextModel, 82 | tokenizer: CLIPTokenizer, 83 | unet: UNet2DConditionModel, 84 | controlnet, 85 | scheduler: KarrasDiffusionSchedulers, 86 | safety_checker: StableDiffusionSafetyChecker, 87 | feature_extractor: CLIPImageProcessor, 88 | image_encoder: CLIPVisionModelWithProjection = None, 89 | requires_safety_checker: bool = True, 90 | ): 91 | super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler, 92 | safety_checker, feature_extractor, image_encoder, requires_safety_checker) 93 | self.attn_state = AttnState() 94 | attn_processor_dict = {} 95 | for k in unet.attn_processors.keys(): 96 | if k.startswith("up"): 97 | attn_processor_dict[k] = CrossFrameAttnProcessor( 98 | self.attn_state) 99 | else: 100 | attn_processor_dict[k] = AttnProcessor() 101 | 102 | self.unet.set_attn_processor(attn_processor_dict) 103 | 104 | def __call__(self, *args, images=None, control_images=None, **kwargs): 105 | self.attn_state.reset() 106 | self.attn_state.set_timestep( 107 | int(kwargs['num_inference_steps'] * kwargs['strength'])) 108 | outputs = [super().__call__( 109 | *args, **kwargs, image=images[0], control_image=control_images[0]).images[0]] 110 | self.attn_state.to_load() 111 | for i in range(1, len(images)): 112 | image = images[i] 113 | control_image = control_images[i] 114 | outputs.append(super().__call__( 115 | *args, **kwargs, image=image, control_image=control_image).images[0]) 116 | return outputs 117 | -------------------------------------------------------------------------------- /TrainingScript/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | models 3 | -------------------------------------------------------------------------------- /TrainingScript/cfg_0.json: -------------------------------------------------------------------------------- 1 | { 2 | "logging_dir": "logs", 3 | "output_dir": "models/ddpm_0", 4 | 5 | "model_config": "unet_cfg", 6 | "num_epochs": 10, 7 | "train_batch_size": 64, 8 | "checkpointing_steps": 5000, 9 | "valid_epochs": 1, 10 | "valid_batch_size": 4, 11 | "dataset_name": "ylecun/mnist", 12 | "resolution": 32, 13 | "learning_rate": 1e-4 14 | } -------------------------------------------------------------------------------- /TrainingScript/cfg_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": { 3 | "logging_dir": "logs", 4 | "output_dir": "models/ddpm_1", 5 | "checkpointing_steps": 5000, 6 | "valid_epochs": 1, 7 | "dataset_name": "ylecun/mnist", 8 | "resolution": 32, 9 | "train_batch_size": 64, 10 | "num_epochs": 10 11 | }, 12 | "ddpm": { 13 | "model_config": "unet_cfg", 14 | "learning_rate": 1e-4, 15 | "valid_batch_size": 4 16 | } 17 | } -------------------------------------------------------------------------------- /TrainingScript/cfg_lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": { 3 | "logging_dir": "logs", 4 | "output_dir": "models/lora", 5 | "checkpointing_steps": 5000, 6 | "valid_epochs": 1, 7 | "dataset_name": "ylecun/mnist", 8 | "resolution": 128, 9 | "train_batch_size": 16, 10 | "num_epochs": 2 11 | }, 12 | "lora": { 13 | "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", 14 | "learning_rate": 1e-4, 15 | "valid_batch_size": 4 16 | } 17 | } -------------------------------------------------------------------------------- /TrainingScript/ddpm_trainer.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | 4 | from dataclasses import dataclass 5 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel 6 | from diffusers.optimization import get_scheduler 7 | from diffusers.training_utils import EMAModel 8 | from diffusers.utils import is_accelerate_version 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from trainer import Trainer 13 | 14 | 15 | @dataclass 16 | class DDPMTrainingConfig: 17 | # Diffuion Models 18 | model_config: str 19 | ddpm_num_steps: int = 1000 20 | ddpm_beta_schedule: str = 'linear' 21 | prediction_type: str = 'epsilon' 22 | ddpm_num_inference_steps: int = 100 23 | 24 | # Validation 25 | valid_batch_size: int = 1 26 | 27 | # EMA 28 | use_ema: bool = False 29 | ema_max_decay: float = 0.9999 30 | ema_inv_gamma: float = 1.0 31 | ema_power: float = 3 / 4 32 | 33 | # AdamW 34 | scale_lr = False 35 | learning_rate: float = 1e-4 36 | adam_beta1: float = 0.9 37 | adam_beta2: float = 0.999 38 | adam_weight_decay: float = 1e-2 39 | adam_epsilon: float = 1e-08 40 | 41 | # LR Scheduler 42 | lr_scheduler: str = 'constant' 43 | lr_warmup_steps: int = 500 44 | 45 | 46 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 47 | """ 48 | Extract values from a 1-D numpy array for a batch of indices. 49 | 50 | :param arr: the 1-D numpy array. 51 | :param timesteps: a tensor of indices into the array to extract. 52 | :param broadcast_shape: a larger shape of K dimensions with the batch 53 | dimension equal to the length of timesteps. 54 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 55 | """ 56 | if not isinstance(arr, torch.Tensor): 57 | arr = torch.from_numpy(arr) 58 | res = arr[timesteps].float().to(timesteps.device) 59 | while len(res.shape) < len(broadcast_shape): 60 | res = res[..., None] 61 | return res.expand(broadcast_shape) 62 | 63 | 64 | class DDPMTrainer(Trainer): 65 | def __init__(self, weight_dtype, accelerator, logger, cfg: DDPMTrainingConfig): 66 | super().__init__(weight_dtype, accelerator, logger, cfg) 67 | 68 | def init_modules(self, 69 | enable_xformer=False, 70 | gradient_checkpointing=False): 71 | if self.cfg.model_config is None: 72 | self.model = UNet2DModel( 73 | in_channels=3, 74 | out_channels=3, 75 | layers_per_block=2, 76 | block_out_channels=(128, 128, 256, 256, 512, 512), 77 | down_block_types=( 78 | "DownBlock2D", 79 | "DownBlock2D", 80 | "DownBlock2D", 81 | "DownBlock2D", 82 | "AttnDownBlock2D", 83 | "DownBlock2D", 84 | ), 85 | up_block_types=( 86 | "UpBlock2D", 87 | "AttnUpBlock2D", 88 | "UpBlock2D", 89 | "UpBlock2D", 90 | "UpBlock2D", 91 | "UpBlock2D", 92 | ), 93 | ) 94 | else: 95 | config = UNet2DModel.load_config(self.cfg.model_config) 96 | self.model = UNet2DModel.from_config(config) 97 | 98 | # Create EMA for the model. 99 | if self.cfg.use_ema: 100 | self.ema_model = EMAModel( 101 | self.model.parameters(), 102 | decay=self.cfg.ema_max_decay, 103 | use_ema_warmup=True, 104 | inv_gamma=self.cfg.ema_inv_gamma, 105 | power=self.cfg.ema_power, 106 | model_cls=UNet2DModel, 107 | model_config=self.model.config, 108 | ) 109 | 110 | if enable_xformer: 111 | self.model.enable_xformers_memory_efficient_attention() 112 | 113 | accepts_prediction_type = "prediction_type" in set( 114 | inspect.signature(DDPMScheduler.__init__).parameters.keys()) 115 | if accepts_prediction_type: 116 | self.noise_scheduler = DDPMScheduler( 117 | num_train_timesteps=self.cfg.ddpm_num_steps, 118 | beta_schedule=self.cfg.ddpm_beta_schedule, 119 | prediction_type=self.cfg.prediction_type, 120 | ) 121 | else: 122 | self.noise_scheduler = DDPMScheduler( 123 | num_train_timesteps=self.cfg.ddpm_num_steps, 124 | beta_schedule=self.cfg.ddpm_beta_schedule) 125 | 126 | if gradient_checkpointing: 127 | self.model.enable_gradient_checkpointing() 128 | 129 | def init_optimizers(self, train_batch_size): 130 | if self.cfg.scale_lr: 131 | self.cfg.learning_rate = ( 132 | self.cfg.learning_rate * self.cfg.gradient_accumulation_steps * 133 | train_batch_size * self.accelerator.num_processes 134 | ) 135 | self.optimizer = torch.optim.AdamW( 136 | self.model.parameters(), 137 | lr=self.cfg.learning_rate, 138 | betas=(self.cfg.adam_beta1, self.cfg.adam_beta2), 139 | weight_decay=self.cfg.adam_weight_decay, 140 | eps=self.cfg.adam_epsilon, 141 | ) 142 | 143 | def init_lr_schedulers(self, gradient_accumulation_steps, num_epochs): 144 | self.lr_scheduler = get_scheduler( 145 | self.cfg.lr_scheduler, 146 | optimizer=self.optimizer, 147 | num_warmup_steps=self.cfg.lr_warmup_steps * 148 | gradient_accumulation_steps, 149 | num_training_steps=(len(self.train_dataloader) 150 | * num_epochs) 151 | ) 152 | 153 | def prepare_modules(self): 154 | self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare( 155 | self.model, self.optimizer, self.train_dataloader, self.lr_scheduler 156 | ) 157 | if self.cfg.use_ema: 158 | self.ema_model.to(self.accelerator.device) 159 | 160 | def models_to_train(self): 161 | self.model.train() 162 | 163 | def training_step(self, global_step, batch) -> dict: 164 | weight_dtype = self.weight_dtype 165 | clean_images = batch["input"].to(weight_dtype) 166 | # Sample noise that we'll add to the images 167 | noise = torch.randn(clean_images.shape, 168 | dtype=weight_dtype, device=clean_images.device) 169 | bsz = clean_images.shape[0] 170 | # Sample a random timestep for each image 171 | timesteps = torch.randint( 172 | 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device 173 | ).long() 174 | 175 | # Add noise to the clean images according to the noise magnitude at each timestep 176 | # (this is the forward diffusion process) 177 | noisy_images = self.noise_scheduler.add_noise( 178 | clean_images, noise, timesteps) 179 | 180 | with self.accelerator.accumulate(self.model): 181 | # Predict the noise residual 182 | model_output = self.model(noisy_images, timesteps).sample 183 | 184 | if self.cfg.prediction_type == "epsilon": 185 | # this could have different weights! 186 | loss = F.mse_loss(model_output.float(), noise.float()) 187 | elif self.cfg.prediction_type == "sample": 188 | alpha_t = _extract_into_tensor( 189 | self.noise_scheduler.alphas_cumprod, timesteps, ( 190 | clean_images.shape[0], 1, 1, 1) 191 | ) 192 | snr_weights = alpha_t / (1 - alpha_t) 193 | # use SNR weighting from distillation paper 194 | loss = snr_weights * \ 195 | F.mse_loss(model_output.float(), 196 | clean_images.float(), reduction="none") 197 | loss = loss.mean() 198 | else: 199 | raise ValueError( 200 | f"Unsupported prediction type: {self.cfg.prediction_type}") 201 | 202 | self.accelerator.backward(loss) 203 | 204 | if self.accelerator.sync_gradients: 205 | self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) 206 | self.optimizer.step() 207 | self.lr_scheduler.step() 208 | self.optimizer.zero_grad() 209 | 210 | if self.accelerator.sync_gradients: 211 | if self.cfg.use_ema: 212 | self.ema_model.step(self.model.parameters()) 213 | 214 | logs = {"loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[ 215 | 0], "step": global_step} 216 | if self.cfg.use_ema: 217 | logs["ema_decay"] = self.ema_model.cur_decay_value 218 | 219 | return logs 220 | 221 | def validate(self, epoch, global_step): 222 | unet = self.accelerator.unwrap_model(self.model) 223 | 224 | if self.cfg.use_ema: 225 | self.ema_model.store(unet.parameters()) 226 | self.ema_model.copy_to(unet.parameters()) 227 | 228 | pipeline = DDPMPipeline( 229 | unet=unet, 230 | scheduler=self.noise_scheduler, 231 | ) 232 | 233 | generator = torch.Generator( 234 | device=pipeline.device).manual_seed(0) 235 | # run pipeline in inference (sample random noise and denoise) 236 | images = pipeline( 237 | generator=generator, 238 | batch_size=self.cfg.valid_batch_size, 239 | num_inference_steps=self.cfg.ddpm_num_inference_steps, 240 | output_type="np", 241 | ).images 242 | 243 | if self.cfg.use_ema: 244 | self.ema_model.restore(unet.parameters()) 245 | 246 | # denormalize the images and save to tensorboard 247 | images_processed = (images * 255).round().astype("uint8") 248 | 249 | if self.logger == "tensorboard": 250 | if is_accelerate_version(">=", "0.17.0.dev0"): 251 | tracker = self.accelerator.get_tracker( 252 | "tensorboard", unwrap=True) 253 | else: 254 | tracker = self.accelerator.get_tracker("tensorboard") 255 | tracker.add_images( 256 | "test_samples", images_processed.transpose(0, 3, 1, 2), epoch) 257 | elif self.logger == "wandb": 258 | # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files 259 | import wandb 260 | self.accelerator.get_tracker("wandb").log( 261 | {"test_samples": [wandb.Image( 262 | img) for img in images_processed], "epoch": epoch}, 263 | step=global_step, 264 | ) 265 | 266 | def save_pipeline(self, output_dir): 267 | unet = self.accelerator.unwrap_model(self.model) 268 | 269 | if self.cfg.use_ema: 270 | self.ema_model.store(unet.parameters()) 271 | self.ema_model.copy_to(unet.parameters()) 272 | 273 | pipeline = DDPMPipeline( 274 | unet=unet, 275 | scheduler=self.noise_scheduler, 276 | ) 277 | 278 | pipeline.save_pretrained(output_dir) 279 | 280 | if self.cfg.use_ema: 281 | self.ema_model.restore(unet.parameters()) 282 | 283 | def save_model_hook(self, models, weights, output_dir): 284 | if self.accelerator.is_main_process: 285 | if self.cfg.use_ema: 286 | self.ema_model.save_pretrained( 287 | os.path.join(output_dir, "unet_ema")) 288 | 289 | for i, model in enumerate(models): 290 | model.save_pretrained(os.path.join(output_dir, "unet")) 291 | 292 | # make sure to pop weight so that corresponding model is not saved again 293 | weights.pop() 294 | 295 | def load_model_hook(self, models, input_dir): 296 | if self.cfg.use_ema: 297 | load_model = EMAModel.from_pretrained( 298 | os.path.join(input_dir, "unet_ema"), UNet2DModel) 299 | self.ema_model.load_state_dict(load_model.state_dict()) 300 | self.ema_model.to(self.accelerator.device) 301 | del load_model 302 | 303 | for i in range(len(models)): 304 | # pop models so that they are not loaded again 305 | model = models.pop() 306 | 307 | # load diffusers style into model 308 | load_model = UNet2DModel.from_pretrained( 309 | input_dir, subfolder="unet") 310 | model.register_to_config(**load_model.config) 311 | 312 | model.load_state_dict(load_model.state_dict()) 313 | del load_model 314 | -------------------------------------------------------------------------------- /TrainingScript/sd_lora_trainer.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | from dataclasses import dataclass 4 | 5 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel 6 | from diffusers.optimization import get_scheduler 7 | from diffusers.training_utils import cast_training_params, compute_snr 8 | from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | from peft import LoraConfig 14 | from peft.utils import get_peft_model_state_dict 15 | 16 | from trainer import Trainer 17 | 18 | 19 | @dataclass 20 | class LoraTrainingConfig: 21 | # Diffuion Models 22 | pretrained_model_name_or_path: str 23 | revision: str = None 24 | variant: str = None 25 | rank: int = 4 26 | ddpm_num_steps: int = 1000 27 | ddpm_beta_schedule: str = 'linear' 28 | prediction_type: str = 'epsilon' 29 | ddpm_num_inference_steps: int = 100 30 | 31 | max_grad_norm = 0.1 32 | 33 | # Validation 34 | valid_seed = 0 35 | valid_batch_size: int = 1 36 | 37 | # AdamW 38 | scale_lr = False 39 | learning_rate: float = 1e-4 40 | adam_beta1: float = 0.9 41 | adam_beta2: float = 0.999 42 | adam_weight_decay: float = 1e-2 43 | adam_epsilon: float = 1e-08 44 | 45 | # LR Scheduler 46 | lr_scheduler: str = 'constant' 47 | lr_warmup_steps: int = 500 48 | 49 | 50 | def log_validation( 51 | pipeline, 52 | seed, 53 | num_validation_images, 54 | accelerator, 55 | epoch, 56 | is_final_validation=False, 57 | ): 58 | pipeline = pipeline.to(accelerator.device) 59 | pipeline.set_progress_bar_config(disable=True) 60 | generator = torch.Generator(device=accelerator.device) 61 | if seed is not None: 62 | generator = generator.manual_seed(seed) 63 | images = [] 64 | 65 | autocast_ctx = torch.autocast(accelerator.device.type) 66 | 67 | with autocast_ctx: 68 | for _ in range(num_validation_images): 69 | images.append(pipeline("", num_inference_steps=30, 70 | generator=generator).images[0]) 71 | 72 | for tracker in accelerator.trackers: 73 | phase_name = "test" if is_final_validation else "validation" 74 | if tracker.name == "tensorboard": 75 | np_images = np.stack([np.asarray(img) for img in images]) 76 | tracker.writer.add_images( 77 | phase_name, np_images, epoch, dataformats="NHWC") 78 | if tracker.name == "wandb": 79 | import wandb 80 | tracker.log( 81 | { 82 | phase_name: [ 83 | wandb.Image(image, caption=f"{i}: {''}") for i, image in enumerate(images) 84 | ] 85 | } 86 | ) 87 | return images 88 | 89 | 90 | class LoraTrainer(Trainer): 91 | def __init__(self, weight_dtype, accelerator, logger, cfg: LoraTrainingConfig): 92 | super().__init__(weight_dtype, accelerator, logger, cfg) 93 | 94 | def init_modules(self, 95 | enable_xformer=False, 96 | gradient_checkpointing=False): 97 | cfg = self.cfg 98 | # Load scheduler, tokenizer and models. 99 | self.noise_scheduler = DDPMScheduler.from_pretrained( 100 | cfg.pretrained_model_name_or_path, subfolder="scheduler") 101 | self.tokenizer = CLIPTokenizer.from_pretrained( 102 | cfg.pretrained_model_name_or_path, subfolder="tokenizer", revision=cfg.revision 103 | ) 104 | self.text_encoder = CLIPTextModel.from_pretrained( 105 | cfg.pretrained_model_name_or_path, subfolder="text_encoder", revision=cfg.revision 106 | ) 107 | self.vae = AutoencoderKL.from_pretrained( 108 | cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision, variant=cfg.variant 109 | ) 110 | self.unet = UNet2DConditionModel.from_pretrained( 111 | cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, variant=cfg.variant 112 | ) 113 | # freeze parameters of models to save more memory 114 | self.unet.requires_grad_(False) 115 | self.vae.requires_grad_(False) 116 | self.text_encoder.requires_grad_(False) 117 | 118 | for param in self.unet.parameters(): 119 | param.requires_grad_(False) 120 | 121 | unet_lora_config = LoraConfig( 122 | r=cfg.rank, 123 | lora_alpha=cfg.rank, 124 | init_lora_weights="gaussian", 125 | target_modules=["to_k", "to_q", "to_v", "to_out.0"], 126 | ) 127 | 128 | self.unet.to(self.accelerator.device, dtype=self.weight_dtype) 129 | self.vae.to(self.accelerator.device, dtype=self.weight_dtype) 130 | self.text_encoder.to(self.accelerator.device, dtype=self.weight_dtype) 131 | 132 | self.unet.add_adapter(unet_lora_config) 133 | if self.accelerator.mixed_precision == "fp16": 134 | # only upcast trainable parameters (LoRA) into fp32 135 | cast_training_params(self.unet, dtype=torch.float32) 136 | 137 | if enable_xformer: 138 | self.unet.enable_xformers_memory_efficient_attention() 139 | 140 | self.lora_layers = filter( 141 | lambda p: p.requires_grad, self.unet.parameters()) 142 | 143 | if gradient_checkpointing: 144 | self.unet.enable_gradient_checkpointing() 145 | 146 | self.empty_ids = self.tokenizer( 147 | '', max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 148 | ).input_ids 149 | 150 | def init_optimizers(self, train_batch_size): 151 | if self.cfg.scale_lr: 152 | self.cfg.learning_rate = ( 153 | self.cfg.learning_rate * self.cfg.gradient_accumulation_steps * 154 | train_batch_size * self.accelerator.num_processes 155 | ) 156 | self.optimizer = torch.optim.AdamW( 157 | self.lora_layers, 158 | lr=self.cfg.learning_rate, 159 | betas=(self.cfg.adam_beta1, self.cfg.adam_beta2), 160 | weight_decay=self.cfg.adam_weight_decay, 161 | eps=self.cfg.adam_epsilon, 162 | ) 163 | 164 | def init_lr_schedulers(self, gradient_accumulation_steps, num_epochs): 165 | self.lr_scheduler = get_scheduler( 166 | self.cfg.lr_scheduler, 167 | optimizer=self.optimizer, 168 | num_warmup_steps=self.cfg.lr_warmup_steps * 169 | gradient_accumulation_steps, 170 | num_training_steps=(len(self.train_dataloader) 171 | * num_epochs) 172 | ) 173 | 174 | def prepare_modules(self): 175 | self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare( 176 | self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler 177 | ) 178 | 179 | def models_to_train(self): 180 | self.unet.train() 181 | 182 | def training_step(self, global_step, batch) -> dict: 183 | train_loss = 0.0 184 | with self.accelerator.accumulate(self.unet): 185 | # Convert images to latent space 186 | latents = self.vae.encode(batch["input"].to( 187 | dtype=self.weight_dtype)).latent_dist.sample() 188 | latents = latents * self.vae.config.scaling_factor 189 | input_ids = self.empty_ids.repeat(latents.shape[0], 1) 190 | input_ids = input_ids.to(latents.device) 191 | 192 | # Sample noise that we'll add to the latents 193 | noise = torch.randn_like(latents) 194 | 195 | bsz = latents.shape[0] 196 | # Sample a random timestep for each image 197 | timesteps = torch.randint( 198 | 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 199 | timesteps = timesteps.long() 200 | 201 | # Add noise to the latents according to the noise magnitude at each timestep 202 | # (this is the forward diffusion process) 203 | noisy_latents = self.noise_scheduler.add_noise( 204 | latents, noise, timesteps) 205 | 206 | # Get the text embedding for conditioning 207 | encoder_hidden_states = self.text_encoder( 208 | input_ids, return_dict=False)[0] 209 | 210 | # Get the target for loss depending on the prediction type 211 | if self.cfg.prediction_type is not None: 212 | # set prediction_type of scheduler if defined 213 | self.noise_scheduler.register_to_config( 214 | prediction_type=self.cfg.prediction_type) 215 | 216 | if self.noise_scheduler.config.prediction_type == "epsilon": 217 | target = noise 218 | elif self.noise_scheduler.config.prediction_type == "v_prediction": 219 | target = self.noise_scheduler.get_velocity( 220 | latents, noise, timesteps) 221 | else: 222 | raise ValueError( 223 | f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") 224 | 225 | # Predict the noise residual and compute loss 226 | model_pred = self.unet(noisy_latents, timesteps, 227 | encoder_hidden_states, return_dict=False)[0] 228 | 229 | loss = F.mse_loss(model_pred.float(), 230 | target.float(), reduction="mean") 231 | 232 | train_batch_size = latents.shape[0] 233 | 234 | # Gather the losses across all processes for logging (if we use distributed training). 235 | avg_loss = self.accelerator.gather( 236 | loss.repeat(train_batch_size)).mean() 237 | train_loss += avg_loss.item() / self.accelerator.gradient_accumulation_steps 238 | 239 | # Backpropagate 240 | self.accelerator.backward(loss) 241 | if self.accelerator.sync_gradients: 242 | params_to_clip = self.lora_layers 243 | self.accelerator.clip_grad_norm_( 244 | params_to_clip, self.cfg.max_grad_norm) 245 | self.optimizer.step() 246 | self.lr_scheduler.step() 247 | self.optimizer.zero_grad() 248 | 249 | if self.accelerator.sync_gradients: 250 | logs = {"train_loss": train_loss} 251 | 252 | return logs 253 | 254 | def validate(self, epoch, global_step): 255 | pipeline = DiffusionPipeline.from_pretrained( 256 | self.cfg.pretrained_model_name_or_path, 257 | unet=self.accelerator.unwrap_model(self.unet), 258 | revision=self.cfg.revision, 259 | variant=self.cfg.variant, 260 | torch_dtype=self.weight_dtype, 261 | ) 262 | log_validation( 263 | pipeline, self.cfg.valid_seed, self.cfg.valid_batch_size, self.accelerator, epoch) 264 | 265 | del pipeline 266 | torch.cuda.empty_cache() 267 | 268 | def save_pipeline(self, output_dir): 269 | self.unet = self.unet.to(torch.float32) 270 | 271 | unwrapped_unet = self.accelerator.unwrap_model(self.unet) 272 | unet_lora_state_dict = convert_state_dict_to_diffusers( 273 | get_peft_model_state_dict(unwrapped_unet) 274 | ) 275 | 276 | StableDiffusionPipeline.save_lora_weights( 277 | save_directory=output_dir, 278 | unet_lora_layers=unet_lora_state_dict, 279 | safe_serialization=True, 280 | ) 281 | 282 | def save_model_hook(self, models, weights, output_dir): 283 | if self.accelerator.is_main_process: 284 | for i, model in enumerate(models): 285 | unwrapped_unet = self.accelerator.unwrap_model(model) 286 | unet_lora_state_dict = convert_state_dict_to_diffusers( 287 | get_peft_model_state_dict(unwrapped_unet) 288 | ) 289 | 290 | StableDiffusionPipeline.save_lora_weights( 291 | save_directory=output_dir, 292 | unet_lora_layers=unet_lora_state_dict, 293 | safe_serialization=True, 294 | ) 295 | 296 | def load_model_hook(self, models, input_dir): 297 | pass 298 | -------------------------------------------------------------------------------- /TrainingScript/train_0.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import logging 4 | import math 5 | import os 6 | import shutil 7 | from pathlib import Path 8 | 9 | import accelerate 10 | import datasets 11 | import torch 12 | import torch.nn.functional as F 13 | from accelerate import Accelerator 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import ProjectConfiguration, set_seed 16 | from datasets import load_dataset 17 | from huggingface_hub import create_repo, upload_folder 18 | from packaging import version 19 | from torchvision import transforms 20 | from tqdm.auto import tqdm 21 | from omegaconf import OmegaConf 22 | 23 | import diffusers 24 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel 25 | from diffusers.optimization import get_scheduler 26 | from diffusers.training_utils import EMAModel 27 | from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available 28 | from diffusers.utils.import_utils import is_xformers_available 29 | 30 | from training_cfg_0 import BaseTrainingConfig 31 | 32 | 33 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 34 | check_min_version("0.30.0.dev0") 35 | 36 | logger = get_logger(__name__, log_level="INFO") 37 | 38 | 39 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 40 | """ 41 | Extract values from a 1-D numpy array for a batch of indices. 42 | 43 | :param arr: the 1-D numpy array. 44 | :param timesteps: a tensor of indices into the array to extract. 45 | :param broadcast_shape: a larger shape of K dimensions with the batch 46 | dimension equal to the length of timesteps. 47 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 48 | """ 49 | if not isinstance(arr, torch.Tensor): 50 | arr = torch.from_numpy(arr) 51 | res = arr[timesteps].float().to(timesteps.device) 52 | while len(res.shape) < len(broadcast_shape): 53 | res = res[..., None] 54 | return res.expand(broadcast_shape) 55 | 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('cfg', type=str) 60 | args = parser.parse_args() 61 | 62 | data_dict = OmegaConf.load(args.cfg) 63 | cfg = BaseTrainingConfig(**data_dict) 64 | 65 | logging_dir = os.path.join(cfg.output_dir, cfg.logging_dir) 66 | accelerator_project_config = ProjectConfiguration( 67 | project_dir=cfg.output_dir, logging_dir=logging_dir) 68 | 69 | accelerator = Accelerator( 70 | gradient_accumulation_steps=cfg.gradient_accumulation_steps, 71 | mixed_precision=cfg.mixed_precision, 72 | log_with=cfg.logger, 73 | project_config=accelerator_project_config 74 | ) 75 | 76 | if cfg.logger == "tensorboard": 77 | if not is_tensorboard_available(): 78 | raise ImportError( 79 | "Make sure to install tensorboard if you want to use it for logging during training.") 80 | 81 | elif cfg.logger == "wandb": 82 | if not is_wandb_available(): 83 | raise ImportError( 84 | "Make sure to install wandb if you want to use it for logging during training.") 85 | import wandb 86 | 87 | # `accelerate` 0.16.0 will have better support for customized saving 88 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 89 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 90 | def save_model_hook(models, weights, output_dir): 91 | if accelerator.is_main_process: 92 | if cfg.use_ema: 93 | ema_model.save_pretrained( 94 | os.path.join(output_dir, "unet_ema")) 95 | 96 | for i, model in enumerate(models): 97 | model.save_pretrained(os.path.join(output_dir, "unet")) 98 | 99 | # make sure to pop weight so that corresponding model is not saved again 100 | weights.pop() 101 | 102 | def load_model_hook(models, input_dir): 103 | if cfg.use_ema: 104 | load_model = EMAModel.from_pretrained( 105 | os.path.join(input_dir, "unet_ema"), UNet2DModel) 106 | ema_model.load_state_dict(load_model.state_dict()) 107 | ema_model.to(accelerator.device) 108 | del load_model 109 | 110 | for i in range(len(models)): 111 | # pop models so that they are not loaded again 112 | model = models.pop() 113 | 114 | # load diffusers style into model 115 | load_model = UNet2DModel.from_pretrained( 116 | input_dir, subfolder="unet") 117 | model.register_to_config(**load_model.config) 118 | 119 | model.load_state_dict(load_model.state_dict()) 120 | del load_model 121 | 122 | accelerator.register_save_state_pre_hook(save_model_hook) 123 | accelerator.register_load_state_pre_hook(load_model_hook) 124 | 125 | # Make one log on every process with the configuration for debugging. 126 | logging.basicConfig( 127 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 128 | datefmt="%m/%d/%Y %H:%M:%S", 129 | level=logging.INFO, 130 | ) 131 | logger.info(accelerator.state, main_process_only=False) 132 | if accelerator.is_local_main_process: 133 | datasets.utils.logging.set_verbosity_warning() 134 | diffusers.utils.logging.set_verbosity_info() 135 | else: 136 | datasets.utils.logging.set_verbosity_error() 137 | diffusers.utils.logging.set_verbosity_error() 138 | 139 | # If passed along, set the training seed now. 140 | if cfg.seed is not None: 141 | set_seed(cfg.seed) 142 | 143 | # Handle the repository creation 144 | if accelerator.is_main_process: 145 | if cfg.output_dir is not None: 146 | os.makedirs(cfg.output_dir, exist_ok=True) 147 | 148 | if cfg.push_to_hub: 149 | repo_id = create_repo( 150 | repo_id=cfg.hub_model_id or Path(cfg.output_dir).name, exist_ok=True, token=cfg.hub_token 151 | ).repo_id 152 | 153 | # Initialize the model 154 | if cfg.model_config is None: 155 | model = UNet2DModel( 156 | sample_size=cfg.resolution, 157 | in_channels=3, 158 | out_channels=3, 159 | layers_per_block=2, 160 | block_out_channels=(128, 128, 256, 256, 512, 512), 161 | down_block_types=( 162 | "DownBlock2D", 163 | "DownBlock2D", 164 | "DownBlock2D", 165 | "DownBlock2D", 166 | "AttnDownBlock2D", 167 | "DownBlock2D", 168 | ), 169 | up_block_types=( 170 | "UpBlock2D", 171 | "AttnUpBlock2D", 172 | "UpBlock2D", 173 | "UpBlock2D", 174 | "UpBlock2D", 175 | "UpBlock2D", 176 | ), 177 | ) 178 | else: 179 | config = UNet2DModel.load_config(cfg.model_config) 180 | model = UNet2DModel.from_config(config) 181 | 182 | # Create EMA for the model. 183 | if cfg.use_ema: 184 | ema_model = EMAModel( 185 | model.parameters(), 186 | decay=cfg.ema_max_decay, 187 | use_ema_warmup=True, 188 | inv_gamma=cfg.ema_inv_gamma, 189 | power=cfg.ema_power, 190 | model_cls=UNet2DModel, 191 | model_config=model.config, 192 | ) 193 | 194 | weight_dtype = torch.float32 195 | if accelerator.mixed_precision == "fp16": 196 | weight_dtype = torch.float16 197 | cfg.mixed_precision = accelerator.mixed_precision 198 | elif accelerator.mixed_precision == "bf16": 199 | weight_dtype = torch.bfloat16 200 | cfg.mixed_precision = accelerator.mixed_precision 201 | 202 | if cfg.enable_xformers_memory_efficient_attention: 203 | if is_xformers_available(): 204 | import xformers 205 | 206 | xformers_version = version.parse(xformers.__version__) 207 | if xformers_version == version.parse("0.0.16"): 208 | logger.warning( 209 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 210 | ) 211 | model.enable_xformers_memory_efficient_attention() 212 | else: 213 | raise ValueError( 214 | "xformers is not available. Make sure it is installed correctly") 215 | 216 | # Initialize the scheduler 217 | accepts_prediction_type = "prediction_type" in set( 218 | inspect.signature(DDPMScheduler.__init__).parameters.keys()) 219 | if accepts_prediction_type: 220 | noise_scheduler = DDPMScheduler( 221 | num_train_timesteps=cfg.ddpm_num_steps, 222 | beta_schedule=cfg.ddpm_beta_schedule, 223 | prediction_type=cfg.prediction_type, 224 | ) 225 | else: 226 | noise_scheduler = DDPMScheduler( 227 | num_train_timesteps=cfg.ddpm_num_steps, beta_schedule=cfg.ddpm_beta_schedule) 228 | 229 | if cfg.gradient_checkpointing: 230 | model.enable_gradient_checkpointing() 231 | 232 | if cfg.scale_lr: 233 | cfg.learning_rate = ( 234 | cfg.learning_rate * cfg.gradient_accumulation_steps * 235 | cfg.train_batch_size * accelerator.num_processes 236 | ) 237 | 238 | # Initialize the optimizer 239 | optimizer = torch.optim.AdamW( 240 | model.parameters(), 241 | lr=cfg.learning_rate, 242 | betas=(cfg.adam_beta1, cfg.adam_beta2), 243 | weight_decay=cfg.adam_weight_decay, 244 | eps=cfg.adam_epsilon, 245 | ) 246 | 247 | # Get the datasets: you can either provide your own training and evaluation files (see below) 248 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 249 | 250 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 251 | # download the dataset. 252 | if cfg.dataset_name is not None: 253 | dataset = load_dataset( 254 | cfg.dataset_name, 255 | cfg.dataset_config_name, 256 | cache_dir=cfg.cache_dir, 257 | split="train", 258 | ) 259 | else: 260 | dataset = load_dataset( 261 | "imagefolder", data_dir=cfg.train_data_dir, cache_dir=cfg.cache_dir, split="train") 262 | # See more about loading custom images at 263 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 264 | 265 | # Preprocessing the datasets and DataLoaders creation. 266 | augmentations = transforms.Compose( 267 | [ 268 | transforms.Resize( 269 | cfg.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 270 | transforms.CenterCrop( 271 | cfg.resolution) if cfg.center_crop else transforms.RandomCrop(cfg.resolution), 272 | transforms.RandomHorizontalFlip() if cfg.random_flip else transforms.Lambda(lambda x: x), 273 | transforms.ToTensor(), 274 | transforms.Normalize([0.5], [0.5]), 275 | ] 276 | ) 277 | 278 | def transform_images(examples): 279 | images = [augmentations(image.convert("RGB")) 280 | for image in examples["image"]] 281 | return {"input": images} 282 | 283 | logger.info(f"Dataset size: {len(dataset)}") 284 | 285 | dataset.set_transform(transform_images) 286 | train_dataloader = torch.utils.data.DataLoader( 287 | dataset, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.dataloader_num_workers 288 | ) 289 | 290 | # Initialize the learning rate scheduler 291 | lr_scheduler = get_scheduler( 292 | cfg.lr_scheduler, 293 | optimizer=optimizer, 294 | num_warmup_steps=cfg.lr_warmup_steps * cfg.gradient_accumulation_steps, 295 | num_training_steps=(len(train_dataloader) * cfg.num_epochs), 296 | ) 297 | 298 | # Prepare everything with our `accelerator`. 299 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 300 | model, optimizer, train_dataloader, lr_scheduler 301 | ) 302 | 303 | if cfg.use_ema: 304 | ema_model.to(accelerator.device) 305 | 306 | # We need to initialize the trackers we use, and also store our configuration. 307 | # The trackers initializes automatically on the main process. 308 | if accelerator.is_main_process: 309 | run = os.path.split(__file__)[-1].split(".")[0] 310 | accelerator.init_trackers(run) 311 | 312 | total_batch_size = cfg.train_batch_size * \ 313 | accelerator.num_processes * cfg.gradient_accumulation_steps 314 | num_update_steps_per_epoch = math.ceil( 315 | len(train_dataloader) / cfg.gradient_accumulation_steps) 316 | max_train_steps = cfg.num_epochs * num_update_steps_per_epoch 317 | 318 | logger.info("***** Running training *****") 319 | logger.info(f" Num examples = {len(dataset)}") 320 | logger.info(f" Num Epochs = {cfg.num_epochs}") 321 | logger.info( 322 | f" Instantaneous batch size per device = {cfg.train_batch_size}") 323 | logger.info( 324 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 325 | logger.info( 326 | f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}") 327 | logger.info(f" Total optimization steps = {max_train_steps}") 328 | 329 | global_step = 0 330 | first_epoch = 0 331 | 332 | # Potentially load in the weights and states from a previous save 333 | if cfg.resume_from_checkpoint: 334 | if cfg.resume_from_checkpoint != "latest": 335 | path = os.path.basename(cfg.resume_from_checkpoint) 336 | else: 337 | # Get the most recent checkpoint 338 | dirs = os.listdir(cfg.output_dir) 339 | dirs = [d for d in dirs if d.startswith("checkpoint")] 340 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 341 | path = dirs[-1] if len(dirs) > 0 else None 342 | 343 | if path is None: 344 | accelerator.print( 345 | f"Checkpoint '{cfg.resume_from_checkpoint}' does not exist. Starting a new training run." 346 | ) 347 | cfg.resume_from_checkpoint = None 348 | else: 349 | accelerator.print(f"Resuming from checkpoint {path}") 350 | accelerator.load_state(os.path.join(cfg.output_dir, path)) 351 | global_step = int(path.split("-")[1]) 352 | 353 | resume_global_step = global_step * cfg.gradient_accumulation_steps 354 | first_epoch = global_step // num_update_steps_per_epoch 355 | resume_step = resume_global_step % ( 356 | num_update_steps_per_epoch * cfg.gradient_accumulation_steps) 357 | 358 | # Train! 359 | for epoch in range(first_epoch, cfg.num_epochs): 360 | model.train() 361 | progress_bar = tqdm(total=num_update_steps_per_epoch, 362 | disable=not accelerator.is_local_main_process) 363 | progress_bar.set_description(f"Epoch {epoch}") 364 | for step, batch in enumerate(train_dataloader): 365 | # Skip steps until we reach the resumed step 366 | if cfg.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 367 | if step % cfg.gradient_accumulation_steps == 0: 368 | progress_bar.update(1) 369 | continue 370 | 371 | clean_images = batch["input"].to(weight_dtype) 372 | # Sample noise that we'll add to the images 373 | noise = torch.randn(clean_images.shape, 374 | dtype=weight_dtype, device=clean_images.device) 375 | bsz = clean_images.shape[0] 376 | # Sample a random timestep for each image 377 | timesteps = torch.randint( 378 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device 379 | ).long() 380 | 381 | # Add noise to the clean images according to the noise magnitude at each timestep 382 | # (this is the forward diffusion process) 383 | noisy_images = noise_scheduler.add_noise( 384 | clean_images, noise, timesteps) 385 | 386 | with accelerator.accumulate(model): 387 | # Predict the noise residual 388 | model_output = model(noisy_images, timesteps).sample 389 | 390 | if cfg.prediction_type == "epsilon": 391 | # this could have different weights! 392 | loss = F.mse_loss(model_output.float(), noise.float()) 393 | elif cfg.prediction_type == "sample": 394 | alpha_t = _extract_into_tensor( 395 | noise_scheduler.alphas_cumprod, timesteps, ( 396 | clean_images.shape[0], 1, 1, 1) 397 | ) 398 | snr_weights = alpha_t / (1 - alpha_t) 399 | # use SNR weighting from distillation paper 400 | loss = snr_weights * \ 401 | F.mse_loss(model_output.float(), 402 | clean_images.float(), reduction="none") 403 | loss = loss.mean() 404 | else: 405 | raise ValueError( 406 | f"Unsupported prediction type: {cfg.prediction_type}") 407 | 408 | accelerator.backward(loss) 409 | 410 | if accelerator.sync_gradients: 411 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 412 | optimizer.step() 413 | lr_scheduler.step() 414 | optimizer.zero_grad() 415 | 416 | # Checks if the accelerator has performed an optimization step behind the scenes 417 | if accelerator.sync_gradients: 418 | if cfg.use_ema: 419 | ema_model.step(model.parameters()) 420 | progress_bar.update(1) 421 | global_step += 1 422 | 423 | if accelerator.is_main_process: 424 | if global_step % cfg.checkpointing_steps == 0: 425 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 426 | if cfg.checkpoints_total_limit is not None: 427 | checkpoints = os.listdir(cfg.output_dir) 428 | checkpoints = [ 429 | d for d in checkpoints if d.startswith("checkpoint")] 430 | checkpoints = sorted( 431 | checkpoints, key=lambda x: int(x.split("-")[1])) 432 | 433 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 434 | if len(checkpoints) >= cfg.checkpoints_total_limit: 435 | num_to_remove = len( 436 | checkpoints) - cfg.checkpoints_total_limit + 1 437 | removing_checkpoints = checkpoints[0:num_to_remove] 438 | 439 | logger.info( 440 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 441 | ) 442 | logger.info( 443 | f"removing checkpoints: {', '.join(removing_checkpoints)}") 444 | 445 | for removing_checkpoint in removing_checkpoints: 446 | removing_checkpoint = os.path.join( 447 | cfg.output_dir, removing_checkpoint) 448 | shutil.rmtree(removing_checkpoint) 449 | 450 | save_path = os.path.join( 451 | cfg.output_dir, f"checkpoint-{global_step}") 452 | accelerator.save_state(save_path) 453 | logger.info(f"Saved state to {save_path}") 454 | 455 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[ 456 | 0], "step": global_step} 457 | if cfg.use_ema: 458 | logs["ema_decay"] = ema_model.cur_decay_value 459 | progress_bar.set_postfix(**logs) 460 | accelerator.log(logs, step=global_step) 461 | progress_bar.close() 462 | 463 | accelerator.wait_for_everyone() 464 | 465 | # Generate sample images for visual inspection 466 | if accelerator.is_main_process: 467 | if epoch % cfg.valid_epochs == 0 or epoch == cfg.num_epochs - 1: 468 | unet = accelerator.unwrap_model(model) 469 | 470 | if cfg.use_ema: 471 | ema_model.store(unet.parameters()) 472 | ema_model.copy_to(unet.parameters()) 473 | 474 | pipeline = DDPMPipeline( 475 | unet=unet, 476 | scheduler=noise_scheduler, 477 | ) 478 | 479 | generator = torch.Generator( 480 | device=pipeline.device).manual_seed(0) 481 | # run pipeline in inference (sample random noise and denoise) 482 | images = pipeline( 483 | generator=generator, 484 | batch_size=cfg.valid_batch_size, 485 | num_inference_steps=cfg.ddpm_num_inference_steps, 486 | output_type="np", 487 | ).images 488 | 489 | if cfg.use_ema: 490 | ema_model.restore(unet.parameters()) 491 | 492 | # denormalize the images and save to tensorboard 493 | images_processed = (images * 255).round().astype("uint8") 494 | 495 | if cfg.logger == "tensorboard": 496 | if is_accelerate_version(">=", "0.17.0.dev0"): 497 | tracker = accelerator.get_tracker( 498 | "tensorboard", unwrap=True) 499 | else: 500 | tracker = accelerator.get_tracker("tensorboard") 501 | tracker.add_images( 502 | "test_samples", images_processed.transpose(0, 3, 1, 2), epoch) 503 | elif cfg.logger == "wandb": 504 | # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files 505 | accelerator.get_tracker("wandb").log( 506 | {"test_samples": [wandb.Image( 507 | img) for img in images_processed], "epoch": epoch}, 508 | step=global_step, 509 | ) 510 | 511 | if epoch % cfg.save_model_epochs == 0 or epoch == cfg.num_epochs - 1: 512 | # save the model 513 | unet = accelerator.unwrap_model(model) 514 | 515 | if cfg.use_ema: 516 | ema_model.store(unet.parameters()) 517 | ema_model.copy_to(unet.parameters()) 518 | 519 | pipeline = DDPMPipeline( 520 | unet=unet, 521 | scheduler=noise_scheduler, 522 | ) 523 | 524 | pipeline.save_pretrained(cfg.output_dir) 525 | 526 | if cfg.use_ema: 527 | ema_model.restore(unet.parameters()) 528 | 529 | if cfg.push_to_hub: 530 | upload_folder( 531 | repo_id=repo_id, 532 | folder_path=cfg.output_dir, 533 | commit_message=f"Epoch {epoch}", 534 | ignore_patterns=["step_*", "epoch_*"], 535 | ) 536 | 537 | accelerator.end_training() 538 | 539 | 540 | if __name__ == "__main__": 541 | main() 542 | -------------------------------------------------------------------------------- /TrainingScript/train_1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import logging 4 | import math 5 | import os 6 | import shutil 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import accelerate 11 | import datasets 12 | import torch 13 | 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import ProjectConfiguration, set_seed 17 | from datasets import load_dataset 18 | from huggingface_hub import create_repo, upload_folder 19 | from packaging import version 20 | from torchvision import transforms 21 | from tqdm.auto import tqdm 22 | 23 | import diffusers 24 | from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available 25 | from diffusers.utils.import_utils import is_xformers_available 26 | 27 | from training_cfg_1 import BaseTrainingConfig, load_training_config 28 | from trainer import Trainer, create_trainer 29 | 30 | 31 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 32 | check_min_version("0.30.0.dev0") 33 | 34 | logger = get_logger(__name__, log_level="INFO") 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('cfg', type=str) 40 | args = parser.parse_args() 41 | 42 | cfgs = load_training_config(args.cfg) 43 | cfg: BaseTrainingConfig = cfgs.pop('base') 44 | trainer_type = next(iter(cfgs)) 45 | trainer_cfg_dict = cfgs[trainer_type] 46 | 47 | logging_dir = os.path.join(cfg.output_dir, cfg.logging_dir) 48 | accelerator_project_config = ProjectConfiguration( 49 | project_dir=cfg.output_dir, logging_dir=logging_dir) 50 | 51 | accelerator = Accelerator( 52 | gradient_accumulation_steps=cfg.gradient_accumulation_steps, 53 | mixed_precision=cfg.mixed_precision, 54 | log_with=cfg.logger, 55 | project_config=accelerator_project_config 56 | ) 57 | 58 | weight_dtype = torch.float32 59 | if accelerator.mixed_precision == "fp16": 60 | weight_dtype = torch.float16 61 | cfg.mixed_precision = accelerator.mixed_precision 62 | elif accelerator.mixed_precision == "bf16": 63 | weight_dtype = torch.bfloat16 64 | cfg.mixed_precision = accelerator.mixed_precision 65 | 66 | trainer: Trainer = create_trainer( 67 | trainer_type, weight_dtype, accelerator, cfg.logger, trainer_cfg_dict) 68 | 69 | if cfg.logger == "tensorboard": 70 | if not is_tensorboard_available(): 71 | raise ImportError( 72 | "Make sure to install tensorboard if you want to use it for logging during training.") 73 | 74 | elif cfg.logger == "wandb": 75 | if not is_wandb_available(): 76 | raise ImportError( 77 | "Make sure to install wandb if you want to use it for logging during training.") 78 | import wandb 79 | 80 | # `accelerate` 0.16.0 will have better support for customized saving 81 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 82 | 83 | accelerator.register_save_state_pre_hook(trainer.save_model_hook) 84 | accelerator.register_load_state_pre_hook(trainer.load_model_hook) 85 | 86 | # Make one log on every process with the configuration for debugging. 87 | logging.basicConfig( 88 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 89 | datefmt="%m/%d/%Y %H:%M:%S", 90 | level=logging.INFO, 91 | ) 92 | logger.info(accelerator.state, main_process_only=False) 93 | if accelerator.is_local_main_process: 94 | datasets.utils.logging.set_verbosity_warning() 95 | diffusers.utils.logging.set_verbosity_info() 96 | else: 97 | datasets.utils.logging.set_verbosity_error() 98 | diffusers.utils.logging.set_verbosity_error() 99 | 100 | # If passed along, set the training seed now. 101 | if cfg.seed is not None: 102 | set_seed(cfg.seed) 103 | 104 | # Handle the repository creation 105 | if accelerator.is_main_process: 106 | if cfg.output_dir is not None: 107 | os.makedirs(cfg.output_dir, exist_ok=True) 108 | 109 | if cfg.push_to_hub: 110 | repo_id = create_repo( 111 | repo_id=cfg.hub_model_id or Path(cfg.output_dir).name, exist_ok=True, token=cfg.hub_token 112 | ).repo_id 113 | 114 | # Initialize the model 115 | enable_xformers = False 116 | if cfg.enable_xformers_memory_efficient_attention: 117 | if is_xformers_available(): 118 | import xformers 119 | 120 | xformers_version = version.parse(xformers.__version__) 121 | if xformers_version == version.parse("0.0.16"): 122 | logger.warning( 123 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 124 | ) 125 | enable_xformers = True 126 | 127 | else: 128 | raise ValueError( 129 | "xformers is not available. Make sure it is installed correctly") 130 | 131 | trainer.init_modules(enable_xformers, cfg.gradient_checkpointing) 132 | 133 | # Initialize the optimizer 134 | trainer.init_optimizers(cfg.train_batch_size) 135 | 136 | # Get the datasets: you can either provide your own training and evaluation files (see below) 137 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 138 | 139 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 140 | # download the dataset. 141 | if cfg.dataset_name is not None: 142 | dataset = load_dataset( 143 | cfg.dataset_name, 144 | cfg.dataset_config_name, 145 | cache_dir=cfg.cache_dir, 146 | split="train", 147 | ) 148 | else: 149 | dataset = load_dataset( 150 | "imagefolder", data_dir=cfg.train_data_dir, cache_dir=cfg.cache_dir, split="train") 151 | # See more about loading custom images at 152 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 153 | 154 | # Preprocessing the datasets and DataLoaders creation. 155 | augmentations = transforms.Compose( 156 | [ 157 | transforms.Resize( 158 | cfg.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 159 | transforms.CenterCrop( 160 | cfg.resolution) if cfg.center_crop else transforms.RandomCrop(cfg.resolution), 161 | transforms.RandomHorizontalFlip() if cfg.random_flip else transforms.Lambda(lambda x: x), 162 | transforms.ToTensor(), 163 | transforms.Normalize([0.5], [0.5]), 164 | ] 165 | ) 166 | 167 | def transform_images(examples): 168 | images = [augmentations(image.convert("RGB")) 169 | for image in examples["image"]] 170 | return {"input": images} 171 | 172 | logger.info(f"Dataset size: {len(dataset)}") 173 | 174 | dataset.set_transform(transform_images) 175 | train_dataloader = torch.utils.data.DataLoader( 176 | dataset, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.dataloader_num_workers 177 | ) 178 | 179 | trainer.set_dataset(dataset, train_dataloader) 180 | 181 | # Initialize the learning rate scheduler 182 | trainer.init_lr_schedulers(cfg.gradient_accumulation_steps, cfg.num_epochs) 183 | 184 | # Prepare everything with our `accelerator`. 185 | trainer.prepare_modules() 186 | train_dataloader = trainer.train_dataloader 187 | 188 | # We need to initialize the trackers we use, and also store our configuration. 189 | # The trackers initializes automatically on the main process. 190 | if accelerator.is_main_process: 191 | now = datetime.now() 192 | formatted_now = now.strftime('%Y%m%d%H%M%S') 193 | accelerator.init_trackers( 194 | formatted_now, config=vars(args)) 195 | 196 | total_batch_size = cfg.train_batch_size * \ 197 | accelerator.num_processes * cfg.gradient_accumulation_steps 198 | num_update_steps_per_epoch = math.ceil( 199 | len(train_dataloader) / cfg.gradient_accumulation_steps) 200 | max_train_steps = cfg.num_epochs * num_update_steps_per_epoch 201 | 202 | logger.info("***** Running training *****") 203 | logger.info(f" Num examples = {len(dataset)}") 204 | logger.info(f" Num Epochs = {cfg.num_epochs}") 205 | logger.info( 206 | f" Instantaneous batch size per device = {cfg.train_batch_size}") 207 | logger.info( 208 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 209 | logger.info( 210 | f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}") 211 | logger.info(f" Total optimization steps = {max_train_steps}") 212 | 213 | global_step = 0 214 | first_epoch = 0 215 | 216 | # Potentially load in the weights and states from a previous save 217 | if cfg.resume_from_checkpoint: 218 | if cfg.resume_from_checkpoint != "latest": 219 | path = os.path.basename(cfg.resume_from_checkpoint) 220 | else: 221 | # Get the most recent checkpoint 222 | dirs = os.listdir(cfg.output_dir) 223 | dirs = [d for d in dirs if d.startswith("checkpoint")] 224 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 225 | path = dirs[-1] if len(dirs) > 0 else None 226 | 227 | if path is None: 228 | accelerator.print( 229 | f"Checkpoint '{cfg.resume_from_checkpoint}' does not exist. Starting a new training run." 230 | ) 231 | cfg.resume_from_checkpoint = None 232 | else: 233 | accelerator.print(f"Resuming from checkpoint {path}") 234 | accelerator.load_state(os.path.join(cfg.output_dir, path)) 235 | global_step = int(path.split("-")[1]) 236 | 237 | resume_global_step = global_step * cfg.gradient_accumulation_steps 238 | first_epoch = global_step // num_update_steps_per_epoch 239 | resume_step = resume_global_step % ( 240 | num_update_steps_per_epoch * cfg.gradient_accumulation_steps) 241 | 242 | # Train! 243 | for epoch in range(first_epoch, cfg.num_epochs): 244 | trainer.models_to_train() 245 | progress_bar = tqdm(total=num_update_steps_per_epoch, 246 | disable=not accelerator.is_local_main_process) 247 | progress_bar.set_description(f"Epoch {epoch}") 248 | for step, batch in enumerate(train_dataloader): 249 | # Skip steps until we reach the resumed step 250 | if cfg.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 251 | if step % cfg.gradient_accumulation_steps == 0: 252 | progress_bar.update(1) 253 | continue 254 | logs = trainer.training_step(global_step, batch) 255 | 256 | # Checks if the accelerator has performed an optimization step behind the scenes 257 | if accelerator.sync_gradients: 258 | progress_bar.update(1) 259 | global_step += 1 260 | 261 | if accelerator.is_main_process: 262 | if global_step % cfg.checkpointing_steps == 0: 263 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 264 | if cfg.checkpoints_total_limit is not None: 265 | checkpoints = os.listdir(cfg.output_dir) 266 | checkpoints = [ 267 | d for d in checkpoints if d.startswith("checkpoint")] 268 | checkpoints = sorted( 269 | checkpoints, key=lambda x: int(x.split("-")[1])) 270 | 271 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 272 | if len(checkpoints) >= cfg.checkpoints_total_limit: 273 | num_to_remove = len( 274 | checkpoints) - cfg.checkpoints_total_limit + 1 275 | removing_checkpoints = checkpoints[0:num_to_remove] 276 | 277 | logger.info( 278 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 279 | ) 280 | logger.info( 281 | f"removing checkpoints: {', '.join(removing_checkpoints)}") 282 | 283 | for removing_checkpoint in removing_checkpoints: 284 | removing_checkpoint = os.path.join( 285 | cfg.output_dir, removing_checkpoint) 286 | shutil.rmtree(removing_checkpoint) 287 | 288 | save_path = os.path.join( 289 | cfg.output_dir, f"checkpoint-{global_step}") 290 | accelerator.save_state(save_path) 291 | logger.info(f"Saved state to {save_path}") 292 | 293 | progress_bar.set_postfix(**logs) 294 | accelerator.log(logs, step=global_step) 295 | progress_bar.close() 296 | 297 | accelerator.wait_for_everyone() 298 | 299 | # Generate sample images for visual inspection 300 | if accelerator.is_main_process: 301 | if epoch % cfg.valid_epochs == 0 or epoch == cfg.num_epochs - 1: 302 | trainer.validate(epoch, global_step) 303 | 304 | if epoch % cfg.save_model_epochs == 0 or epoch == cfg.num_epochs - 1: 305 | trainer.save_pipeline(cfg.output_dir) 306 | 307 | if cfg.push_to_hub: 308 | upload_folder( 309 | repo_id=repo_id, 310 | folder_path=cfg.output_dir, 311 | commit_message=f"Epoch {epoch}", 312 | ignore_patterns=["step_*", "epoch_*"], 313 | ) 314 | 315 | accelerator.end_training() 316 | 317 | 318 | if __name__ == "__main__": 319 | main() 320 | -------------------------------------------------------------------------------- /TrainingScript/train_official.py: -------------------------------------------------------------------------------- 1 | # examples/unconditional_image_generation/train_unconditional.py 2 | import argparse 3 | import inspect 4 | import logging 5 | import math 6 | import os 7 | import shutil 8 | from datetime import timedelta 9 | from pathlib import Path 10 | 11 | import accelerate 12 | import datasets 13 | import torch 14 | import torch.nn.functional as F 15 | from accelerate import Accelerator, InitProcessGroupKwargs 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import ProjectConfiguration 18 | from datasets import load_dataset 19 | from huggingface_hub import create_repo, upload_folder 20 | from packaging import version 21 | from torchvision import transforms 22 | from tqdm.auto import tqdm 23 | 24 | import diffusers 25 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel 26 | from diffusers.optimization import get_scheduler 27 | from diffusers.training_utils import EMAModel 28 | from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available 29 | from diffusers.utils.import_utils import is_xformers_available 30 | 31 | 32 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 33 | check_min_version("0.30.0.dev0") 34 | 35 | logger = get_logger(__name__, log_level="INFO") 36 | 37 | 38 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 39 | """ 40 | Extract values from a 1-D numpy array for a batch of indices. 41 | 42 | :param arr: the 1-D numpy array. 43 | :param timesteps: a tensor of indices into the array to extract. 44 | :param broadcast_shape: a larger shape of K dimensions with the batch 45 | dimension equal to the length of timesteps. 46 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 47 | """ 48 | if not isinstance(arr, torch.Tensor): 49 | arr = torch.from_numpy(arr) 50 | res = arr[timesteps].float().to(timesteps.device) 51 | while len(res.shape) < len(broadcast_shape): 52 | res = res[..., None] 53 | return res.expand(broadcast_shape) 54 | 55 | 56 | def parse_args(): 57 | parser = argparse.ArgumentParser( 58 | description="Simple example of a training script.") 59 | parser.add_argument( 60 | "--dataset_name", 61 | type=str, 62 | default=None, 63 | help=( 64 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 65 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 66 | " or to a folder containing files that HF Datasets can understand." 67 | ), 68 | ) 69 | parser.add_argument( 70 | "--dataset_config_name", 71 | type=str, 72 | default=None, 73 | help="The config of the Dataset, leave as None if there's only one config.", 74 | ) 75 | parser.add_argument( 76 | "--model_config_name_or_path", 77 | type=str, 78 | default=None, 79 | help="The config of the UNet model to train, leave as None to use standard DDPM configuration.", 80 | ) 81 | parser.add_argument( 82 | "--train_data_dir", 83 | type=str, 84 | default=None, 85 | help=( 86 | "A folder containing the training data. Folder contents must follow the structure described in" 87 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 88 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 89 | ), 90 | ) 91 | parser.add_argument( 92 | "--output_dir", 93 | type=str, 94 | default="ddpm-model-64", 95 | help="The output directory where the model predictions and checkpoints will be written.", 96 | ) 97 | parser.add_argument("--overwrite_output_dir", action="store_true") 98 | parser.add_argument( 99 | "--cache_dir", 100 | type=str, 101 | default=None, 102 | help="The directory where the downloaded models and datasets will be stored.", 103 | ) 104 | parser.add_argument( 105 | "--resolution", 106 | type=int, 107 | default=64, 108 | help=( 109 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 110 | " resolution" 111 | ), 112 | ) 113 | parser.add_argument( 114 | "--center_crop", 115 | default=False, 116 | action="store_true", 117 | help=( 118 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 119 | " cropped. The images will be resized to the resolution first before cropping." 120 | ), 121 | ) 122 | parser.add_argument( 123 | "--random_flip", 124 | default=False, 125 | action="store_true", 126 | help="whether to randomly flip images horizontally", 127 | ) 128 | parser.add_argument( 129 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 130 | ) 131 | parser.add_argument( 132 | "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." 133 | ) 134 | parser.add_argument( 135 | "--dataloader_num_workers", 136 | type=int, 137 | default=0, 138 | help=( 139 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" 140 | " process." 141 | ), 142 | ) 143 | parser.add_argument("--num_epochs", type=int, default=100) 144 | parser.add_argument("--save_images_epochs", type=int, default=10, 145 | help="How often to save images during training.") 146 | parser.add_argument( 147 | "--save_model_epochs", type=int, default=10, help="How often to save the model during training." 148 | ) 149 | parser.add_argument( 150 | "--gradient_accumulation_steps", 151 | type=int, 152 | default=1, 153 | help="Number of updates steps to accumulate before performing a backward/update pass.", 154 | ) 155 | parser.add_argument( 156 | "--learning_rate", 157 | type=float, 158 | default=1e-4, 159 | help="Initial learning rate (after the potential warmup period) to use.", 160 | ) 161 | parser.add_argument( 162 | "--lr_scheduler", 163 | type=str, 164 | default="cosine", 165 | help=( 166 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 167 | ' "constant", "constant_with_warmup"]' 168 | ), 169 | ) 170 | parser.add_argument( 171 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 172 | ) 173 | parser.add_argument("--adam_beta1", type=float, default=0.95, 174 | help="The beta1 parameter for the Adam optimizer.") 175 | parser.add_argument("--adam_beta2", type=float, default=0.999, 176 | help="The beta2 parameter for the Adam optimizer.") 177 | parser.add_argument( 178 | "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer." 179 | ) 180 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, 181 | help="Epsilon value for the Adam optimizer.") 182 | parser.add_argument( 183 | "--use_ema", 184 | action="store_true", 185 | help="Whether to use Exponential Moving Average for the final model weights.", 186 | ) 187 | parser.add_argument("--ema_inv_gamma", type=float, default=1.0, 188 | help="The inverse gamma value for the EMA decay.") 189 | parser.add_argument("--ema_power", type=float, default=3 / 4, 190 | help="The power value for the EMA decay.") 191 | parser.add_argument("--ema_max_decay", type=float, default=0.9999, 192 | help="The maximum decay magnitude for EMA.") 193 | parser.add_argument("--push_to_hub", action="store_true", 194 | help="Whether or not to push the model to the Hub.") 195 | parser.add_argument("--hub_token", type=str, default=None, 196 | help="The token to use to push to the Model Hub.") 197 | parser.add_argument( 198 | "--hub_model_id", 199 | type=str, 200 | default=None, 201 | help="The name of the repository to keep in sync with the local `output_dir`.", 202 | ) 203 | parser.add_argument( 204 | "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." 205 | ) 206 | parser.add_argument( 207 | "--logger", 208 | type=str, 209 | default="tensorboard", 210 | choices=["tensorboard", "wandb"], 211 | help=( 212 | "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)" 213 | " for experiment tracking and logging of model metrics and model checkpoints" 214 | ), 215 | ) 216 | parser.add_argument( 217 | "--logging_dir", 218 | type=str, 219 | default="logs", 220 | help=( 221 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 222 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 223 | ), 224 | ) 225 | parser.add_argument("--local_rank", type=int, default=-1, 226 | help="For distributed training: local_rank") 227 | parser.add_argument( 228 | "--mixed_precision", 229 | type=str, 230 | default="no", 231 | choices=["no", "fp16", "bf16"], 232 | help=( 233 | "Whether to use mixed precision. Choose" 234 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 235 | "and an Nvidia Ampere GPU." 236 | ), 237 | ) 238 | parser.add_argument( 239 | "--prediction_type", 240 | type=str, 241 | default="epsilon", 242 | choices=["epsilon", "sample"], 243 | help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", 244 | ) 245 | parser.add_argument("--ddpm_num_steps", type=int, default=1000) 246 | parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000) 247 | parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") 248 | parser.add_argument( 249 | "--checkpointing_steps", 250 | type=int, 251 | default=500, 252 | help=( 253 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 254 | " training using `--resume_from_checkpoint`." 255 | ), 256 | ) 257 | parser.add_argument( 258 | "--checkpoints_total_limit", 259 | type=int, 260 | default=None, 261 | help=("Max number of checkpoints to store."), 262 | ) 263 | parser.add_argument( 264 | "--resume_from_checkpoint", 265 | type=str, 266 | default=None, 267 | help=( 268 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 269 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 270 | ), 271 | ) 272 | parser.add_argument( 273 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 274 | ) 275 | 276 | args = parser.parse_args() 277 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 278 | if env_local_rank != -1 and env_local_rank != args.local_rank: 279 | args.local_rank = env_local_rank 280 | 281 | if args.dataset_name is None and args.train_data_dir is None: 282 | raise ValueError( 283 | "You must specify either a dataset name from the hub or a train data directory.") 284 | 285 | return args 286 | 287 | 288 | def main(args): 289 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 290 | accelerator_project_config = ProjectConfiguration( 291 | project_dir=args.output_dir, logging_dir=logging_dir) 292 | 293 | # a big number for high resolution or big dataset 294 | kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) 295 | accelerator = Accelerator( 296 | gradient_accumulation_steps=args.gradient_accumulation_steps, 297 | mixed_precision=args.mixed_precision, 298 | log_with=args.logger, 299 | project_config=accelerator_project_config, 300 | kwargs_handlers=[kwargs], 301 | ) 302 | 303 | if args.logger == "tensorboard": 304 | if not is_tensorboard_available(): 305 | raise ImportError( 306 | "Make sure to install tensorboard if you want to use it for logging during training.") 307 | 308 | elif args.logger == "wandb": 309 | if not is_wandb_available(): 310 | raise ImportError( 311 | "Make sure to install wandb if you want to use it for logging during training.") 312 | import wandb 313 | 314 | # `accelerate` 0.16.0 will have better support for customized saving 315 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 316 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 317 | def save_model_hook(models, weights, output_dir): 318 | if accelerator.is_main_process: 319 | if args.use_ema: 320 | ema_model.save_pretrained( 321 | os.path.join(output_dir, "unet_ema")) 322 | 323 | for i, model in enumerate(models): 324 | model.save_pretrained(os.path.join(output_dir, "unet")) 325 | 326 | # make sure to pop weight so that corresponding model is not saved again 327 | weights.pop() 328 | 329 | def load_model_hook(models, input_dir): 330 | if args.use_ema: 331 | load_model = EMAModel.from_pretrained( 332 | os.path.join(input_dir, "unet_ema"), UNet2DModel) 333 | ema_model.load_state_dict(load_model.state_dict()) 334 | ema_model.to(accelerator.device) 335 | del load_model 336 | 337 | for i in range(len(models)): 338 | # pop models so that they are not loaded again 339 | model = models.pop() 340 | 341 | # load diffusers style into model 342 | load_model = UNet2DModel.from_pretrained( 343 | input_dir, subfolder="unet") 344 | model.register_to_config(**load_model.config) 345 | 346 | model.load_state_dict(load_model.state_dict()) 347 | del load_model 348 | 349 | accelerator.register_save_state_pre_hook(save_model_hook) 350 | accelerator.register_load_state_pre_hook(load_model_hook) 351 | 352 | # Make one log on every process with the configuration for debugging. 353 | logging.basicConfig( 354 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 355 | datefmt="%m/%d/%Y %H:%M:%S", 356 | level=logging.INFO, 357 | ) 358 | logger.info(accelerator.state, main_process_only=False) 359 | if accelerator.is_local_main_process: 360 | datasets.utils.logging.set_verbosity_warning() 361 | diffusers.utils.logging.set_verbosity_info() 362 | else: 363 | datasets.utils.logging.set_verbosity_error() 364 | diffusers.utils.logging.set_verbosity_error() 365 | 366 | # Handle the repository creation 367 | if accelerator.is_main_process: 368 | if args.output_dir is not None: 369 | os.makedirs(args.output_dir, exist_ok=True) 370 | 371 | if args.push_to_hub: 372 | repo_id = create_repo( 373 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 374 | ).repo_id 375 | 376 | # Initialize the model 377 | if args.model_config_name_or_path is None: 378 | model = UNet2DModel( 379 | sample_size=args.resolution, 380 | in_channels=3, 381 | out_channels=3, 382 | layers_per_block=2, 383 | block_out_channels=(128, 128, 256, 256, 512, 512), 384 | down_block_types=( 385 | "DownBlock2D", 386 | "DownBlock2D", 387 | "DownBlock2D", 388 | "DownBlock2D", 389 | "AttnDownBlock2D", 390 | "DownBlock2D", 391 | ), 392 | up_block_types=( 393 | "UpBlock2D", 394 | "AttnUpBlock2D", 395 | "UpBlock2D", 396 | "UpBlock2D", 397 | "UpBlock2D", 398 | "UpBlock2D", 399 | ), 400 | ) 401 | else: 402 | config = UNet2DModel.load_config(args.model_config_name_or_path) 403 | model = UNet2DModel.from_config(config) 404 | 405 | # Create EMA for the model. 406 | if args.use_ema: 407 | ema_model = EMAModel( 408 | model.parameters(), 409 | decay=args.ema_max_decay, 410 | use_ema_warmup=True, 411 | inv_gamma=args.ema_inv_gamma, 412 | power=args.ema_power, 413 | model_cls=UNet2DModel, 414 | model_config=model.config, 415 | ) 416 | 417 | weight_dtype = torch.float32 418 | if accelerator.mixed_precision == "fp16": 419 | weight_dtype = torch.float16 420 | args.mixed_precision = accelerator.mixed_precision 421 | elif accelerator.mixed_precision == "bf16": 422 | weight_dtype = torch.bfloat16 423 | args.mixed_precision = accelerator.mixed_precision 424 | 425 | if args.enable_xformers_memory_efficient_attention: 426 | if is_xformers_available(): 427 | import xformers 428 | 429 | xformers_version = version.parse(xformers.__version__) 430 | if xformers_version == version.parse("0.0.16"): 431 | logger.warning( 432 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 433 | ) 434 | model.enable_xformers_memory_efficient_attention() 435 | else: 436 | raise ValueError( 437 | "xformers is not available. Make sure it is installed correctly") 438 | 439 | # Initialize the scheduler 440 | accepts_prediction_type = "prediction_type" in set( 441 | inspect.signature(DDPMScheduler.__init__).parameters.keys()) 442 | if accepts_prediction_type: 443 | noise_scheduler = DDPMScheduler( 444 | num_train_timesteps=args.ddpm_num_steps, 445 | beta_schedule=args.ddpm_beta_schedule, 446 | prediction_type=args.prediction_type, 447 | ) 448 | else: 449 | noise_scheduler = DDPMScheduler( 450 | num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) 451 | 452 | # Initialize the optimizer 453 | optimizer = torch.optim.AdamW( 454 | model.parameters(), 455 | lr=args.learning_rate, 456 | betas=(args.adam_beta1, args.adam_beta2), 457 | weight_decay=args.adam_weight_decay, 458 | eps=args.adam_epsilon, 459 | ) 460 | 461 | # Get the datasets: you can either provide your own training and evaluation files (see below) 462 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 463 | 464 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 465 | # download the dataset. 466 | if args.dataset_name is not None: 467 | dataset = load_dataset( 468 | args.dataset_name, 469 | args.dataset_config_name, 470 | cache_dir=args.cache_dir, 471 | split="train", 472 | ) 473 | else: 474 | dataset = load_dataset( 475 | "imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") 476 | # See more about loading custom images at 477 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 478 | 479 | # Preprocessing the datasets and DataLoaders creation. 480 | augmentations = transforms.Compose( 481 | [ 482 | transforms.Resize( 483 | args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 484 | transforms.CenterCrop( 485 | args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 486 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 487 | transforms.ToTensor(), 488 | transforms.Normalize([0.5], [0.5]), 489 | ] 490 | ) 491 | 492 | def transform_images(examples): 493 | images = [augmentations(image.convert("RGB")) 494 | for image in examples["image"]] 495 | return {"input": images} 496 | 497 | logger.info(f"Dataset size: {len(dataset)}") 498 | 499 | dataset.set_transform(transform_images) 500 | train_dataloader = torch.utils.data.DataLoader( 501 | dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers 502 | ) 503 | 504 | # Initialize the learning rate scheduler 505 | lr_scheduler = get_scheduler( 506 | args.lr_scheduler, 507 | optimizer=optimizer, 508 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 509 | num_training_steps=(len(train_dataloader) * args.num_epochs), 510 | ) 511 | 512 | # Prepare everything with our `accelerator`. 513 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 514 | model, optimizer, train_dataloader, lr_scheduler 515 | ) 516 | 517 | if args.use_ema: 518 | ema_model.to(accelerator.device) 519 | 520 | # We need to initialize the trackers we use, and also store our configuration. 521 | # The trackers initializes automatically on the main process. 522 | if accelerator.is_main_process: 523 | run = os.path.split(__file__)[-1].split(".")[0] 524 | accelerator.init_trackers(run) 525 | 526 | total_batch_size = args.train_batch_size * \ 527 | accelerator.num_processes * args.gradient_accumulation_steps 528 | num_update_steps_per_epoch = math.ceil( 529 | len(train_dataloader) / args.gradient_accumulation_steps) 530 | max_train_steps = args.num_epochs * num_update_steps_per_epoch 531 | 532 | logger.info("***** Running training *****") 533 | logger.info(f" Num examples = {len(dataset)}") 534 | logger.info(f" Num Epochs = {args.num_epochs}") 535 | logger.info( 536 | f" Instantaneous batch size per device = {args.train_batch_size}") 537 | logger.info( 538 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 539 | logger.info( 540 | f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 541 | logger.info(f" Total optimization steps = {max_train_steps}") 542 | 543 | global_step = 0 544 | first_epoch = 0 545 | 546 | # Potentially load in the weights and states from a previous save 547 | if args.resume_from_checkpoint: 548 | if args.resume_from_checkpoint != "latest": 549 | path = os.path.basename(args.resume_from_checkpoint) 550 | else: 551 | # Get the most recent checkpoint 552 | dirs = os.listdir(args.output_dir) 553 | dirs = [d for d in dirs if d.startswith("checkpoint")] 554 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 555 | path = dirs[-1] if len(dirs) > 0 else None 556 | 557 | if path is None: 558 | accelerator.print( 559 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 560 | ) 561 | args.resume_from_checkpoint = None 562 | else: 563 | accelerator.print(f"Resuming from checkpoint {path}") 564 | accelerator.load_state(os.path.join(args.output_dir, path)) 565 | global_step = int(path.split("-")[1]) 566 | 567 | resume_global_step = global_step * args.gradient_accumulation_steps 568 | first_epoch = global_step // num_update_steps_per_epoch 569 | resume_step = resume_global_step % ( 570 | num_update_steps_per_epoch * args.gradient_accumulation_steps) 571 | 572 | # Train! 573 | for epoch in range(first_epoch, args.num_epochs): 574 | model.train() 575 | progress_bar = tqdm(total=num_update_steps_per_epoch, 576 | disable=not accelerator.is_local_main_process) 577 | progress_bar.set_description(f"Epoch {epoch}") 578 | for step, batch in enumerate(train_dataloader): 579 | # Skip steps until we reach the resumed step 580 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 581 | if step % args.gradient_accumulation_steps == 0: 582 | progress_bar.update(1) 583 | continue 584 | 585 | clean_images = batch["input"].to(weight_dtype) 586 | # Sample noise that we'll add to the images 587 | noise = torch.randn(clean_images.shape, 588 | dtype=weight_dtype, device=clean_images.device) 589 | bsz = clean_images.shape[0] 590 | # Sample a random timestep for each image 591 | timesteps = torch.randint( 592 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device 593 | ).long() 594 | 595 | # Add noise to the clean images according to the noise magnitude at each timestep 596 | # (this is the forward diffusion process) 597 | noisy_images = noise_scheduler.add_noise( 598 | clean_images, noise, timesteps) 599 | 600 | with accelerator.accumulate(model): 601 | # Predict the noise residual 602 | model_output = model(noisy_images, timesteps).sample 603 | 604 | if args.prediction_type == "epsilon": 605 | # this could have different weights! 606 | loss = F.mse_loss(model_output.float(), noise.float()) 607 | elif args.prediction_type == "sample": 608 | alpha_t = _extract_into_tensor( 609 | noise_scheduler.alphas_cumprod, timesteps, ( 610 | clean_images.shape[0], 1, 1, 1) 611 | ) 612 | snr_weights = alpha_t / (1 - alpha_t) 613 | # use SNR weighting from distillation paper 614 | loss = snr_weights * \ 615 | F.mse_loss(model_output.float(), 616 | clean_images.float(), reduction="none") 617 | loss = loss.mean() 618 | else: 619 | raise ValueError( 620 | f"Unsupported prediction type: {args.prediction_type}") 621 | 622 | accelerator.backward(loss) 623 | 624 | if accelerator.sync_gradients: 625 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 626 | optimizer.step() 627 | lr_scheduler.step() 628 | optimizer.zero_grad() 629 | 630 | # Checks if the accelerator has performed an optimization step behind the scenes 631 | if accelerator.sync_gradients: 632 | if args.use_ema: 633 | ema_model.step(model.parameters()) 634 | progress_bar.update(1) 635 | global_step += 1 636 | 637 | if accelerator.is_main_process: 638 | if global_step % args.checkpointing_steps == 0: 639 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 640 | if args.checkpoints_total_limit is not None: 641 | checkpoints = os.listdir(args.output_dir) 642 | checkpoints = [ 643 | d for d in checkpoints if d.startswith("checkpoint")] 644 | checkpoints = sorted( 645 | checkpoints, key=lambda x: int(x.split("-")[1])) 646 | 647 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 648 | if len(checkpoints) >= args.checkpoints_total_limit: 649 | num_to_remove = len( 650 | checkpoints) - args.checkpoints_total_limit + 1 651 | removing_checkpoints = checkpoints[0:num_to_remove] 652 | 653 | logger.info( 654 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 655 | ) 656 | logger.info( 657 | f"removing checkpoints: {', '.join(removing_checkpoints)}") 658 | 659 | for removing_checkpoint in removing_checkpoints: 660 | removing_checkpoint = os.path.join( 661 | args.output_dir, removing_checkpoint) 662 | shutil.rmtree(removing_checkpoint) 663 | 664 | save_path = os.path.join( 665 | args.output_dir, f"checkpoint-{global_step}") 666 | accelerator.save_state(save_path) 667 | logger.info(f"Saved state to {save_path}") 668 | 669 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[ 670 | 0], "step": global_step} 671 | if args.use_ema: 672 | logs["ema_decay"] = ema_model.cur_decay_value 673 | progress_bar.set_postfix(**logs) 674 | accelerator.log(logs, step=global_step) 675 | progress_bar.close() 676 | 677 | accelerator.wait_for_everyone() 678 | 679 | # Generate sample images for visual inspection 680 | if accelerator.is_main_process: 681 | if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: 682 | unet = accelerator.unwrap_model(model) 683 | 684 | if args.use_ema: 685 | ema_model.store(unet.parameters()) 686 | ema_model.copy_to(unet.parameters()) 687 | 688 | pipeline = DDPMPipeline( 689 | unet=unet, 690 | scheduler=noise_scheduler, 691 | ) 692 | 693 | generator = torch.Generator( 694 | device=pipeline.device).manual_seed(0) 695 | # run pipeline in inference (sample random noise and denoise) 696 | images = pipeline( 697 | generator=generator, 698 | batch_size=args.eval_batch_size, 699 | num_inference_steps=args.ddpm_num_inference_steps, 700 | output_type="np", 701 | ).images 702 | 703 | if args.use_ema: 704 | ema_model.restore(unet.parameters()) 705 | 706 | # denormalize the images and save to tensorboard 707 | images_processed = (images * 255).round().astype("uint8") 708 | 709 | if args.logger == "tensorboard": 710 | if is_accelerate_version(">=", "0.17.0.dev0"): 711 | tracker = accelerator.get_tracker( 712 | "tensorboard", unwrap=True) 713 | else: 714 | tracker = accelerator.get_tracker("tensorboard") 715 | tracker.add_images( 716 | "test_samples", images_processed.transpose(0, 3, 1, 2), epoch) 717 | elif args.logger == "wandb": 718 | # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files 719 | accelerator.get_tracker("wandb").log( 720 | {"test_samples": [wandb.Image( 721 | img) for img in images_processed], "epoch": epoch}, 722 | step=global_step, 723 | ) 724 | 725 | if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: 726 | # save the model 727 | unet = accelerator.unwrap_model(model) 728 | 729 | if args.use_ema: 730 | ema_model.store(unet.parameters()) 731 | ema_model.copy_to(unet.parameters()) 732 | 733 | pipeline = DDPMPipeline( 734 | unet=unet, 735 | scheduler=noise_scheduler, 736 | ) 737 | 738 | pipeline.save_pretrained(args.output_dir) 739 | 740 | if args.use_ema: 741 | ema_model.restore(unet.parameters()) 742 | 743 | if args.push_to_hub: 744 | upload_folder( 745 | repo_id=repo_id, 746 | folder_path=args.output_dir, 747 | commit_message=f"Epoch {epoch}", 748 | ignore_patterns=["step_*", "epoch_*"], 749 | ) 750 | 751 | accelerator.end_training() 752 | 753 | 754 | if __name__ == "__main__": 755 | args = parse_args() 756 | main(args) 757 | -------------------------------------------------------------------------------- /TrainingScript/trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Trainer(metaclass=ABCMeta): 5 | def __init__(self, weight_dtype, accelerator, logger, cfg): 6 | self.weight_dtype = weight_dtype 7 | self.accelerator = accelerator 8 | self.logger = logger 9 | self.cfg = cfg 10 | 11 | @abstractmethod 12 | def init_modules(self, 13 | enable_xformer: bool = False, 14 | gradient_checkpointing: bool = False): 15 | pass 16 | 17 | @abstractmethod 18 | def init_optimizers(self, train_batch_size): 19 | pass 20 | 21 | @abstractmethod 22 | def init_lr_schedulers(self, gradient_accumulation_steps, num_epochs): 23 | pass 24 | 25 | def set_dataset(self, dataset, train_dataloader): 26 | self.dataset = dataset 27 | self.train_dataloader = train_dataloader 28 | 29 | @abstractmethod 30 | def prepare_modules(self): 31 | pass 32 | 33 | @abstractmethod 34 | def models_to_train(self): 35 | pass 36 | 37 | @abstractmethod 38 | def training_step(self, global_step, batch) -> dict: 39 | pass 40 | 41 | @abstractmethod 42 | def validate(self, epoch, global_step): 43 | pass 44 | 45 | @abstractmethod 46 | def save_pipeline(self): 47 | pass 48 | 49 | @abstractmethod 50 | def save_model_hook(self, models, weights, output_dir): 51 | pass 52 | 53 | @abstractmethod 54 | def load_model_hook(self, models, input_dir): 55 | pass 56 | 57 | 58 | def create_trainer(type, weight_dtype, accelerator, logger, cfg_dict) -> Trainer: 59 | from ddpm_trainer import DDPMTrainer 60 | from sd_lora_trainer import LoraTrainer 61 | 62 | __TYPE_CLS_DICT = { 63 | 'ddpm': DDPMTrainer, 64 | 'lora': LoraTrainer 65 | } 66 | 67 | return __TYPE_CLS_DICT[type](weight_dtype, accelerator, logger, cfg_dict) 68 | -------------------------------------------------------------------------------- /TrainingScript/training_cfg_0.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class BaseTrainingConfig: 6 | # Dir 7 | logging_dir: str 8 | output_dir: str 9 | 10 | # Logger and checkpoint 11 | logger: str = 'tensorboard' 12 | checkpointing_steps: int = 500 13 | checkpoints_total_limit: int = 20 14 | valid_epochs: int = 100 15 | valid_batch_size: int = 1 16 | save_model_epochs: int = 100 17 | resume_from_checkpoint: str = None 18 | 19 | # Diffuion Models 20 | model_config: str = None 21 | ddpm_num_steps: int = 1000 22 | ddpm_beta_schedule: str = 'linear' 23 | prediction_type: str = 'epsilon' 24 | ddpm_num_inference_steps: int = 100 25 | 26 | # Training 27 | seed: int = None 28 | num_epochs: int = 200 29 | train_batch_size: int = 1 30 | dataloader_num_workers: int = 1 31 | gradient_accumulation_steps: int = 1 32 | mixed_precision: str = None 33 | enable_xformers_memory_efficient_attention: bool = True 34 | gradient_checkpointing: bool = False 35 | 36 | # Dataset 37 | dataset_name: str = None 38 | dataset_config_name: str = None 39 | train_data_dir: str = None 40 | cache_dir: str = None 41 | resolution: int = 512 42 | center_crop: bool = False 43 | random_flip: bool = False 44 | 45 | # LR Scheduler 46 | lr_scheduler: str = 'constant' 47 | lr_warmup_steps: int = 500 48 | 49 | # AdamW 50 | scale_lr = False 51 | learning_rate: float = 1e-4 52 | adam_beta1: float = 0.9 53 | adam_beta2: float = 0.999 54 | adam_weight_decay: float = 1e-2 55 | adam_epsilon: float = 1e-08 56 | 57 | # EMA 58 | use_ema: bool = False 59 | ema_max_decay: float = 0.9999 60 | ema_inv_gamma: float = 1.0 61 | ema_power: float = 3 / 4 62 | 63 | # Hub 64 | push_to_hub: bool = False 65 | hub_model_id: str = '' 66 | -------------------------------------------------------------------------------- /TrainingScript/training_cfg_1.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from dataclasses import dataclass 3 | from omegaconf import OmegaConf 4 | 5 | from ddpm_trainer import DDPMTrainingConfig 6 | from sd_lora_trainer import LoraTrainingConfig 7 | 8 | 9 | @dataclass 10 | class BaseTrainingConfig: 11 | # Dir 12 | logging_dir: str 13 | output_dir: str 14 | 15 | # Logger and checkpoint 16 | logger: str = 'tensorboard' 17 | checkpointing_steps: int = 500 18 | checkpoints_total_limit: int = 20 19 | valid_epochs: int = 100 20 | save_model_epochs: int = 100 21 | resume_from_checkpoint: str = None 22 | 23 | # Training 24 | seed: int = None 25 | num_epochs: int = 200 26 | train_batch_size: int = 1 27 | dataloader_num_workers: int = 1 28 | gradient_accumulation_steps: int = 1 29 | mixed_precision: str = None 30 | enable_xformers_memory_efficient_attention: bool = True 31 | gradient_checkpointing: bool = False 32 | 33 | # Dataset 34 | dataset_name: str = None 35 | dataset_config_name: str = None 36 | train_data_dir: str = None 37 | cache_dir: str = None 38 | resolution: int = 512 39 | center_crop: bool = True 40 | random_flip: bool = False 41 | 42 | # Hub 43 | push_to_hub: bool = False 44 | hub_model_id: str = '' 45 | 46 | 47 | __TYPE_CLS_DICT = { 48 | 'base': BaseTrainingConfig, 49 | 'ddpm': DDPMTrainingConfig, 50 | 'lora': LoraTrainingConfig 51 | } 52 | 53 | 54 | def load_training_config(config_path: str) -> Dict[str, BaseTrainingConfig]: 55 | data_dict = OmegaConf.load(config_path) 56 | 57 | # The config must have a "base" key 58 | base_cfg_dict = data_dict.pop('base') 59 | 60 | # The config must have one another model config 61 | assert len(data_dict) == 1 62 | model_key = next(iter(data_dict)) 63 | model_cfg_dict = data_dict[model_key] 64 | model_cfg_cls = __TYPE_CLS_DICT[model_key] 65 | 66 | return {'base': BaseTrainingConfig(**base_cfg_dict), 67 | model_key: model_cfg_cls(**model_cfg_dict)} 68 | -------------------------------------------------------------------------------- /TrainingScript/unet_cfg/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DModel", 3 | 4 | "block_out_channels": [ 5 | 64, 6 | 128, 7 | 256, 8 | 256 9 | ], 10 | "down_block_types": [ 11 | "DownBlock2D", 12 | "DownBlock2D", 13 | "DownBlock2D", 14 | "DownBlock2D" 15 | ], 16 | "in_channels": 3, 17 | "out_channels": 3, 18 | "layers_per_block":2, 19 | "sample_size": 32, 20 | "up_block_types": [ 21 | "UpBlock2D", 22 | "UpBlock2D", 23 | "UpBlock2D", 24 | "UpBlock2D" 25 | ] 26 | } 27 | --------------------------------------------------------------------------------