├── .gitignore ├── LICENSE ├── README.md ├── data ├── T-shirt │ ├── .DS_Store │ ├── gap-2697-8350236-1.jpg │ ├── gap-2698-8350236-2.jpg │ ├── gap-2699-8350236-3.jpg │ └── gap-2699-8350236-4.jpg ├── T-shirt_test │ ├── 5e46f9dd2dae5c4d4770c214.jpg │ └── T-shirt.jpg ├── sofa │ ├── 53bb82effa16547cea05419fd84edf90_1800x1800.jpg │ ├── 94e823f11f4f841adfddf2fe2dc87b02_1800x1800.jpg │ ├── d157e324510ef44ede37a9dbea53bf30_1800x1800.jpg │ └── f6a0ea9a009f47a59d00aeec56378921_1800x1800.jpg └── sofa_test │ ├── H-5168-12_8c8cc93e-163e-47a6-b8d0-4d3253f0b86b_900x.jpg │ └── XE3KtWoY2TgU7r2u2AJkBJ.jpg ├── docs ├── image_composition-1.png └── image_composition-2.png ├── inference.py ├── inference_lora.py ├── preprocess.py ├── requirements.txt ├── run.sh ├── run_lora.sh ├── train.py └── train_lora.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs 3 | data/*/*.png 4 | data/*/*.txt 5 | out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 RealityEdior 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Composition via Stable Diffusion 2 | 3 | ### We achieve image composition via Stable Diffusion Model. Application includes Virtual Clothes / Funiture Try-on. 4 | 5 | 6 |

7 | 8 |

9 |

10 | Demo 1: Virtual Clothes Try-on 11 |

12 | 13 | --- 14 | 15 |

16 | 17 |

18 |

19 | Demo 2: Virtual Furniture Try-on 20 |

21 | 22 | ### Installation 23 | * Requirements 24 | ```bash 25 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | * Initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with 30 | ```bash 31 | accelerate config default 32 | ``` 33 | 34 | * Run the following command to authenticate your token 35 | 36 | ```bash 37 | huggingface-cli login 38 | ``` 39 | 40 | ### 0. Prepare Images 41 | Please provide at least one images in .jpg format and instance prompt. 42 | For example, images in ./data/sofa 43 | 44 | ### 1. Set Environment 45 | ```bash 46 | export MODEL_NAME="runwayml/stable-diffusion-inpainting" 47 | export INSTANCE_DIR="data/sofa" 48 | export Test_DIR="data/sofa_test" 49 | export MODEL_DIR="logs/sofa" 50 | export OUT_DIR="out/sofa" 51 | export INSTANCE_PROMPT="sofa" 52 | ``` 53 | 54 | ### 2. Preprocess Images 55 | Please provide at least one images in .jpg format and instance prompt. The preprocess.py script will generate captions and instance masks. 56 | 57 | ```bash 58 | python preprocess.py --instance_data_dir $INSTANCE_DIR \ 59 | --instance_prompt $INSTANCE_PROMPT 60 | ``` 61 | 62 | ### 3. Finetune 63 | We then embed the instance images and prompt into stable diffusion model. 64 | 65 | ```bash 66 | accelerate launch --num_processes 1 train.py \ 67 | --pretrained_model_name_or_path=$MODEL_NAME \ 68 | --instance_data_dir=$INSTANCE_DIR \ 69 | --output_dir=$MODEL_DIR \ 70 | --instance_prompt=$INSTANCE_PROMPT \ 71 | --resolution=512 \ 72 | --train_batch_size=1 \ 73 | --gradient_accumulation_steps=1 \ 74 | --learning_rate=5e-6 \ 75 | --lr_scheduler="constant" \ 76 | --lr_warmup_steps=0 \ 77 | --max_train_steps=1000 78 | ``` 79 | 80 | ### 4. Image Composition 81 | Finally, you can provide new images to achieve image composition. 82 | 83 | ```bash 84 | python inference.py --image_path $Test_DIR \ 85 | --model_path $MODEL_DIR \ 86 | --out_path $OUT_DIR \ 87 | --instance_prompt $INSTANCE_PROMPT 88 | ``` 89 | 90 | ### Or else 91 | using end-to-end run_sd.sh. 92 | 93 | ```bash 94 | bash run.sh 95 | ``` 96 | 97 | ### GPU Memory 98 | We tested the code on RTX3090 GPU. If there is Out-of-Memory error, please refer to for low memory training: 99 | 100 | * [dreambooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) 101 | 102 | ### Authors: 103 | * [Tao Hu](https://tau-yihouxiang.github.io) 104 | * [RealityEditor](https://realityeditor.com.cn) 105 | -------------------------------------------------------------------------------- /data/T-shirt/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/T-shirt/.DS_Store -------------------------------------------------------------------------------- /data/T-shirt/gap-2697-8350236-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/T-shirt/gap-2697-8350236-1.jpg -------------------------------------------------------------------------------- /data/T-shirt/gap-2698-8350236-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/T-shirt/gap-2698-8350236-2.jpg -------------------------------------------------------------------------------- /data/T-shirt/gap-2699-8350236-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/T-shirt/gap-2699-8350236-3.jpg -------------------------------------------------------------------------------- /data/T-shirt/gap-2699-8350236-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/T-shirt/gap-2699-8350236-4.jpg -------------------------------------------------------------------------------- /data/T-shirt_test/5e46f9dd2dae5c4d4770c214.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/T-shirt_test/5e46f9dd2dae5c4d4770c214.jpg -------------------------------------------------------------------------------- /data/T-shirt_test/T-shirt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/T-shirt_test/T-shirt.jpg -------------------------------------------------------------------------------- /data/sofa/53bb82effa16547cea05419fd84edf90_1800x1800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/sofa/53bb82effa16547cea05419fd84edf90_1800x1800.jpg -------------------------------------------------------------------------------- /data/sofa/94e823f11f4f841adfddf2fe2dc87b02_1800x1800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/sofa/94e823f11f4f841adfddf2fe2dc87b02_1800x1800.jpg -------------------------------------------------------------------------------- /data/sofa/d157e324510ef44ede37a9dbea53bf30_1800x1800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/sofa/d157e324510ef44ede37a9dbea53bf30_1800x1800.jpg -------------------------------------------------------------------------------- /data/sofa/f6a0ea9a009f47a59d00aeec56378921_1800x1800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/sofa/f6a0ea9a009f47a59d00aeec56378921_1800x1800.jpg -------------------------------------------------------------------------------- /data/sofa_test/H-5168-12_8c8cc93e-163e-47a6-b8d0-4d3253f0b86b_900x.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/sofa_test/H-5168-12_8c8cc93e-163e-47a6-b8d0-4d3253f0b86b_900x.jpg -------------------------------------------------------------------------------- /data/sofa_test/XE3KtWoY2TgU7r2u2AJkBJ.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/data/sofa_test/XE3KtWoY2TgU7r2u2AJkBJ.jpg -------------------------------------------------------------------------------- /docs/image_composition-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/docs/image_composition-1.png -------------------------------------------------------------------------------- /docs/image_composition-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reality-Editor/Composition-Stable-Diffusion/d3b0a6b4b44c987891354a3e5f53e68a2880388a/docs/image_composition-2.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFilter 2 | import requests 3 | import numpy as np 4 | import glob, os 5 | import torch 6 | import argparse 7 | from torchvision import transforms 8 | from diffusers import StableDiffusionInpaintPipeline, DDPMScheduler 9 | from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Simple example of preprocessing daa.") 13 | parser.add_argument( 14 | "--image_path", 15 | type=str, 16 | required=True, 17 | help="Path to source directory.", 18 | ) 19 | 20 | parser.add_argument( 21 | "--instance_prompt", 22 | type=str, 23 | required=True, 24 | help="Path to output directory.", 25 | ) 26 | 27 | parser.add_argument( 28 | "--out_path", 29 | type=str, 30 | required=True, 31 | help="Path to output directory.", 32 | ) 33 | 34 | parser.add_argument( 35 | "--model_path", 36 | type=str, 37 | required=True, 38 | help="Path to destinate directory.", 39 | ) 40 | 41 | args = parser.parse_args() 42 | 43 | return args 44 | 45 | 46 | if __name__ == '__main__': 47 | args = parse_args() 48 | device = "cuda" if torch.cuda.is_available() else "cpu" 49 | 50 | pipe = StableDiffusionInpaintPipeline.from_pretrained(args.model_path) 51 | pipe = pipe.to(device) 52 | 53 | os.makedirs(f'{args.out_path}', exist_ok=True) 54 | 55 | if os.path.isdir(args.image_path): 56 | img_paths = glob.glob(os.path.join(args.image_path, '*.jpg')) 57 | img_paths.sort() 58 | else: 59 | img_paths = [args.image_path] 60 | 61 | # clipseg for image segmentation 62 | processor_clipseg = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") 63 | model_clipseg = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") 64 | model_clipseg.to(device) 65 | 66 | for img_path in img_paths: 67 | init_image = Image.open(img_path).convert("RGB") 68 | init_size = init_image.size 69 | init_image = init_image.resize((512, 512)) 70 | 71 | inputs_clipseg = processor_clipseg(text=[args.instance_prompt], images=[init_image], padding="max_length", return_tensors="pt").to(device) 72 | outputs = model_clipseg(**inputs_clipseg) 73 | preds = outputs.logits.unsqueeze(0)[0].detach().cpu() 74 | mask_image = transforms.ToPILImage()(torch.sigmoid(preds)).convert("L").resize((512, 512)) 75 | mask_image = mask_image.filter(ImageFilter.MaxFilter(21)) 76 | 77 | prompt = f"a photo with sks {args.instance_prompt}" 78 | image = pipe(prompt=prompt, image=init_image, 79 | mask_image=mask_image 80 | ).images[0] 81 | 82 | cat_image = Image.new('RGB', (512 * 2, 512)) 83 | 84 | masked_image = Image.composite(mask_image, init_image, mask_image) 85 | cat_image.paste(init_image, (512*0, 0)) 86 | cat_image.paste(image, (512*1, 0)) 87 | 88 | cat_image.save(f"{args.out_path}/{os.path.basename(img_path)}") 89 | -------------------------------------------------------------------------------- /inference_lora.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFilter 2 | import requests 3 | import numpy as np 4 | import glob, os 5 | import torch 6 | import argparse 7 | from torchvision import transforms 8 | from diffusers import StableDiffusionInpaintPipeline, DDPMScheduler 9 | from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Simple example of preprocessing daa.") 13 | parser.add_argument( 14 | "--image_path", 15 | type=str, 16 | required=True, 17 | help="Path to source directory.", 18 | ) 19 | 20 | parser.add_argument( 21 | "--instance_prompt", 22 | type=str, 23 | required=True, 24 | help="Path to output directory.", 25 | ) 26 | 27 | parser.add_argument( 28 | "--out_path", 29 | type=str, 30 | required=True, 31 | help="Path to output directory.", 32 | ) 33 | 34 | parser.add_argument( 35 | "--model_path", 36 | type=str, 37 | required=True, 38 | help="Path to destinate directory.", 39 | ) 40 | 41 | args = parser.parse_args() 42 | 43 | return args 44 | 45 | 46 | if __name__ == '__main__': 47 | args = parse_args() 48 | device = "cuda" if torch.cuda.is_available() else "cpu" 49 | 50 | model_name = "runwayml/stable-diffusion-inpainting" 51 | pipe = StableDiffusionInpaintPipeline.from_pretrained(model_name) 52 | pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) 53 | pipe.unet.load_attn_procs(args.model_path) 54 | pipe = pipe.to(device) 55 | 56 | os.makedirs(f'{args.out_path}', exist_ok=True) 57 | 58 | if os.path.isdir(args.image_path): 59 | img_paths = glob.glob(os.path.join(args.image_path, '*.jpg')) 60 | img_paths.sort() 61 | else: 62 | img_paths = [args.image_path] 63 | 64 | # clipseg for image segmentation 65 | processor_clipseg = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") 66 | model_clipseg = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") 67 | model_clipseg.to(device) 68 | 69 | for img_path in img_paths: 70 | init_image = Image.open(img_path).convert("RGB") 71 | init_size = init_image.size 72 | init_image = init_image.resize((512, 512)) 73 | 74 | inputs_clipseg = processor_clipseg(text=[args.instance_prompt], images=[init_image], padding="max_length", return_tensors="pt").to(device) 75 | outputs = model_clipseg(**inputs_clipseg) 76 | preds = outputs.logits.unsqueeze(0)[0].detach().cpu() 77 | mask_image = transforms.ToPILImage()(torch.sigmoid(preds)).convert("L").resize((512, 512)) 78 | mask_image = mask_image.filter(ImageFilter.MaxFilter(11)) 79 | 80 | prompt = f"a photo with sks {args.instance_prompt}" 81 | image = pipe(prompt=prompt, image=init_image, 82 | mask_image=mask_image 83 | ).images[0] 84 | 85 | cat_image = Image.new('RGB', (512 * 2, 512)) 86 | 87 | masked_image = Image.composite(mask_image, init_image, mask_image) 88 | cat_image.paste(init_image, (512*0, 0)) 89 | cat_image.paste(image, (512*1, 0)) 90 | 91 | cat_image.save(f"{args.out_path}/{os.path.basename(img_path)}") 92 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation 3 | # from transformers import Blip2Processor, Blip2ForConditionalGeneration 4 | from torchvision import transforms 5 | import torch 6 | import glob, os, tqdm 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="Simple example of preprocessing daa.") 11 | parser.add_argument( 12 | "--instance_data_dir", 13 | type=str, 14 | required=True, 15 | help="Path to source directory.", 16 | ) 17 | 18 | parser.add_argument( 19 | "--instance_prompt", 20 | type=str, 21 | required=True, 22 | help="target object to be composed.", 23 | ) 24 | 25 | args = parser.parse_args() 26 | 27 | return args 28 | 29 | if __name__ == '__main__': 30 | args = parse_args() 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | 34 | # # blip2 for image caption 35 | # processor_blip2 = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") 36 | # model_blip2 = Blip2ForConditionalGeneration.from_pretrained( 37 | # "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 38 | # ) 39 | # model_blip2.to(device) 40 | 41 | # clipseg for image segmentation 42 | processor_clipseg = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") 43 | model_clipseg = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") 44 | model_clipseg.to(device) 45 | 46 | img_files = glob.glob(os.path.join(args.instance_data_dir, "*.jpg")) 47 | 48 | for img_file in tqdm.tqdm(img_files): 49 | prompt_path = img_file[:-4] + '.txt' 50 | 51 | image = Image.open(img_file).convert("RGB") 52 | 53 | # # blip2 54 | # inputs_blip2 = processor_blip2(images=image, return_tensors="pt").to(device, torch.float16) 55 | # generated_ids = model_blip2.generate(**inputs_blip2) 56 | # generated_text = processor_blip2.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 57 | # with open(prompt_path, 'w') as f: 58 | # f.write(generated_text) 59 | 60 | # clipseg 61 | inputs_clipseg = processor_clipseg(text=[args.instance_prompt], images=[image], padding="max_length", return_tensors="pt").to('cuda') 62 | outputs = model_clipseg(**inputs_clipseg) 63 | preds = outputs.logits.unsqueeze(0)[0].detach().cpu() 64 | mask = transforms.ToPILImage()(torch.sigmoid(preds)).convert("L") 65 | mask.save(img_file[:-4] + '.png') 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.14.0 2 | accelerate==0.16.0 3 | transformers 4 | ftfy 5 | tensorboard==1.15.0 6 | Jinja2==3.1.2 7 | chardet 8 | protobuf>=3.20.2 9 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="runwayml/stable-diffusion-inpainting" 2 | export INSTANCE_DIR="data/sofa" 3 | export Test_DIR="data/sofa_test" 4 | export OUT_DIR="out/sofa" 5 | export INSTANCE_PROMPT="sofa" 6 | export MODEL_DIR="logs/sofa" 7 | 8 | # preprocess data 9 | python preprocess.py --instance_data_dir $INSTANCE_DIR \ 10 | --instance_prompt $INSTANCE_PROMPT 11 | 12 | # CUDA_VISIBLE_DEVICES=0 13 | accelerate launch --num_processes 1 train.py \ 14 | --pretrained_model_name_or_path=$MODEL_NAME \ 15 | --instance_data_dir=$INSTANCE_DIR \ 16 | --output_dir=$MODEL_DIR \ 17 | --instance_prompt=$INSTANCE_PROMPT \ 18 | --resolution=512 \ 19 | --train_batch_size=1 \ 20 | --gradient_accumulation_steps=1 \ 21 | --learning_rate=5e-6 \ 22 | --lr_scheduler="constant" \ 23 | --lr_warmup_steps=0 \ 24 | --max_train_steps=1000 25 | 26 | python inference.py --image_path $Test_DIR \ 27 | --model_path $MODEL_DIR \ 28 | --out_path $OUT_DIR \ 29 | --instance_prompt $INSTANCE_PROMPT -------------------------------------------------------------------------------- /run_lora.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="runwayml/stable-diffusion-inpainting" 2 | export INSTANCE_DIR="data/sofa" 3 | export Test_DIR="data/T-shirt_test" 4 | export OUT_DIR="out/sofa" 5 | export INSTANCE_PROMPT="sofa" 6 | export MODEL_DIR="logs/sofa/checkpoint-1000" 7 | 8 | # preprocess data 9 | python preprocess.py --instance_data_dir $INSTANCE_DIR \ 10 | --instance_prompt $INSTANCE_PROMPT 11 | 12 | # CUDA_VISIBLE_DEVICES=0 13 | accelerate launch --num_processes 1 train_lora.py \ 14 | --pretrained_model_name_or_path=$MODEL_NAME \ 15 | --instance_data_dir=$INSTANCE_DIR \ 16 | --output_dir=$MODEL_DIR \ 17 | --instance_prompt=$INSTANCE_PROMPT \ 18 | --resolution=512 \ 19 | --train_batch_size=1 \ 20 | --learning_rate=1e-4 \ 21 | --max_train_steps=2000 \ 22 | --checkpointing_steps 1000 23 | 24 | python inference.py --image_path $Test_DIR \ 25 | --model_path $MODEL_DIR \ 26 | --out_path $OUT_DIR \ 27 | --instance_prompt $INSTANCE_PROMPT -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import itertools 4 | import math 5 | import os 6 | import random 7 | from pathlib import Path 8 | from typing import Optional 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import set_seed 17 | from huggingface_hub import HfFolder, Repository, create_repo, whoami 18 | from PIL import Image, ImageDraw, ImageFilter 19 | from torch.utils.data import Dataset 20 | from torchvision import transforms 21 | from tqdm.auto import tqdm 22 | from transformers import CLIPTextModel, CLIPTokenizer 23 | 24 | from transformers import ( 25 | CLIPSegForImageSegmentation, 26 | CLIPSegProcessor, 27 | AutoProcessor, 28 | CLIPVisionModelWithProjection 29 | ) 30 | 31 | from diffusers import ( 32 | AutoencoderKL, 33 | DDPMScheduler, 34 | StableDiffusionInpaintPipeline, 35 | StableDiffusionPipeline, 36 | UNet2DConditionModel, 37 | ) 38 | from diffusers.optimization import get_scheduler 39 | from diffusers.utils import check_min_version 40 | 41 | 42 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 43 | check_min_version("0.13.0.dev0") 44 | 45 | logger = get_logger(__name__) 46 | 47 | 48 | def prepare_mask_and_masked_image(image, mask): 49 | image = np.array(image.convert("RGB")) 50 | image = image[None].transpose(0, 3, 1, 2) 51 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 52 | 53 | mask = np.array(mask.convert("L")) 54 | mask = mask.astype(np.float32) / 255.0 55 | mask = mask[None, None] 56 | mask[mask < 0.5] = 0 57 | mask[mask >= 0.5] = 1 58 | mask = torch.from_numpy(mask) 59 | 60 | masked_image = image * (mask < 0.5) 61 | 62 | return mask, masked_image 63 | 64 | 65 | # generate random masks 66 | def random_mask(im_shape, ratio=1, mask_full_image=False): 67 | mask = Image.new("L", im_shape, 0) 68 | draw = ImageDraw.Draw(mask) 69 | size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio))) 70 | # use this to always mask the whole image 71 | if mask_full_image: 72 | size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio)) 73 | limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2) 74 | center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1])) 75 | draw_type = random.randint(0, 1) 76 | if draw_type == 0 or mask_full_image: 77 | draw.rectangle( 78 | (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), 79 | fill=255, 80 | ) 81 | else: 82 | draw.ellipse( 83 | (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), 84 | fill=255, 85 | ) 86 | 87 | return mask 88 | 89 | 90 | def parse_args(): 91 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 92 | parser.add_argument( 93 | "--pretrained_model_name_or_path", 94 | type=str, 95 | default=None, 96 | required=True, 97 | help="Path to pretrained model or model identifier from huggingface.co/models.", 98 | ) 99 | parser.add_argument( 100 | "--tokenizer_name", 101 | type=str, 102 | default=None, 103 | help="Pretrained tokenizer name or path if not the same as model_name", 104 | ) 105 | parser.add_argument( 106 | "--instance_data_dir", 107 | type=str, 108 | default=None, 109 | required=True, 110 | help="A folder containing the training data of instance images.", 111 | ) 112 | parser.add_argument( 113 | "--instance_prompt", 114 | type=str, 115 | default=None, 116 | help="The prompt with identifier specifying the instance", 117 | ) 118 | parser.add_argument( 119 | "--with_prior_preservation", 120 | default=False, 121 | action="store_true", 122 | help="Flag to add prior preservation loss.", 123 | ) 124 | 125 | parser.add_argument( 126 | "--output_dir", 127 | type=str, 128 | default="text-inversion-model", 129 | help="The output directory where the model predictions and checkpoints will be written.", 130 | ) 131 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 132 | parser.add_argument( 133 | "--resolution", 134 | type=int, 135 | default=512, 136 | help=( 137 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 138 | " resolution" 139 | ), 140 | ) 141 | parser.add_argument( 142 | "--center_crop", 143 | default=False, 144 | action="store_true", 145 | help=( 146 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 147 | " cropped. The images will be resized to the resolution first before cropping." 148 | ), 149 | ) 150 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 151 | parser.add_argument( 152 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 153 | ) 154 | parser.add_argument( 155 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 156 | ) 157 | parser.add_argument("--num_train_epochs", type=int, default=1) 158 | parser.add_argument( 159 | "--max_train_steps", 160 | type=int, 161 | default=None, 162 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 163 | ) 164 | parser.add_argument( 165 | "--gradient_accumulation_steps", 166 | type=int, 167 | default=1, 168 | help="Number of updates steps to accumulate before performing a backward/update pass.", 169 | ) 170 | parser.add_argument( 171 | "--gradient_checkpointing", 172 | action="store_true", 173 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 174 | ) 175 | parser.add_argument( 176 | "--learning_rate", 177 | type=float, 178 | default=5e-6, 179 | help="Initial learning rate (after the potential warmup period) to use.", 180 | ) 181 | parser.add_argument( 182 | "--scale_lr", 183 | action="store_true", 184 | default=False, 185 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 186 | ) 187 | parser.add_argument( 188 | "--lr_scheduler", 189 | type=str, 190 | default="constant", 191 | help=( 192 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 193 | ' "constant", "constant_with_warmup"]' 194 | ), 195 | ) 196 | parser.add_argument( 197 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 198 | ) 199 | parser.add_argument( 200 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 201 | ) 202 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 203 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 204 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 205 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 206 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 207 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 208 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 209 | parser.add_argument( 210 | "--hub_model_id", 211 | type=str, 212 | default=None, 213 | help="The name of the repository to keep in sync with the local `output_dir`.", 214 | ) 215 | parser.add_argument( 216 | "--logging_dir", 217 | type=str, 218 | default="logs", 219 | help=( 220 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 221 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 222 | ), 223 | ) 224 | parser.add_argument( 225 | "--mixed_precision", 226 | type=str, 227 | default="no", 228 | choices=["no", "fp16", "bf16"], 229 | help=( 230 | "Whether to use mixed precision. Choose" 231 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 232 | "and an Nvidia Ampere GPU." 233 | ), 234 | ) 235 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 236 | parser.add_argument( 237 | "--checkpointing_steps", 238 | type=int, 239 | default=2000, 240 | help=( 241 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 242 | " checkpoints in case they are better than the last checkpoint and are suitable for resuming training" 243 | " using `--resume_from_checkpoint`." 244 | ), 245 | ) 246 | parser.add_argument( 247 | "--resume_from_checkpoint", 248 | type=str, 249 | default=None, 250 | help=( 251 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 252 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 253 | ), 254 | ) 255 | 256 | args = parser.parse_args() 257 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 258 | if env_local_rank != -1 and env_local_rank != args.local_rank: 259 | args.local_rank = env_local_rank 260 | 261 | if args.instance_data_dir is None: 262 | raise ValueError("You must specify a train data directory.") 263 | 264 | return args 265 | 266 | 267 | class DreamBoothDataset(Dataset): 268 | """ 269 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 270 | It pre-processes the images and the tokenizes prompts. 271 | """ 272 | 273 | def __init__( 274 | self, 275 | instance_data_root, 276 | instance_prompt, 277 | tokenizer, 278 | size=512, 279 | center_crop=False, 280 | ): 281 | self.size = size 282 | self.center_crop = center_crop 283 | self.tokenizer = tokenizer 284 | 285 | self.instance_data_root = Path(instance_data_root) 286 | if not self.instance_data_root.exists(): 287 | raise ValueError("Instance images root doesn't exists.") 288 | 289 | self.instance_images_path = list(Path(instance_data_root).glob("*.jpg")) 290 | self.num_instance_images = len(self.instance_images_path) 291 | self.instance_prompt = instance_prompt 292 | self._length = self.num_instance_images 293 | 294 | self.image_transforms_resize_and_crop = transforms.Compose( 295 | [ 296 | transforms.Resize((int(size), int(size)), interpolation=transforms.InterpolationMode.BILINEAR), 297 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 298 | ] 299 | ) 300 | 301 | self.image_transforms = transforms.Compose( 302 | [ 303 | transforms.ToTensor(), 304 | transforms.Normalize([0.5], [0.5]), 305 | ] 306 | ) 307 | 308 | def __len__(self): 309 | return self._length 310 | 311 | def __getitem__(self, index): 312 | example = {} 313 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 314 | if not instance_image.mode == "RGB": 315 | instance_image = instance_image.convert("RGB") 316 | instance_image = self.image_transforms_resize_and_crop(instance_image) 317 | 318 | example["PIL_images"] = instance_image 319 | example["instance_images"] = self.image_transforms(instance_image) 320 | 321 | # mask 322 | instance_mask = Image.open(self.instance_images_path[index % self.num_instance_images].with_suffix('.png')).resize(instance_image.size) 323 | instance_mask = instance_mask.filter(ImageFilter.MaxFilter(21)) 324 | if not instance_mask.mode == "RGB": 325 | instance_mask = instance_mask.convert("RGB") 326 | instance_mask = self.image_transforms_resize_and_crop(instance_mask) 327 | example["PIL_masks"] = instance_mask 328 | example["instance_masks"] = self.image_transforms(instance_mask) 329 | 330 | instance_prompt = 'a photo with sks ' + self.instance_prompt 331 | 332 | example["instance_prompt_ids"] = self.tokenizer( 333 | instance_prompt, 334 | padding="do_not_pad", 335 | truncation=True, 336 | max_length=self.tokenizer.model_max_length, 337 | ).input_ids 338 | 339 | return example 340 | 341 | 342 | class PromptDataset(Dataset): 343 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 344 | 345 | def __init__(self, prompt, num_samples): 346 | self.prompt = prompt 347 | self.num_samples = num_samples 348 | 349 | def __len__(self): 350 | return self.num_samples 351 | 352 | def __getitem__(self, index): 353 | example = {} 354 | example["prompt"] = self.prompt 355 | example["index"] = index 356 | return example 357 | 358 | 359 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 360 | if token is None: 361 | token = HfFolder.get_token() 362 | if organization is None: 363 | username = whoami(token)["name"] 364 | return f"{username}/{model_id}" 365 | else: 366 | return f"{organization}/{model_id}" 367 | 368 | 369 | def main(): 370 | args = parse_args() 371 | logging_dir = Path(args.output_dir, args.logging_dir) 372 | 373 | accelerator = Accelerator( 374 | gradient_accumulation_steps=args.gradient_accumulation_steps, 375 | mixed_precision=args.mixed_precision, 376 | log_with="tensorboard", 377 | logging_dir=logging_dir, 378 | ) 379 | 380 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 381 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 382 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 383 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 384 | raise ValueError( 385 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 386 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 387 | ) 388 | 389 | if args.seed is not None: 390 | set_seed(args.seed) 391 | 392 | # Handle the repository creation 393 | if accelerator.is_main_process: 394 | if args.push_to_hub: 395 | if args.hub_model_id is None: 396 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 397 | else: 398 | repo_name = args.hub_model_id 399 | create_repo(repo_name, exist_ok=True, token=args.hub_token) 400 | repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) 401 | 402 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 403 | if "step_*" not in gitignore: 404 | gitignore.write("step_*\n") 405 | if "epoch_*" not in gitignore: 406 | gitignore.write("epoch_*\n") 407 | elif args.output_dir is not None: 408 | os.makedirs(args.output_dir, exist_ok=True) 409 | 410 | # Load the tokenizer 411 | if args.tokenizer_name: 412 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 413 | elif args.pretrained_model_name_or_path: 414 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 415 | 416 | # Load models and create wrapper for stable diffusion 417 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 418 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 419 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 420 | 421 | vae.requires_grad_(False) 422 | if not args.train_text_encoder: 423 | text_encoder.requires_grad_(False) 424 | 425 | if args.gradient_checkpointing: 426 | unet.enable_gradient_checkpointing() 427 | if args.train_text_encoder: 428 | text_encoder.gradient_checkpointing_enable() 429 | 430 | if args.scale_lr: 431 | args.learning_rate = ( 432 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 433 | ) 434 | 435 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 436 | if args.use_8bit_adam: 437 | try: 438 | import bitsandbytes as bnb 439 | except ImportError: 440 | raise ImportError( 441 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 442 | ) 443 | 444 | optimizer_class = bnb.optim.AdamW8bit 445 | else: 446 | optimizer_class = torch.optim.AdamW 447 | 448 | params_to_optimize = ( 449 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 450 | ) 451 | optimizer = optimizer_class( 452 | params_to_optimize, 453 | lr=args.learning_rate, 454 | betas=(args.adam_beta1, args.adam_beta2), 455 | weight_decay=args.adam_weight_decay, 456 | eps=args.adam_epsilon, 457 | ) 458 | 459 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 460 | 461 | train_dataset = DreamBoothDataset( 462 | instance_data_root=args.instance_data_dir, 463 | instance_prompt=args.instance_prompt, 464 | tokenizer=tokenizer, 465 | size=args.resolution, 466 | center_crop=args.center_crop, 467 | ) 468 | 469 | def collate_fn(examples): 470 | input_ids = [example["instance_prompt_ids"] for example in examples] 471 | pixel_values = [example["instance_images"] for example in examples] 472 | 473 | # Concat class and instance examples for prior preservation. 474 | # We do this to avoid doing two forward passes. 475 | 476 | masks = [] 477 | masked_images = [] 478 | for example in examples: 479 | pil_image = example["PIL_images"] 480 | # generate a random mask 481 | # mask = random_mask(pil_image.size, 1, False) 482 | mask = example["PIL_masks"] 483 | # prepare mask and masked image 484 | mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) 485 | 486 | masks.append(mask) 487 | masked_images.append(masked_image) 488 | 489 | if args.with_prior_preservation: 490 | for pil_image, mask in zip(pior_pil, pior_pil_mask): 491 | # generate a random mask 492 | # mask = random_mask(pil_image.size, 1, False) 493 | # prepare mask and masked image 494 | mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) 495 | 496 | masks.append(mask) 497 | masked_images.append(masked_image) 498 | 499 | pixel_values = torch.stack(pixel_values) 500 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 501 | 502 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 503 | masks = torch.stack(masks) 504 | masked_images = torch.stack(masked_images) 505 | batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} 506 | return batch 507 | 508 | train_dataloader = torch.utils.data.DataLoader( 509 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn 510 | ) 511 | 512 | # Scheduler and math around the number of training steps. 513 | overrode_max_train_steps = False 514 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 515 | if args.max_train_steps is None: 516 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 517 | overrode_max_train_steps = True 518 | 519 | lr_scheduler = get_scheduler( 520 | args.lr_scheduler, 521 | optimizer=optimizer, 522 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 523 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 524 | ) 525 | 526 | if args.train_text_encoder: 527 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 528 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 529 | ) 530 | else: 531 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 532 | unet, optimizer, train_dataloader, lr_scheduler 533 | ) 534 | accelerator.register_for_checkpointing(lr_scheduler) 535 | 536 | weight_dtype = torch.float32 537 | if args.mixed_precision == "fp16": 538 | weight_dtype = torch.float16 539 | elif args.mixed_precision == "bf16": 540 | weight_dtype = torch.bfloat16 541 | 542 | # Move text_encode and vae to gpu. 543 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 544 | # as these models are only used for inference, keeping weights in full precision is not required. 545 | vae.to(accelerator.device, dtype=weight_dtype) 546 | 547 | if not args.train_text_encoder: 548 | text_encoder.to(accelerator.device, dtype=weight_dtype) 549 | 550 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 551 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 552 | if overrode_max_train_steps: 553 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 554 | # Afterwards we recalculate our number of training epochs 555 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 556 | 557 | # We need to initialize the trackers we use, and also store our configuration. 558 | # The trackers initializes automatically on the main process. 559 | if accelerator.is_main_process: 560 | accelerator.init_trackers("dreambooth", config=vars(args)) 561 | 562 | # Train! 563 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 564 | 565 | logger.info("***** Running training *****") 566 | logger.info(f" Num examples = {len(train_dataset)}") 567 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 568 | logger.info(f" Num Epochs = {args.num_train_epochs}") 569 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 570 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 571 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 572 | logger.info(f" Total optimization steps = {args.max_train_steps}") 573 | global_step = 0 574 | first_epoch = 0 575 | 576 | if args.resume_from_checkpoint: 577 | if args.resume_from_checkpoint != "latest": 578 | path = os.path.basename(args.resume_from_checkpoint) 579 | else: 580 | # Get the most recent checkpoint 581 | dirs = os.listdir(args.output_dir) 582 | dirs = [d for d in dirs if d.startswith("checkpoint")] 583 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 584 | path = dirs[-1] if len(dirs) > 0 else None 585 | 586 | if path is None: 587 | accelerator.print( 588 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 589 | ) 590 | args.resume_from_checkpoint = None 591 | else: 592 | accelerator.print(f"Resuming from checkpoint {path}") 593 | accelerator.load_state(os.path.join(args.output_dir, path)) 594 | global_step = int(path.split("-")[1]) 595 | 596 | resume_global_step = global_step * args.gradient_accumulation_steps 597 | first_epoch = global_step // num_update_steps_per_epoch 598 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 599 | 600 | # Only show the progress bar once on each machine. 601 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 602 | progress_bar.set_description("Steps") 603 | 604 | for epoch in range(first_epoch, args.num_train_epochs): 605 | unet.train() 606 | for step, batch in enumerate(train_dataloader): 607 | # Skip steps until we reach the resumed step 608 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 609 | if step % args.gradient_accumulation_steps == 0: 610 | progress_bar.update(1) 611 | continue 612 | 613 | with accelerator.accumulate(unet): 614 | # Convert images to latent space 615 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 616 | latents = latents * vae.config.scaling_factor 617 | 618 | # Convert masked images to latent space 619 | masked_latents = vae.encode( 620 | batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) 621 | ).latent_dist.sample() 622 | masked_latents = masked_latents * vae.config.scaling_factor 623 | 624 | masks = batch["masks"] 625 | # resize the mask to latents shape as we concatenate the mask to the latents 626 | mask = torch.stack( 627 | [ 628 | torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) 629 | for mask in masks 630 | ] 631 | ) 632 | mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) 633 | 634 | # Sample noise that we'll add to the latents 635 | noise = torch.randn_like(latents) 636 | bsz = latents.shape[0] 637 | # Sample a random timestep for each image 638 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 639 | timesteps = timesteps.long() 640 | 641 | # Add noise to the latents according to the noise magnitude at each timestep 642 | # (this is the forward diffusion process) 643 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 644 | 645 | # concatenate the noised latents with the mask and the masked latents 646 | latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1) 647 | 648 | # Get the text embedding for conditioning 649 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 650 | 651 | # Predict the noise residual 652 | noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample 653 | 654 | # Get the target for loss depending on the prediction type 655 | if noise_scheduler.config.prediction_type == "epsilon": 656 | target = noise 657 | elif noise_scheduler.config.prediction_type == "v_prediction": 658 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 659 | else: 660 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 661 | 662 | if args.with_prior_preservation: 663 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 664 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 665 | target, target_prior = torch.chunk(target, 2, dim=0) 666 | 667 | # Compute instance loss 668 | loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 669 | 670 | # Compute prior loss 671 | prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean") 672 | 673 | # Add the prior loss to the instance loss. 674 | loss = loss + args.prior_loss_weight * prior_loss 675 | else: 676 | loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") 677 | 678 | accelerator.backward(loss) 679 | if accelerator.sync_gradients: 680 | params_to_clip = ( 681 | itertools.chain(unet.parameters(), text_encoder.parameters()) 682 | if args.train_text_encoder 683 | else unet.parameters() 684 | ) 685 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 686 | optimizer.step() 687 | lr_scheduler.step() 688 | optimizer.zero_grad() 689 | 690 | # Checks if the accelerator has performed an optimization step behind the scenes 691 | if accelerator.sync_gradients: 692 | progress_bar.update(1) 693 | global_step += 1 694 | 695 | if global_step % args.checkpointing_steps == 0: 696 | if accelerator.is_main_process: 697 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 698 | accelerator.save_state(save_path) 699 | logger.info(f"Saved state to {save_path}") 700 | 701 | pipeline = StableDiffusionPipeline.from_pretrained( 702 | args.pretrained_model_name_or_path, 703 | unet=accelerator.unwrap_model(unet), 704 | text_encoder=accelerator.unwrap_model(text_encoder), 705 | ) 706 | pipeline.save_pretrained(save_path) 707 | 708 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 709 | progress_bar.set_postfix(**logs) 710 | accelerator.log(logs, step=global_step) 711 | 712 | if global_step >= args.max_train_steps: 713 | break 714 | 715 | accelerator.wait_for_everyone() 716 | 717 | # Create the pipeline using using the trained modules and save it. 718 | if accelerator.is_main_process: 719 | pipeline = StableDiffusionPipeline.from_pretrained( 720 | args.pretrained_model_name_or_path, 721 | unet=accelerator.unwrap_model(unet), 722 | text_encoder=accelerator.unwrap_model(text_encoder), 723 | ) 724 | pipeline.save_pretrained(args.output_dir) 725 | 726 | if args.push_to_hub: 727 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 728 | 729 | accelerator.end_training() 730 | 731 | 732 | if __name__ == "__main__": 733 | main() 734 | -------------------------------------------------------------------------------- /train_lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import math 4 | import os 5 | import random 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.utils.checkpoint 13 | from accelerate import Accelerator 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from huggingface_hub import HfFolder, Repository, create_repo, whoami 17 | from PIL import Image, ImageDraw 18 | from torch.utils.data import Dataset 19 | from torchvision import transforms 20 | from tqdm.auto import tqdm 21 | from transformers import CLIPTextModel, CLIPTokenizer 22 | 23 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel 24 | from diffusers.loaders import AttnProcsLayers 25 | from diffusers.models.cross_attention import LoRACrossAttnProcessor 26 | from diffusers.optimization import get_scheduler 27 | from diffusers.utils import check_min_version 28 | from diffusers.utils.import_utils import is_xformers_available 29 | 30 | from PIL import Image, ImageDraw, ImageFilter 31 | 32 | 33 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 34 | check_min_version("0.13.0.dev0") 35 | 36 | logger = get_logger(__name__) 37 | 38 | 39 | def prepare_mask_and_masked_image(image, mask): 40 | image = np.array(image.convert("RGB")) 41 | image = image[None].transpose(0, 3, 1, 2) 42 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 43 | 44 | mask = np.array(mask.convert("L")) 45 | mask = mask.astype(np.float32) / 255.0 46 | mask = mask[None, None] 47 | mask[mask < 0.5] = 0 48 | mask[mask >= 0.5] = 1 49 | mask = torch.from_numpy(mask) 50 | 51 | masked_image = image * (mask < 0.5) 52 | 53 | return mask, masked_image 54 | 55 | 56 | # generate random masks 57 | def random_mask(im_shape, ratio=1, mask_full_image=False): 58 | mask = Image.new("L", im_shape, 0) 59 | draw = ImageDraw.Draw(mask) 60 | size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio))) 61 | # use this to always mask the whole image 62 | if mask_full_image: 63 | size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio)) 64 | limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2) 65 | center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1])) 66 | draw_type = random.randint(0, 1) 67 | if draw_type == 0 or mask_full_image: 68 | draw.rectangle( 69 | (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), 70 | fill=255, 71 | ) 72 | else: 73 | draw.ellipse( 74 | (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), 75 | fill=255, 76 | ) 77 | 78 | return mask 79 | 80 | 81 | def parse_args(): 82 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 83 | parser.add_argument( 84 | "--pretrained_model_name_or_path", 85 | type=str, 86 | default=None, 87 | required=True, 88 | help="Path to pretrained model or model identifier from huggingface.co/models.", 89 | ) 90 | parser.add_argument( 91 | "--tokenizer_name", 92 | type=str, 93 | default=None, 94 | help="Pretrained tokenizer name or path if not the same as model_name", 95 | ) 96 | parser.add_argument( 97 | "--instance_data_dir", 98 | type=str, 99 | default=None, 100 | required=True, 101 | help="A folder containing the training data of instance images.", 102 | ) 103 | parser.add_argument( 104 | "--instance_prompt", 105 | type=str, 106 | default=None, 107 | help="The prompt with identifier specifying the instance", 108 | ) 109 | parser.add_argument( 110 | "--output_dir", 111 | type=str, 112 | default="tensorboard", 113 | help="The output directory where the model predictions and checkpoints will be written.", 114 | ) 115 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 116 | parser.add_argument( 117 | "--resolution", 118 | type=int, 119 | default=512, 120 | help=( 121 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 122 | " resolution" 123 | ), 124 | ) 125 | parser.add_argument( 126 | "--center_crop", 127 | default=False, 128 | action="store_true", 129 | help=( 130 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 131 | " cropped. The images will be resized to the resolution first before cropping." 132 | ), 133 | ) 134 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 135 | parser.add_argument( 136 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 137 | ) 138 | parser.add_argument( 139 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 140 | ) 141 | parser.add_argument("--num_train_epochs", type=int, default=1) 142 | parser.add_argument( 143 | "--max_train_steps", 144 | type=int, 145 | default=None, 146 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 147 | ) 148 | parser.add_argument( 149 | "--gradient_accumulation_steps", 150 | type=int, 151 | default=1, 152 | help="Number of updates steps to accumulate before performing a backward/update pass.", 153 | ) 154 | parser.add_argument( 155 | "--gradient_checkpointing", 156 | action="store_true", 157 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 158 | ) 159 | parser.add_argument( 160 | "--learning_rate", 161 | type=float, 162 | default=5e-6, 163 | help="Initial learning rate (after the potential warmup period) to use.", 164 | ) 165 | parser.add_argument( 166 | "--scale_lr", 167 | action="store_true", 168 | default=False, 169 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 170 | ) 171 | parser.add_argument( 172 | "--lr_scheduler", 173 | type=str, 174 | default="constant", 175 | help=( 176 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 177 | ' "constant", "constant_with_warmup"]' 178 | ), 179 | ) 180 | parser.add_argument( 181 | "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 182 | ) 183 | parser.add_argument( 184 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 185 | ) 186 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 187 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 188 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 189 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 190 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 191 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 192 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 193 | parser.add_argument( 194 | "--hub_model_id", 195 | type=str, 196 | default=None, 197 | help="The name of the repository to keep in sync with the local `output_dir`.", 198 | ) 199 | parser.add_argument( 200 | "--logging_dir", 201 | type=str, 202 | default="logs", 203 | help=( 204 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 205 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 206 | ), 207 | ) 208 | parser.add_argument( 209 | "--mixed_precision", 210 | type=str, 211 | default="no", 212 | choices=["no", "fp16", "bf16"], 213 | help=( 214 | "Whether to use mixed precision. Choose" 215 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 216 | "and an Nvidia Ampere GPU." 217 | ), 218 | ) 219 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 220 | parser.add_argument( 221 | "--checkpointing_steps", 222 | type=int, 223 | default=2000, 224 | help=( 225 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 226 | " checkpoints in case they are better than the last checkpoint and are suitable for resuming training" 227 | " using `--resume_from_checkpoint`." 228 | ), 229 | ) 230 | parser.add_argument( 231 | "--resume_from_checkpoint", 232 | type=str, 233 | default=None, 234 | help=( 235 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 236 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 237 | ), 238 | ) 239 | parser.add_argument( 240 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 241 | ) 242 | 243 | args = parser.parse_args() 244 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 245 | if env_local_rank != -1 and env_local_rank != args.local_rank: 246 | args.local_rank = env_local_rank 247 | 248 | if args.instance_data_dir is None: 249 | raise ValueError("You must specify a train data directory.") 250 | 251 | return args 252 | 253 | 254 | def merge_rgb_mask_to_rgba(rgb, mask): 255 | rgba_image = rgb.copy().convert('RGBA') 256 | alpha = mask.convert('L') 257 | rgba_image.putalpha(alpha) 258 | return rgba_image 259 | 260 | class DreamBoothDataset(Dataset): 261 | """ 262 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 263 | It pre-processes the images and the tokenizes prompts. 264 | """ 265 | 266 | def __init__( 267 | self, 268 | instance_data_root, 269 | instance_prompt, 270 | tokenizer, 271 | size=512, 272 | center_crop=False, 273 | ): 274 | self.size = size 275 | self.center_crop = center_crop 276 | self.tokenizer = tokenizer 277 | 278 | self.instance_data_root = Path(instance_data_root) 279 | if not self.instance_data_root.exists(): 280 | raise ValueError("Instance images root doesn't exists.") 281 | 282 | self.instance_images_path = list(Path(instance_data_root).glob("*.jpg")) 283 | self.num_instance_images = len(self.instance_images_path) 284 | self.instance_prompt = instance_prompt 285 | self._length = self.num_instance_images 286 | 287 | self.image_transforms_resize_and_crop = transforms.Compose( 288 | [ 289 | transforms.Resize((int(size * 1.2), int(size * 1.2)), interpolation=transforms.InterpolationMode.BILINEAR), 290 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 291 | ] 292 | ) 293 | 294 | self.image_transforms = transforms.Compose( 295 | [ 296 | transforms.ToTensor(), 297 | transforms.Normalize([0.5], [0.5]), 298 | ] 299 | ) 300 | 301 | def __len__(self): 302 | return self._length 303 | 304 | def __getitem__(self, index): 305 | example = {} 306 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 307 | if not instance_image.mode == "RGB": 308 | instance_image = instance_image.convert("RGB") 309 | 310 | # mask 311 | instance_mask = Image.open(self.instance_images_path[index % self.num_instance_images].with_suffix('.png')) 312 | instance_mask = instance_mask.filter(ImageFilter.MaxFilter(11)).resize(instance_image.size) 313 | 314 | rgba = merge_rgb_mask_to_rgba(instance_image, instance_mask) 315 | rgba = self.image_transforms_resize_and_crop(rgba) 316 | r, g, b, a = rgba.split() 317 | instance_image = Image.merge('RGB', (r, g, b)) 318 | instance_mask = a.convert('RGB') 319 | 320 | example["PIL_images"] = instance_image 321 | example["instance_images"] = self.image_transforms(instance_image) 322 | example["PIL_masks"] = instance_mask 323 | 324 | # prompt_path = self.instance_images_path[index % self.num_instance_images].with_suffix('.txt') 325 | # if os.path.exists(prompt_path): 326 | # with open(prompt_path) as f: 327 | # instance_prompt = f.readline().strip() 328 | # instance_prompt = instance_prompt + ', sks ' + self.instance_prompt 329 | # else: 330 | instance_prompt = 'a photo with sks ' + self.instance_prompt 331 | 332 | example["instance_prompt_ids"] = self.tokenizer( 333 | instance_prompt, 334 | padding="do_not_pad", 335 | truncation=True, 336 | max_length=self.tokenizer.model_max_length, 337 | ).input_ids 338 | 339 | return example 340 | 341 | 342 | class PromptDataset(Dataset): 343 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 344 | 345 | def __init__(self, prompt, num_samples): 346 | self.prompt = prompt 347 | self.num_samples = num_samples 348 | 349 | def __len__(self): 350 | return self.num_samples 351 | 352 | def __getitem__(self, index): 353 | example = {} 354 | example["prompt"] = self.prompt 355 | example["index"] = index 356 | return example 357 | 358 | 359 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 360 | if token is None: 361 | token = HfFolder.get_token() 362 | if organization is None: 363 | username = whoami(token)["name"] 364 | return f"{username}/{model_id}" 365 | else: 366 | return f"{organization}/{model_id}" 367 | 368 | 369 | def main(): 370 | args = parse_args() 371 | logging_dir = Path(args.output_dir, args.logging_dir) 372 | 373 | accelerator = Accelerator( 374 | gradient_accumulation_steps=args.gradient_accumulation_steps, 375 | mixed_precision=args.mixed_precision, 376 | log_with="tensorboard", 377 | logging_dir=logging_dir, 378 | ) 379 | 380 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 381 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 382 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 383 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 384 | raise ValueError( 385 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 386 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 387 | ) 388 | 389 | if args.seed is not None: 390 | set_seed(args.seed) 391 | 392 | # Handle the repository creation 393 | if accelerator.is_main_process: 394 | if args.push_to_hub: 395 | if args.hub_model_id is None: 396 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 397 | else: 398 | repo_name = args.hub_model_id 399 | create_repo(repo_name, exist_ok=True, token=args.hub_token) 400 | repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) 401 | 402 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 403 | if "step_*" not in gitignore: 404 | gitignore.write("step_*\n") 405 | if "epoch_*" not in gitignore: 406 | gitignore.write("epoch_*\n") 407 | elif args.output_dir is not None: 408 | os.makedirs(args.output_dir, exist_ok=True) 409 | 410 | # Load the tokenizer 411 | if args.tokenizer_name: 412 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 413 | elif args.pretrained_model_name_or_path: 414 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 415 | 416 | # Load models and create wrapper for stable diffusion 417 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 418 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 419 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 420 | 421 | # We only train the additional adapter LoRA layers 422 | vae.requires_grad_(False) 423 | text_encoder.requires_grad_(False) 424 | unet.requires_grad_(False) 425 | 426 | weight_dtype = torch.float32 427 | if args.mixed_precision == "fp16": 428 | weight_dtype = torch.float16 429 | elif args.mixed_precision == "bf16": 430 | weight_dtype = torch.bfloat16 431 | 432 | # Move text_encode and vae to gpu. 433 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 434 | # as these models are only used for inference, keeping weights in full precision is not required. 435 | unet.to(accelerator.device, dtype=weight_dtype) 436 | vae.to(accelerator.device, dtype=weight_dtype) 437 | text_encoder.to(accelerator.device, dtype=weight_dtype) 438 | 439 | if args.enable_xformers_memory_efficient_attention: 440 | if is_xformers_available(): 441 | unet.enable_xformers_memory_efficient_attention() 442 | else: 443 | raise ValueError("xformers is not available. Make sure it is installed correctly") 444 | 445 | # now we will add new LoRA weights to the attention layers 446 | # It's important to realize here how many attention weights will be added and of which sizes 447 | # The sizes of the attention layers consist only of two different variables: 448 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. 449 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. 450 | 451 | # Let's first see how many attention processors we will have to set. 452 | # For Stable Diffusion, it should be equal to: 453 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 454 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 455 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 456 | # => 32 layers 457 | 458 | # Set correct lora layers 459 | lora_attn_procs = {} 460 | for name in unet.attn_processors.keys(): 461 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 462 | if name.startswith("mid_block"): 463 | hidden_size = unet.config.block_out_channels[-1] 464 | elif name.startswith("up_blocks"): 465 | block_id = int(name[len("up_blocks.")]) 466 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 467 | elif name.startswith("down_blocks"): 468 | block_id = int(name[len("down_blocks.")]) 469 | hidden_size = unet.config.block_out_channels[block_id] 470 | 471 | lora_attn_procs[name] = LoRACrossAttnProcessor( 472 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim 473 | ) 474 | 475 | unet.set_attn_processor(lora_attn_procs) 476 | lora_layers = AttnProcsLayers(unet.attn_processors) 477 | 478 | accelerator.register_for_checkpointing(lora_layers) 479 | 480 | if args.scale_lr: 481 | args.learning_rate = ( 482 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 483 | ) 484 | 485 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 486 | if args.use_8bit_adam: 487 | try: 488 | import bitsandbytes as bnb 489 | except ImportError: 490 | raise ImportError( 491 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 492 | ) 493 | 494 | optimizer_class = bnb.optim.AdamW8bit 495 | else: 496 | optimizer_class = torch.optim.AdamW 497 | 498 | optimizer = optimizer_class( 499 | lora_layers.parameters(), 500 | lr=args.learning_rate, 501 | betas=(args.adam_beta1, args.adam_beta2), 502 | weight_decay=args.adam_weight_decay, 503 | eps=args.adam_epsilon, 504 | ) 505 | 506 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 507 | 508 | train_dataset = DreamBoothDataset( 509 | instance_data_root=args.instance_data_dir, 510 | instance_prompt=args.instance_prompt, 511 | tokenizer=tokenizer, 512 | size=args.resolution, 513 | center_crop=args.center_crop, 514 | ) 515 | 516 | def collate_fn(examples): 517 | input_ids = [example["instance_prompt_ids"] for example in examples] 518 | pixel_values = [example["instance_images"] for example in examples] 519 | 520 | masks = [] 521 | masked_images = [] 522 | for example in examples: 523 | pil_image = example["PIL_images"] 524 | # generate a random mask 525 | mask = example["PIL_masks"] 526 | # prepare mask and masked image 527 | mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) 528 | 529 | masks.append(mask) 530 | masked_images.append(masked_image) 531 | 532 | pixel_values = torch.stack(pixel_values) 533 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 534 | 535 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 536 | masks = torch.stack(masks) 537 | masked_images = torch.stack(masked_images) 538 | batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} 539 | return batch 540 | 541 | train_dataloader = torch.utils.data.DataLoader( 542 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn 543 | ) 544 | 545 | # Scheduler and math around the number of training steps. 546 | overrode_max_train_steps = False 547 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 548 | if args.max_train_steps is None: 549 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 550 | overrode_max_train_steps = True 551 | 552 | lr_scheduler = get_scheduler( 553 | args.lr_scheduler, 554 | optimizer=optimizer, 555 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 556 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 557 | ) 558 | 559 | # Prepare everything with our `accelerator`. 560 | lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 561 | lora_layers, optimizer, train_dataloader, lr_scheduler 562 | ) 563 | # accelerator.register_for_checkpointing(lr_scheduler) 564 | 565 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 566 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 567 | if overrode_max_train_steps: 568 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 569 | # Afterwards we recalculate our number of training epochs 570 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 571 | 572 | # We need to initialize the trackers we use, and also store our configuration. 573 | # The trackers initializes automatically on the main process. 574 | if accelerator.is_main_process: 575 | accelerator.init_trackers("tensorboard", config=vars(args)) 576 | 577 | # Train! 578 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 579 | 580 | logger.info("***** Running training *****") 581 | logger.info(f" Num examples = {len(train_dataset)}") 582 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 583 | logger.info(f" Num Epochs = {args.num_train_epochs}") 584 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 585 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 586 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 587 | logger.info(f" Total optimization steps = {args.max_train_steps}") 588 | global_step = 0 589 | first_epoch = 0 590 | 591 | if args.resume_from_checkpoint: 592 | if args.resume_from_checkpoint != "latest": 593 | path = os.path.basename(args.resume_from_checkpoint) 594 | else: 595 | # Get the most recent checkpoint 596 | dirs = os.listdir(args.output_dir) 597 | dirs = [d for d in dirs if d.startswith("checkpoint")] 598 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 599 | path = dirs[-1] if len(dirs) > 0 else None 600 | 601 | if path is None: 602 | accelerator.print( 603 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 604 | ) 605 | args.resume_from_checkpoint = None 606 | else: 607 | accelerator.print(f"Resuming from checkpoint {path}") 608 | accelerator.load_state(os.path.join(args.output_dir, path)) 609 | global_step = int(path.split("-")[1]) 610 | 611 | resume_global_step = global_step * args.gradient_accumulation_steps 612 | first_epoch = global_step // num_update_steps_per_epoch 613 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 614 | 615 | # Only show the progress bar once on each machine. 616 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 617 | progress_bar.set_description("Steps") 618 | 619 | for epoch in range(first_epoch, args.num_train_epochs): 620 | unet.train() 621 | for step, batch in enumerate(train_dataloader): 622 | # Skip steps until we reach the resumed step 623 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 624 | if step % args.gradient_accumulation_steps == 0: 625 | progress_bar.update(1) 626 | continue 627 | 628 | with accelerator.accumulate(unet): 629 | # Convert images to latent space 630 | 631 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 632 | latents = latents * vae.config.scaling_factor 633 | 634 | # Convert masked images to latent space 635 | masked_latents = vae.encode( 636 | batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) 637 | ).latent_dist.sample() 638 | masked_latents = masked_latents * vae.config.scaling_factor 639 | 640 | masks = batch["masks"] 641 | # resize the mask to latents shape as we concatenate the mask to the latents 642 | mask = torch.stack( 643 | [ 644 | torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) 645 | for mask in masks 646 | ] 647 | ) 648 | mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) 649 | 650 | # Sample noise that we'll add to the latents 651 | noise = torch.randn_like(latents) 652 | bsz = latents.shape[0] 653 | # Sample a random timestep for each image 654 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 655 | timesteps = timesteps.long() 656 | 657 | # Add noise to the latents according to the noise magnitude at each timestep 658 | # (this is the forward diffusion process) 659 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 660 | 661 | # concatenate the noised latents with the mask and the masked latents 662 | latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1) 663 | 664 | # Get the text embedding for conditioning 665 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 666 | 667 | # Predict the noise residual 668 | noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample 669 | 670 | # Get the target for loss depending on the prediction type 671 | if noise_scheduler.config.prediction_type == "epsilon": 672 | target = noise 673 | elif noise_scheduler.config.prediction_type == "v_prediction": 674 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 675 | else: 676 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 677 | 678 | loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") 679 | 680 | accelerator.backward(loss) 681 | if accelerator.sync_gradients: 682 | params_to_clip = lora_layers.parameters() 683 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 684 | optimizer.step() 685 | lr_scheduler.step() 686 | optimizer.zero_grad() 687 | 688 | # Checks if the accelerator has performed an optimization step behind the scenes 689 | if accelerator.sync_gradients: 690 | progress_bar.update(1) 691 | global_step += 1 692 | 693 | if global_step % args.checkpointing_steps == 0: 694 | if accelerator.is_main_process: 695 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 696 | accelerator.save_state(save_path) 697 | logger.info(f"Saved state to {save_path}") 698 | 699 | unet.to(torch.float32).save_attn_procs(save_path) 700 | 701 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 702 | progress_bar.set_postfix(**logs) 703 | accelerator.log(logs, step=global_step) 704 | 705 | if global_step >= args.max_train_steps: 706 | break 707 | 708 | accelerator.wait_for_everyone() 709 | 710 | # Save the lora layers 711 | if accelerator.is_main_process: 712 | unet = unet.to(torch.float32) 713 | unet.save_attn_procs(args.output_dir) 714 | 715 | if args.push_to_hub: 716 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 717 | 718 | accelerator.end_training() 719 | 720 | 721 | if __name__ == "__main__": 722 | main() 723 | --------------------------------------------------------------------------------