├── 000000000285.jpg ├── 000000000724.jpg ├── 000000007991.jpg ├── 000000018837.jpg ├── 000000122962.jpg ├── 000000295478.jpg ├── README.md ├── eval_controlnet.py ├── eval_controlnet.sh ├── eval_controlnet_sdxl_light.py ├── eval_controlnet_sdxl_light.sh ├── eval_controlnet_sdxl_light_single.py ├── eval_controlnet_sdxl_light_single.sh ├── example ├── UUColor_results │ └── Hollywood-Sign.jpeg └── legacy_images │ ├── Big-Ben-vintage.jpg │ ├── Central-Park.jpg │ ├── Hollywood-Sign.jpg │ ├── Little-Mermaid.jpg │ ├── Migrant-Mother.jpg │ ├── Mount-Everest.jpg │ ├── Tower-of-Pisa.jpg │ └── Wasatch-Mountains-Summit-County-Utah.jpg ├── gradio_ui.py ├── images ├── 000000022935_gray.jpg ├── 000000022935_green_shirt_on_right_girl.jpeg ├── 000000022935_purple_shirt_on_right_girl.jpeg ├── 000000022935_red_shirt_on_right_girl.jpeg ├── 000000025560_color.jpg ├── 000000025560_gray.jpg ├── 000000025560_gt.jpg ├── 000000041633_black_car.jpeg ├── 000000041633_bright_red_car.jpeg ├── 000000041633_dark_blue_car.jpeg ├── 000000041633_gray.jpg ├── 000000065736_color.jpg ├── 000000065736_gray.jpg ├── 000000065736_gt.jpg ├── 000000091779_color.jpg ├── 000000091779_gray.jpg ├── 000000091779_gt.jpg ├── 000000092177_color.jpg ├── 000000092177_gray.jpg ├── 000000092177_gt.jpg ├── 000000166426_color.jpg ├── 000000166426_gray.jpg ├── 000000166426_gt.jpg ├── 000000286708_gray.jpg ├── 000000286708_orange_hat.jpeg ├── 000000286708_pink_hat.jpeg ├── 000000286708_yellow_hat.jpeg ├── framework.jpg └── gradio_ui.png ├── requirements.txt ├── train_controlnet.py ├── train_controlnet.sh ├── train_controlnet_sdxl.py ├── train_controlnet_sdxl.sh ├── train_controlnet_sdxl_light.py └── train_controlnet_sdxl_light.sh /000000000285.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/000000000285.jpg -------------------------------------------------------------------------------- /000000000724.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/000000000724.jpg -------------------------------------------------------------------------------- /000000007991.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/000000007991.jpg -------------------------------------------------------------------------------- /000000018837.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/000000018837.jpg -------------------------------------------------------------------------------- /000000122962.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/000000122962.jpg -------------------------------------------------------------------------------- /000000295478.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/000000295478.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-Guided-Image-Colorization 2 | 3 | This project utilizes the power of **Stable Diffusion (SDXL/SDXL-Light)** and the **BLIP (Bootstrapping Language-Image Pre-training)** captioning model to provide an interactive image colorization experience. Users can influence the generated colors of objects within images, making the colorization process more personalized and creative. 4 | 5 | ![framework.jpg](images/framework.jpg) 6 | 7 | ## Table of Contents 8 | - [Features](#features) 9 | - [Installation](#installation) 10 | - [Quick Start](#quick-start) 11 | - [Dataset Usage](#dataset-usage) 12 | - [Training](#training) 13 | - [Evaluation](#evaluation) 14 | - [Results](#results) 15 | - [License](#license) 16 | 17 | ## News 18 | - **(2024/11/23)** The project is now available on [Hugging Face Spaces](https://huggingface.co/spaces/fffiloni/text-guided-image-colorization) 🎉 Big thanks to @fffiloni! 19 | 20 | 21 | ## Features 22 | 23 | - **Interactive Colorization**: Users can specify desired colors for different objects in the image. 24 | - **ControlNet Approach**: Enhanced colorization capabilities through retraining with ControlNet, allowing SDXL to better adapt to the image colorization task. 25 | - **High-Quality Outputs**: Leverage the latest advancements in diffusion models to generate vibrant and realistic colorizations. 26 | 27 | ## Installation 28 | 29 | To set up the project locally, follow these steps: 30 | 31 | 1. **Clone the Repository**: 32 | 33 | ```bash 34 | git clone https://github.com/nick8592/text-guided-image-colorization.git 35 | cd text-guided-image-colorization 36 | ``` 37 | 38 | 2. **Install Dependencies**: 39 | Make sure you have Python 3.7 or higher installed. Then, install the required packages: 40 | 41 | ```bash 42 | pip install -r requirements.txt 43 | ``` 44 | Install `torch` and `torchvision` matching your CUDA version: 45 | ```bash 46 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cuXXX 47 | ``` 48 | Replace `XXX` with your CUDA version (e.g., `118` for CUDA 11.8). For more info, see [PyTorch Get Started](https://pytorch.org/get-started/locally/). 49 | 50 | 51 | 3. **Download Pre-trained Models**: 52 | | Models | Hugging Face | 53 | |:---:|:---:| 54 | |SDXL-Lightning Caption|[link](https://huggingface.co/nickpai/sdxl_light_caption_output)| 55 | |SDXL-Lightning Custom Caption (Recommand)|[link](https://huggingface.co/nickpai/sdxl_light_custom_caption_output)| 56 | 57 | 58 | ```bash 59 | text-guided-image-colorization/sdxl_light_caption_output 60 | └── checkpoint-30000 61 | ├── controlnet 62 | │ ├── diffusion_pytorch_model.safetensors 63 | │ └── config.json 64 | ├── optimizer.bin 65 | ├── random_states_0.pkl 66 | ├── scaler.pt 67 | └── scheduler.bin 68 | ``` 69 | 70 | ## Quick Start 71 | 72 | 1. Run the `gradio_ui.py` script: 73 | 74 | ```bash 75 | python gradio_ui.py 76 | ``` 77 | 78 | 2. Open the provided URL in your web browser to access the Gradio-based user interface. 79 | 80 | 3. Upload an image and use the interface to control the colors of specific objects in the image. But still the model can generate images without a specific prompt. 81 | 82 | 4. The model will generate a colorized version of the image based on your input (or automatic). See the [demo video](https://x.com/weichenpai/status/1829513077588631987). 83 | ![Gradio UI](images/gradio_ui.png) 84 | 85 | 86 | ## Dataset Usage 87 | 88 | You can find more details about the dataset usage in the [Dataset-for-Image-Colorization](https://github.com/nick8592/Dataset-for-Image-Colorization). 89 | 90 | ## Training 91 | 92 | For training, you can use one of the following scripts: 93 | 94 | - `train_controlnet.sh`: Trains a model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1) 95 | - `train_controlnet_sdxl.sh`: Trains a model using [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 96 | - `train_controlnet_sdxl_light.sh`: Trains a model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) 97 | 98 | Although the training code for SDXL is provided, due to a lack of GPU resources, I wasn't able to train the model by myself. Therefore, there might be some errors when you try to train the model. 99 | 100 | ## Evaluation 101 | 102 | For evaluation, you can use one of the following scripts: 103 | 104 | - `eval_controlnet.sh`: Evaluates the model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1) for a folder of images. 105 | - `eval_controlnet_sdxl_light.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a folder of images. 106 | - `eval_controlnet_sdxl_light_single.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a single image. 107 | 108 | ## Results 109 | ### Prompt-Guided 110 | | Caption | Condition 1 | Condition 2 | Condition 3 | 111 | |:---:|:---:|:---:|:---:| 112 | | ![000000022935_gray.jpg](images/000000022935_gray.jpg) | ![000000022935_green_shirt_on_right_girl.jpeg](images/000000022935_green_shirt_on_right_girl.jpeg) | ![000000022935_purple_shirt_on_right_girl.jpeg](images/000000022935_purple_shirt_on_right_girl.jpeg) |![000000022935_red_shirt_on_right_girl.jpeg](images/000000022935_red_shirt_on_right_girl.jpeg) | 113 | | a photography of a woman in a soccer uniform kicking a soccer ball | + "green shirt"| + "purple shirt" | + "red shirt" | 114 | | ![000000041633_gray.jpg](images/000000041633_gray.jpg) | ![000000041633_bright_red_car.jpeg](images/000000041633_bright_red_car.jpeg) | ![000000041633_dark_blue_car.jpeg](images/000000041633_dark_blue_car.jpeg) |![000000041633_black_car.jpeg](images/000000041633_black_car.jpeg) | 115 | | a photography of a photo of a truck | + "bright red car"| + "dark blue car" | + "black car" | 116 | | ![000000286708_gray.jpg](images/000000286708_gray.jpg) | ![000000286708_orange_hat.jpeg](images/000000286708_orange_hat.jpeg) | ![000000286708_pink_hat.jpeg](images/000000286708_pink_hat.jpeg) |![000000286708_yellow_hat.jpeg](images/000000286708_yellow_hat.jpeg) | 117 | | a photography of a cat wearing a hat on his head | + "orange hat"| + "pink hat" | + "yellow hat" | 118 | 119 | ### Prompt-Free 120 | Ground truth images are provided solely for reference purpose in the image colorization task. 121 | | Grayscale Image | Colorized Result | Ground Truth | 122 | |:---:|:---:|:---:| 123 | | ![000000025560_gray.jpg](images/000000025560_gray.jpg) | ![000000025560_color.jpg](images/000000025560_color.jpg) | ![000000025560_gt.jpg](images/000000025560_gt.jpg) | 124 | | ![000000065736_gray.jpg](images/000000065736_gray.jpg) | ![000000065736_color.jpg](images/000000065736_color.jpg) | ![000000065736_gt.jpg](images/000000065736_gt.jpg) | 125 | | ![000000091779_gray.jpg](images/000000091779_gray.jpg) | ![000000091779_color.jpg](images/000000091779_color.jpg) | ![000000091779_gt.jpg](images/000000091779_gt.jpg) | 126 | | ![000000092177_gray.jpg](images/000000092177_gray.jpg) | ![000000092177_color.jpg](images/000000092177_color.jpg) | ![000000092177_gt.jpg](images/000000092177_gt.jpg) | 127 | | ![000000166426_gray.jpg](images/000000166426_gray.jpg) | ![000000166426_color.jpg](images/000000166426_color.jpg) | ![000000025560_gt.jpg](images/000000166426_gt.jpg) | 128 | 129 | ## Read More 130 | 131 | Here are some related articles you might find interesting: 132 | 133 | - [Image Colorization: Bringing Black and White to Life](https://medium.com/generative-ai/image-colorization-bringing-black-and-white-to-life-b14d3e0db763) 134 | - [Understanding RGB, YCbCr, and Lab Color Spaces](https://medium.com/@weichenpai/understanding-rgb-ycbcr-and-lab-color-spaces-f9c4a5fe485a) 135 | - [Comparison Between CLIP and BLIP Models](https://medium.com/generative-ai/comparison-between-clip-and-blip-models-42f8a6ff4b1e) 136 | - [A Step-by-Step Guide to Interactive Machine Learning with Gradio](https://medium.com/generative-ai/a-step-by-step-guide-to-interactive-machine-learning-with-gradio-3fde7541da52) 137 | 138 | ## License 139 | 140 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details. 141 | -------------------------------------------------------------------------------- /eval_controlnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import shutil 5 | import argparse 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | from PIL import Image 10 | from datasets import load_dataset 11 | from diffusers.utils import load_image 12 | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel 13 | 14 | # Define the function to parse arguments 15 | def parse_args(input_args=None): 16 | parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.") 17 | 18 | parser.add_argument("--model_dir", type=str, default="sd_v2_caption_free_output/checkpoint-22500", 19 | help="Directory of the model checkpoint") 20 | parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-base", 21 | help="ID of the model (Tested with runwayml/stable-diffusion-v1-5 and stabilityai/stable-diffusion-2-base)") 22 | parser.add_argument("--dataset", type=str, default="nickpai/coco2017-colorization", 23 | help="Dataset used") 24 | parser.add_argument("--revision", type=str, default="caption-free", 25 | choices=["main", "caption-free"], 26 | help="Revision option (main/caption-free)") 27 | 28 | if input_args is not None: 29 | args = parser.parse_args(input_args) 30 | else: 31 | args = parser.parse_args() 32 | 33 | return args 34 | 35 | def apply_color(image, color_map): 36 | # Convert input images to LAB color space 37 | image_lab = image.convert('LAB') 38 | color_map_lab = color_map.convert('LAB') 39 | 40 | # Split LAB channels 41 | l, a, b = image_lab.split() 42 | _, a_map, b_map = color_map_lab.split() 43 | 44 | # Merge LAB channels with color map 45 | merged_lab = Image.merge('LAB', (l, a_map, b_map)) 46 | 47 | # Convert merged LAB image back to RGB color space 48 | result_rgb = merged_lab.convert('RGB') 49 | 50 | return result_rgb 51 | 52 | def main(args): 53 | generator = torch.manual_seed(0) 54 | 55 | # MODEL_DIR = "sd_v2_caption_free_output/checkpoint-22500" 56 | # # MODEL_ID="runwayml/stable-diffusion-v1-5" 57 | # MODEL_ID="stabilityai/stable-diffusion-2-base" 58 | # DATASET = "nickpai/coco2017-colorization" 59 | # REVISION = "caption-free" # option: main/caption-free 60 | 61 | # Path to the eval_results folder 62 | eval_results_folder = os.path.join(args.model_dir, "results") 63 | 64 | # Remove eval_results folder if it exists 65 | if os.path.exists(eval_results_folder): 66 | shutil.rmtree(eval_results_folder) 67 | 68 | # Create directory for eval_results 69 | os.makedirs(eval_results_folder) 70 | 71 | # Create subfolders for compare and colorized images 72 | compare_folder = os.path.join(eval_results_folder, "compare") 73 | colorized_folder = os.path.join(eval_results_folder, "colorized") 74 | os.makedirs(compare_folder) 75 | os.makedirs(colorized_folder) 76 | 77 | # Load the validation split of the colorization dataset 78 | val_dataset = load_dataset(args.dataset, split="validation", revision=args.revision) 79 | 80 | controlnet = ControlNetModel.from_pretrained(f"{args.model_dir}/controlnet", torch_dtype=torch.float16) 81 | pipe = StableDiffusionControlNetPipeline.from_pretrained( 82 | args.model_id, controlnet=controlnet, torch_dtype=torch.float16 83 | ).to("cuda") 84 | 85 | pipe.safety_checker = None 86 | 87 | # Counter for processed images 88 | processed_images = 0 89 | 90 | # Record start time 91 | start_time = time.time() 92 | 93 | # Iterate through the validation dataset 94 | for example in tqdm(val_dataset, desc="Processing Images"): 95 | image_path = example["file_name"] 96 | 97 | prompt = [] 98 | for caption in example["captions"]: 99 | if isinstance(caption, str): 100 | prompt.append(caption) 101 | elif isinstance(caption, (list, np.ndarray)): 102 | # take a random caption if there are multiple 103 | prompt.append(caption[0]) 104 | else: 105 | raise ValueError( 106 | f"Caption column `captions` should contain either strings or lists of strings." 107 | ) 108 | 109 | # Generate image 110 | ground_truth_image = load_image(image_path).resize((512, 512)) 111 | control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512)) 112 | image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0] 113 | 114 | # Apply color mapping 115 | image = apply_color(ground_truth_image, image) 116 | 117 | # Concatenate images into a row 118 | row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image))) 119 | row_image = Image.fromarray(row_image) 120 | 121 | # Save row image in the compare folder 122 | compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}") 123 | row_image.save(compare_output_path) 124 | 125 | # Save colorized image in the colorized folder 126 | colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}") 127 | image.save(colorized_output_path) 128 | 129 | # Increment processed images counter 130 | processed_images += 1 131 | 132 | # Record end time 133 | end_time = time.time() 134 | 135 | # Calculate total time taken 136 | total_time = end_time - start_time 137 | 138 | # Calculate FPS 139 | fps = processed_images / total_time 140 | 141 | print("All images processed.") 142 | print(f"Total time taken: {total_time:.2f} seconds") 143 | print(f"FPS: {fps:.2f}") 144 | 145 | # Entry point of the script 146 | if __name__ == "__main__": 147 | args = parse_args() 148 | main(args) -------------------------------------------------------------------------------- /eval_controlnet.sh: -------------------------------------------------------------------------------- 1 | # Define default values for parameters 2 | 3 | # # sdv2 with BCE loss 4 | # MODEL_DIR="sd_v2_caption_bce_output/checkpoint-22500" 5 | # MODEL_ID="stabilityai/stable-diffusion-2-base" 6 | # DATASET="nickpai/coco2017-colorization" 7 | # REVISION="main" 8 | 9 | # sdv2 with kl loss 10 | MODEL_DIR="sd_v2_caption_kl_output/checkpoint-22500" 11 | MODEL_ID="stabilityai/stable-diffusion-2-base" 12 | DATASET="nickpai/coco2017-colorization" 13 | REVISION="main" 14 | 15 | accelerate launch eval_controlnet.py \ 16 | --model_dir=$MODEL_DIR \ 17 | --model_id=$MODEL_ID \ 18 | --dataset=$DATASET \ 19 | --revision=$REVISION -------------------------------------------------------------------------------- /eval_controlnet_sdxl_light.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import shutil 5 | import argparse 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | from PIL import Image 10 | from datasets import load_dataset 11 | from accelerate import Accelerator 12 | from diffusers.utils import load_image 13 | from diffusers import ( 14 | AutoencoderKL, 15 | StableDiffusionXLControlNetPipeline, 16 | ControlNetModel, 17 | UNet2DConditionModel, 18 | ) 19 | from huggingface_hub import hf_hub_download 20 | from safetensors.torch import load_file 21 | 22 | # Define the function to parse arguments 23 | def parse_args(input_args=None): 24 | parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.") 25 | 26 | parser.add_argument( 27 | "--pretrained_model_name_or_path", 28 | type=str, 29 | default=None, 30 | required=True, 31 | help="Path to pretrained model or model identifier from huggingface.co/models.", 32 | ) 33 | parser.add_argument( 34 | "--pretrained_vae_model_name_or_path", 35 | type=str, 36 | default=None, 37 | help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", 38 | ) 39 | parser.add_argument( 40 | "--controlnet_model_name_or_path", 41 | type=str, 42 | default=None, 43 | required=True, 44 | help="Path to pretrained controlnet model.", 45 | ) 46 | parser.add_argument( 47 | "--output_dir", 48 | type=str, 49 | default=None, 50 | required=True, 51 | help="Path to output results.", 52 | ) 53 | parser.add_argument( 54 | "--dataset", 55 | type=str, 56 | default="nickpai/coco2017-colorization", 57 | help="Dataset used" 58 | ) 59 | parser.add_argument( 60 | "--dataset_revision", 61 | type=str, 62 | default="caption-free", 63 | choices=["main", "caption-free", "custom-caption"], 64 | help="Revision option (main/caption-free/custom-caption)" 65 | ) 66 | parser.add_argument( 67 | "--mixed_precision", 68 | type=str, 69 | default=None, 70 | choices=["no", "fp16", "bf16"], 71 | help=( 72 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 73 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 74 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 75 | ), 76 | ) 77 | parser.add_argument( 78 | "--variant", 79 | type=str, 80 | default=None, 81 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 82 | ) 83 | parser.add_argument( 84 | "--revision", 85 | type=str, 86 | default=None, 87 | required=False, 88 | help="Revision of pretrained model identifier from huggingface.co/models.", 89 | ) 90 | parser.add_argument( 91 | "--num_inference_steps", 92 | type=int, 93 | default=8, 94 | help="1-step, 2-step, 4-step, or 8-step distilled models" 95 | ) 96 | parser.add_argument( 97 | "--repo", 98 | type=str, 99 | default="ByteDance/SDXL-Lightning", 100 | required=True, 101 | help="Repository from huggingface.co", 102 | ) 103 | parser.add_argument( 104 | "--ckpt", 105 | type=str, 106 | default="sdxl_lightning_4step_unet.safetensors", 107 | required=True, 108 | help="Available checkpoints from the repository", 109 | ) 110 | parser.add_argument( 111 | "--negative_prompt", 112 | action="store_true", 113 | help="The prompt or prompts not to guide the image generation", 114 | ) 115 | 116 | if input_args is not None: 117 | args = parser.parse_args(input_args) 118 | else: 119 | args = parser.parse_args() 120 | 121 | return args 122 | 123 | def apply_color(image, color_map): 124 | # Convert input images to LAB color space 125 | image_lab = image.convert('LAB') 126 | color_map_lab = color_map.convert('LAB') 127 | 128 | # Split LAB channels 129 | l, a, b = image_lab.split() 130 | _, a_map, b_map = color_map_lab.split() 131 | 132 | # Merge LAB channels with color map 133 | merged_lab = Image.merge('LAB', (l, a_map, b_map)) 134 | 135 | # Convert merged LAB image back to RGB color space 136 | result_rgb = merged_lab.convert('RGB') 137 | 138 | return result_rgb 139 | 140 | def main(args): 141 | generator = torch.manual_seed(0) 142 | 143 | # Path to the eval_results folder 144 | eval_results_folder = os.path.join(args.output_dir, "results") 145 | 146 | # Remove eval_results folder if it exists 147 | if os.path.exists(eval_results_folder): 148 | shutil.rmtree(eval_results_folder) 149 | 150 | # Create directory for eval_results 151 | os.makedirs(eval_results_folder) 152 | 153 | # Create subfolders for compare and colorized images 154 | compare_folder = os.path.join(eval_results_folder, "compare") 155 | colorized_folder = os.path.join(eval_results_folder, "colorized") 156 | os.makedirs(compare_folder) 157 | os.makedirs(colorized_folder) 158 | 159 | # Load the validation split of the colorization dataset 160 | val_dataset = load_dataset(args.dataset, split="validation", revision=args.dataset_revision) 161 | 162 | accelerator = Accelerator( 163 | mixed_precision=args.mixed_precision, 164 | ) 165 | 166 | weight_dtype = torch.float32 167 | if accelerator.mixed_precision == "fp16": 168 | weight_dtype = torch.float16 169 | elif accelerator.mixed_precision == "bf16": 170 | weight_dtype = torch.bfloat16 171 | 172 | vae_path = ( 173 | args.pretrained_model_name_or_path 174 | if args.pretrained_vae_model_name_or_path is None 175 | else args.pretrained_vae_model_name_or_path 176 | ) 177 | vae = AutoencoderKL.from_pretrained( 178 | vae_path, 179 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 180 | revision=args.revision, 181 | variant=args.variant, 182 | ) 183 | unet = UNet2DConditionModel.from_config( 184 | args.pretrained_model_name_or_path, 185 | subfolder="unet", 186 | revision=args.revision, 187 | variant=args.variant, 188 | ) 189 | unet.load_state_dict(load_file(hf_hub_download(args.repo, args.ckpt))) 190 | 191 | # Move vae, unet and text_encoder to device and cast to weight_dtype 192 | # The VAE is in float32 to avoid NaN losses. 193 | if args.pretrained_vae_model_name_or_path is not None: 194 | vae.to(accelerator.device, dtype=weight_dtype) 195 | else: 196 | vae.to(accelerator.device, dtype=torch.float32) 197 | unet.to(accelerator.device, dtype=weight_dtype) 198 | 199 | controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=weight_dtype) 200 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 201 | args.pretrained_model_name_or_path, 202 | vae=vae, 203 | unet=unet, 204 | controlnet=controlnet, 205 | ) 206 | pipe.to(accelerator.device, dtype=weight_dtype) 207 | 208 | # Prepare everything with our `accelerator`. 209 | pipe, val_dataset = accelerator.prepare(pipe, val_dataset) 210 | 211 | pipe.safety_checker = None 212 | 213 | # Counter for processed images 214 | processed_images = 0 215 | 216 | # Record start time 217 | start_time = time.time() 218 | 219 | # Iterate through the validation dataset 220 | for example in tqdm(val_dataset, desc="Processing Images"): 221 | image_path = example["file_name"] 222 | 223 | prompt = [] 224 | for caption in example["captions"]: 225 | if isinstance(caption, str): 226 | prompt.append(caption) 227 | elif isinstance(caption, (list, np.ndarray)): 228 | # take a random caption if there are multiple 229 | prompt.append(caption[0]) 230 | else: 231 | raise ValueError( 232 | f"Caption column `captions` should contain either strings or lists of strings." 233 | ) 234 | 235 | negative_prompt = None 236 | if args.negative_prompt: 237 | negative_prompt = [ 238 | "low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate" 239 | ] 240 | 241 | # Generate image 242 | ground_truth_image = load_image(image_path).resize((512, 512)) 243 | control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512)) 244 | image = pipe(prompt=prompt, 245 | negative_prompt=negative_prompt, 246 | num_inference_steps=args.num_inference_steps, 247 | generator=generator, 248 | image=control_image).images[0] 249 | 250 | # Apply color mapping 251 | image = apply_color(ground_truth_image, image) 252 | 253 | # Concatenate images into a row 254 | row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image))) 255 | row_image = Image.fromarray(row_image) 256 | 257 | # Save row image in the compare folder 258 | compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}") 259 | row_image.save(compare_output_path) 260 | 261 | # Save colorized image in the colorized folder 262 | colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}") 263 | image.save(colorized_output_path) 264 | 265 | # Increment processed images counter 266 | processed_images += 1 267 | 268 | # Record end time 269 | end_time = time.time() 270 | 271 | # Calculate total time taken 272 | total_time = end_time - start_time 273 | 274 | # Calculate FPS 275 | fps = processed_images / total_time 276 | 277 | print("All images processed.") 278 | print(f"Total time taken: {total_time:.2f} seconds") 279 | print(f"FPS: {fps:.2f}") 280 | 281 | # Entry point of the script 282 | if __name__ == "__main__": 283 | args = parse_args() 284 | main(args) -------------------------------------------------------------------------------- /eval_controlnet_sdxl_light.sh: -------------------------------------------------------------------------------- 1 | # Define default values for parameters 2 | 3 | # # sdxl light without negative prompt 4 | # export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0" 5 | # export REPO="ByteDance/SDXL-Lightning" 6 | # export INFERENCE_STEP=8 7 | # export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step 8 | # export CONTROLNET_MODEL="sdxl_light_custom_caption_output/checkpoint-12500/controlnet" 9 | # export DATASET="nickpai/coco2017-colorization" 10 | # export DATSET_REVISION="custom-caption" 11 | # export OUTPUT_DIR="sdxl_light_custom_caption_output/checkpoint-12500" 12 | 13 | # accelerate launch eval_controlnet_sdxl_light.py \ 14 | # --pretrained_model_name_or_path=$BASE_MODEL \ 15 | # --repo=$REPO \ 16 | # --ckpt=$CKPT \ 17 | # --num_inference_steps=$INFERENCE_STEP \ 18 | # --controlnet_model_name_or_path=$CONTROLNET_MODEL \ 19 | # --dataset=$DATASET \ 20 | # --dataset_revision=$DATSET_REVISION \ 21 | # --mixed_precision="fp16" \ 22 | # --output_dir=$OUTPUT_DIR 23 | 24 | # sdxl light with negative prompt 25 | export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0" 26 | export REPO="ByteDance/SDXL-Lightning" 27 | export INFERENCE_STEP=8 28 | export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step 29 | export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-22500/controlnet" 30 | export DATASET="nickpai/coco2017-colorization" 31 | export DATSET_REVISION="custom-caption" 32 | export OUTPUT_DIR="sdxl_light_caption_output/checkpoint-22500" 33 | 34 | accelerate launch eval_controlnet_sdxl_light.py \ 35 | --pretrained_model_name_or_path=$BASE_MODEL \ 36 | --repo=$REPO \ 37 | --ckpt=$CKPT \ 38 | --num_inference_steps=$INFERENCE_STEP \ 39 | --controlnet_model_name_or_path=$CONTROLNET_MODEL \ 40 | --dataset=$DATASET \ 41 | --dataset_revision=$DATSET_REVISION \ 42 | --mixed_precision="fp16" \ 43 | --output_dir=$OUTPUT_DIR \ 44 | --negative_prompt -------------------------------------------------------------------------------- /eval_controlnet_sdxl_light_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import time 4 | import torch 5 | import argparse 6 | 7 | from typing import Optional, Union 8 | from accelerate import Accelerator 9 | from diffusers import ( 10 | AutoencoderKL, 11 | StableDiffusionXLControlNetPipeline, 12 | ControlNetModel, 13 | UNet2DConditionModel, 14 | ) 15 | from transformers import ( 16 | BlipProcessor, BlipForConditionalGeneration, 17 | VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer 18 | ) 19 | from huggingface_hub import hf_hub_download 20 | from safetensors.torch import load_file 21 | 22 | # Define the function to parse arguments 23 | def parse_args(input_args=None): 24 | parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.") 25 | parser.add_argument( 26 | "--image_path", 27 | type=str, 28 | default="example/legacy_images/Hollywood-Sign.jpg", 29 | required=True, 30 | help="Path to the image", 31 | ) 32 | parser.add_argument( 33 | "--pretrained_model_name_or_path", 34 | type=str, 35 | default=None, 36 | required=True, 37 | help="Path to pretrained model or model identifier from huggingface.co/models.", 38 | ) 39 | parser.add_argument( 40 | "--pretrained_vae_model_name_or_path", 41 | type=str, 42 | default=None, 43 | help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", 44 | ) 45 | parser.add_argument( 46 | "--controlnet_model_name_or_path", 47 | type=str, 48 | default=None, 49 | required=True, 50 | help="Path to pretrained controlnet model.", 51 | ) 52 | parser.add_argument( 53 | "--caption_model_name", 54 | type=str, 55 | default="blip-image-captioning-large", 56 | choices=["blip-image-captioning-large", "blip-image-captioning-base"], 57 | help="Path to pretrained controlnet model.", 58 | ) 59 | parser.add_argument( 60 | "--mixed_precision", 61 | type=str, 62 | default=None, 63 | choices=["no", "fp16", "bf16"], 64 | help=( 65 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 66 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 67 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 68 | ), 69 | ) 70 | parser.add_argument( 71 | "--variant", 72 | type=str, 73 | default=None, 74 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 75 | ) 76 | parser.add_argument( 77 | "--revision", 78 | type=str, 79 | default=None, 80 | required=False, 81 | help="Revision of pretrained model identifier from huggingface.co/models.", 82 | ) 83 | parser.add_argument( 84 | "--num_inference_steps", 85 | type=int, 86 | default=8, 87 | help="1-step, 2-step, 4-step, or 8-step distilled models" 88 | ) 89 | parser.add_argument( 90 | "--repo", 91 | type=str, 92 | default="ByteDance/SDXL-Lightning", 93 | required=True, 94 | help="Repository from huggingface.co", 95 | ) 96 | parser.add_argument( 97 | "--ckpt", 98 | type=str, 99 | default="sdxl_lightning_4step_unet.safetensors", 100 | required=True, 101 | help="Available checkpoints from the repository", 102 | ) 103 | parser.add_argument( 104 | "--seed", 105 | type=int, 106 | default=123, 107 | help="Random seeds" 108 | ) 109 | parser.add_argument( 110 | "--positive_prompt", 111 | type=str, 112 | help="Text for positive prompt", 113 | ) 114 | parser.add_argument( 115 | "--negative_prompt", 116 | type=str, 117 | default="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate", 118 | help="Text for negative prompt", 119 | ) 120 | 121 | if input_args is not None: 122 | args = parser.parse_args(input_args) 123 | else: 124 | args = parser.parse_args() 125 | 126 | return args 127 | 128 | def apply_color(image, color_map): 129 | # Convert input images to LAB color space 130 | image_lab = image.convert('LAB') 131 | color_map_lab = color_map.convert('LAB') 132 | 133 | # Split LAB channels 134 | l, a, b = image_lab.split() 135 | _, a_map, b_map = color_map_lab.split() 136 | 137 | # Merge LAB channels with color map 138 | merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map)) 139 | 140 | # Convert merged LAB image back to RGB color space 141 | result_rgb = merged_lab.convert('RGB') 142 | 143 | return result_rgb 144 | 145 | def remove_unlikely_words(prompt: str) -> str: 146 | """ 147 | Removes unlikely words from a prompt. 148 | 149 | Args: 150 | prompt: The text prompt to be cleaned. 151 | 152 | Returns: 153 | The cleaned prompt with unlikely words removed. 154 | """ 155 | unlikely_words = [] 156 | 157 | a1_list = [f'{i}s' for i in range(1900, 2000)] 158 | a2_list = [f'{i}' for i in range(1900, 2000)] 159 | a3_list = [f'year {i}' for i in range(1900, 2000)] 160 | a4_list = [f'circa {i}' for i in range(1900, 2000)] 161 | b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list] 162 | b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] 163 | b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] 164 | b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] 165 | 166 | words_list = [ 167 | "black and white,", "black and white", "black & white,", "black & white", "circa", 168 | "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", 169 | "black - and - white photography,", "monochrome bw,", "black white,", "black an white,", 170 | "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", 171 | "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", 172 | "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", 173 | "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", 174 | "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", 175 | "black-and-white photo,", "black-and-white photo", "black - and - white photography", 176 | "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", 177 | "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", 178 | "black - and - white photograph,", "black - and - white photograph", "black on white,", 179 | "black on white", "black-and-white", "historical image,", "historical picture,", 180 | "historical photo,", "historical photograph,", "archival photo,", "taken in the early", 181 | "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", 182 | "historical photo", "historical setting,", 183 | "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", 184 | "taken in", "shot on leica", "shot on leica sl2", "sl2", 185 | "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting", 186 | "overcast day", "overcast weather", "slight overcast", "overcast", 187 | "picture taken in", "photo taken in", 188 | ", photo", ", photo", ", photo", ", photo", ", photograph", 189 | ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", 190 | ] 191 | 192 | unlikely_words.extend(a1_list) 193 | unlikely_words.extend(a2_list) 194 | unlikely_words.extend(a3_list) 195 | unlikely_words.extend(a4_list) 196 | unlikely_words.extend(b1_list) 197 | unlikely_words.extend(b2_list) 198 | unlikely_words.extend(b3_list) 199 | unlikely_words.extend(b4_list) 200 | unlikely_words.extend(words_list) 201 | 202 | for word in unlikely_words: 203 | prompt = prompt.replace(word, "") 204 | return prompt 205 | 206 | def blip_image_captioning(image: PIL.Image.Image, 207 | model_backbone: str, 208 | weight_dtype: type, 209 | device: str, 210 | conditional: bool) -> str: 211 | # https://huggingface.co/Salesforce/blip-image-captioning-large 212 | # https://huggingface.co/Salesforce/blip-image-captioning-base 213 | if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type 214 | weight_dtype = torch.float16 215 | 216 | processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}") 217 | model = BlipForConditionalGeneration.from_pretrained( 218 | f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device) 219 | 220 | valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"] 221 | if model_backbone not in valid_backbones: 222 | raise ValueError(f"Invalid model backbone '{model_backbone}'. \ 223 | Valid options are: {', '.join(valid_backbones)}") 224 | 225 | if conditional: 226 | text = "a photography of" 227 | inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype) 228 | else: 229 | inputs = processor(image, return_tensors="pt").to(device) 230 | out = model.generate(**inputs) 231 | caption = processor.decode(out[0], skip_special_tokens=True) 232 | return caption 233 | 234 | import matplotlib.pyplot as plt 235 | 236 | def display_images(input_image, output_image, ground_truth): 237 | """ 238 | Displays a grid of input, output, ground truth images with a caption at the bottom. 239 | 240 | Args: 241 | input_image: A grayscale image as a NumPy array. 242 | output_image: A grayscale image (result) as a NumPy array. 243 | ground_truth: A grayscale image (ground truth) as a NumPy array. 244 | """ 245 | fig, axes = plt.subplots(1, 3, figsize=(20, 8)) 246 | 247 | axes[0].imshow(input_image, cmap='gray') 248 | axes[0].set_title('Input') 249 | axes[0].axis('off') 250 | 251 | axes[1].imshow(output_image) 252 | axes[1].set_title('Output') 253 | axes[1].axis('off') 254 | 255 | axes[2].imshow(ground_truth) 256 | axes[2].set_title('Ground Truth') 257 | axes[2].axis('off') 258 | 259 | plt.tight_layout() 260 | plt.show() 261 | 262 | # Define a function to process the image with the loaded model 263 | def process_image(image_path: str, 264 | controlnet_model_name_or_path: str, 265 | caption_model_name: str, 266 | positive_prompt: Optional[str], 267 | negative_prompt: Optional[str], 268 | seed: int, 269 | num_inference_steps: int, 270 | mixed_precision: str, 271 | pretrained_model_name_or_path: str, 272 | pretrained_vae_model_name_or_path: Optional[str], 273 | revision: Optional[str], 274 | variant: Optional[str], 275 | repo: str, 276 | ckpt: str,) -> PIL.Image.Image: 277 | # Seed 278 | generator = torch.manual_seed(seed) 279 | 280 | # Accelerator Setting 281 | accelerator = Accelerator( 282 | mixed_precision=mixed_precision, 283 | ) 284 | 285 | weight_dtype = torch.float32 286 | if accelerator.mixed_precision == "fp16": 287 | weight_dtype = torch.float16 288 | elif accelerator.mixed_precision == "bf16": 289 | weight_dtype = torch.bfloat16 290 | 291 | vae_path = ( 292 | pretrained_model_name_or_path 293 | if pretrained_vae_model_name_or_path is None 294 | else pretrained_vae_model_name_or_path 295 | ) 296 | vae = AutoencoderKL.from_pretrained( 297 | vae_path, 298 | subfolder="vae" if pretrained_vae_model_name_or_path is None else None, 299 | revision=revision, 300 | variant=variant, 301 | ) 302 | unet = UNet2DConditionModel.from_config( 303 | pretrained_model_name_or_path, 304 | subfolder="unet", 305 | revision=revision, 306 | variant=variant, 307 | ) 308 | unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) 309 | 310 | # Move vae, unet and text_encoder to device and cast to weight_dtype 311 | # The VAE is in float32 to avoid NaN losses. 312 | if pretrained_vae_model_name_or_path is not None: 313 | vae.to(accelerator.device, dtype=weight_dtype) 314 | else: 315 | vae.to(accelerator.device, dtype=torch.float32) 316 | unet.to(accelerator.device, dtype=weight_dtype) 317 | 318 | controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype) 319 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 320 | pretrained_model_name_or_path, 321 | vae=vae, 322 | unet=unet, 323 | controlnet=controlnet, 324 | ) 325 | pipe.to(accelerator.device, dtype=weight_dtype) 326 | 327 | image = PIL.Image.open(image_path) 328 | 329 | # Prepare everything with our `accelerator`. 330 | pipe, image = accelerator.prepare(pipe, image) 331 | pipe.safety_checker = None 332 | 333 | # Convert image into grayscale 334 | original_size = image.size 335 | control_image = image.convert("L").convert("RGB").resize((512, 512)) 336 | 337 | # Image captioning 338 | if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base": 339 | caption = blip_image_captioning(control_image, caption_model_name, 340 | weight_dtype, accelerator.device, conditional=True) 341 | # elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k": 342 | # caption = clip_image_captioning(control_image, caption_model_name, accelerator.device) 343 | # elif caption_model_name == "vit-gpt2-image-captioning": 344 | # caption = vit_gpt2_image_captioning(control_image, accelerator.device) 345 | caption = remove_unlikely_words(caption) 346 | 347 | print("================================================================") 348 | print(f"Positive prompt: \n>>> {positive_prompt}") 349 | print(f"Negative prompt: \n>>> {negative_prompt}") 350 | print(f"Caption results: \n>>> {caption}") 351 | print("================================================================") 352 | 353 | # Combine positive prompt and captioning result 354 | prompt = [positive_prompt + ", " + caption] 355 | 356 | # Image colorization 357 | image = pipe(prompt=prompt, 358 | negative_prompt=negative_prompt, 359 | num_inference_steps=num_inference_steps, 360 | generator=generator, 361 | image=control_image).images[0] 362 | 363 | # Apply color mapping 364 | result_image = apply_color(control_image, image) 365 | result_image = result_image.resize(original_size) 366 | return result_image, caption 367 | 368 | def main(args): 369 | output_image, output_caption = process_image(image_path=args.image_path, 370 | controlnet_model_name_or_path=args.controlnet_model_name_or_path, 371 | caption_model_name=args.caption_model_name, 372 | positive_prompt=args.positive_prompt, 373 | negative_prompt=args.negative_prompt, 374 | seed=args.seed, 375 | num_inference_steps=args.num_inference_steps, 376 | mixed_precision=args.mixed_precision, 377 | pretrained_model_name_or_path=args.pretrained_model_name_or_path, 378 | pretrained_vae_model_name_or_path=args.pretrained_vae_model_name_or_path, 379 | revision=args.revision, 380 | variant=args.variant, 381 | repo=args.repo, 382 | ckpt=args.ckpt,) 383 | input_image = PIL.Image.open(args.image_path) 384 | display_images(input_image.convert("L"), output_image, input_image) 385 | return output_image, output_caption 386 | 387 | # Entry point of the script 388 | if __name__ == "__main__": 389 | args = parse_args() 390 | main(args) -------------------------------------------------------------------------------- /eval_controlnet_sdxl_light_single.sh: -------------------------------------------------------------------------------- 1 | # sdxl light for single image 2 | export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0" 3 | export REPO="ByteDance/SDXL-Lightning" 4 | export INFERENCE_STEP=8 5 | export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step 6 | export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-30000/controlnet" 7 | export CAPTION_MODEL="blip-image-captioning-large" 8 | export IMAGE_PATH="example/legacy_images/Hollywood-Sign.jpg" 9 | # export POSITIVE_PROMPT="blue shirt" 10 | 11 | accelerate launch eval_controlnet_sdxl_light_single.py \ 12 | --pretrained_model_name_or_path=$BASE_MODEL \ 13 | --repo=$REPO \ 14 | --ckpt=$CKPT \ 15 | --num_inference_steps=$INFERENCE_STEP \ 16 | --controlnet_model_name_or_path=$CONTROLNET_MODEL \ 17 | --caption_model_name=$CAPTION_MODEL \ 18 | --mixed_precision="fp16" \ 19 | --image_path=$IMAGE_PATH \ 20 | --positive_prompt="red car" -------------------------------------------------------------------------------- /example/UUColor_results/Hollywood-Sign.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/UUColor_results/Hollywood-Sign.jpeg -------------------------------------------------------------------------------- /example/legacy_images/Big-Ben-vintage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Big-Ben-vintage.jpg -------------------------------------------------------------------------------- /example/legacy_images/Central-Park.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Central-Park.jpg -------------------------------------------------------------------------------- /example/legacy_images/Hollywood-Sign.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Hollywood-Sign.jpg -------------------------------------------------------------------------------- /example/legacy_images/Little-Mermaid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Little-Mermaid.jpg -------------------------------------------------------------------------------- /example/legacy_images/Migrant-Mother.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Migrant-Mother.jpg -------------------------------------------------------------------------------- /example/legacy_images/Mount-Everest.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Mount-Everest.jpg -------------------------------------------------------------------------------- /example/legacy_images/Tower-of-Pisa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Tower-of-Pisa.jpg -------------------------------------------------------------------------------- /example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg -------------------------------------------------------------------------------- /gradio_ui.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | import subprocess 4 | import gradio as gr 5 | 6 | from typing import Optional 7 | from accelerate import Accelerator 8 | from diffusers import ( 9 | AutoencoderKL, 10 | StableDiffusionXLControlNetPipeline, 11 | ControlNetModel, 12 | UNet2DConditionModel, 13 | ) 14 | from transformers import ( 15 | BlipProcessor, BlipForConditionalGeneration, 16 | VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer 17 | ) 18 | from huggingface_hub import hf_hub_download 19 | from safetensors.torch import load_file 20 | from clip_interrogator import Interrogator, Config, list_clip_models 21 | 22 | def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image: 23 | # Convert input images to LAB color space 24 | image_lab = image.convert('LAB') 25 | color_map_lab = color_map.convert('LAB') 26 | 27 | # Split LAB channels 28 | l, a , b = image_lab.split() 29 | _, a_map, b_map = color_map_lab.split() 30 | 31 | # Merge LAB channels with color map 32 | merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map)) 33 | 34 | # Convert merged LAB image back to RGB color space 35 | result_rgb = merged_lab.convert('RGB') 36 | return result_rgb 37 | 38 | def remove_unlikely_words(prompt: str) -> str: 39 | """ 40 | Removes unlikely words from a prompt. 41 | 42 | Args: 43 | prompt: The text prompt to be cleaned. 44 | 45 | Returns: 46 | The cleaned prompt with unlikely words removed. 47 | """ 48 | unlikely_words = [] 49 | 50 | a1_list = [f'{i}s' for i in range(1900, 2000)] 51 | a2_list = [f'{i}' for i in range(1900, 2000)] 52 | a3_list = [f'year {i}' for i in range(1900, 2000)] 53 | a4_list = [f'circa {i}' for i in range(1900, 2000)] 54 | b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list] 55 | b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] 56 | b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] 57 | b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] 58 | 59 | words_list = [ 60 | "black and white,", "black and white", "black & white,", "black & white", "circa", 61 | "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", 62 | "black - and - white photography,", "monochrome bw,", "black white,", "black an white,", 63 | "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", 64 | "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", 65 | "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", 66 | "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", 67 | "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", 68 | "black-and-white photo,", "black-and-white photo", "black - and - white photography", 69 | "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", 70 | "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", 71 | "black - and - white photograph,", "black - and - white photograph", "black on white,", 72 | "black on white", "black-and-white", "historical image,", "historical picture,", 73 | "historical photo,", "historical photograph,", "archival photo,", "taken in the early", 74 | "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", 75 | "historical photo", "historical setting,", 76 | "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", 77 | "taken in", "shot on leica", "shot on leica sl2", "sl2", 78 | "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting", 79 | "overcast day", "overcast weather", "slight overcast", "overcast", 80 | "picture taken in", "photo taken in", 81 | ", photo", ", photo", ", photo", ", photo", ", photograph", 82 | ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", 83 | ] 84 | 85 | unlikely_words.extend(a1_list) 86 | unlikely_words.extend(a2_list) 87 | unlikely_words.extend(a3_list) 88 | unlikely_words.extend(a4_list) 89 | unlikely_words.extend(b1_list) 90 | unlikely_words.extend(b2_list) 91 | unlikely_words.extend(b3_list) 92 | unlikely_words.extend(b4_list) 93 | unlikely_words.extend(words_list) 94 | 95 | for word in unlikely_words: 96 | prompt = prompt.replace(word, "") 97 | return prompt 98 | 99 | def blip_image_captioning(image: PIL.Image.Image, 100 | model_backbone: str, 101 | weight_dtype: type, 102 | device: str, 103 | conditional: bool) -> str: 104 | # https://huggingface.co/Salesforce/blip-image-captioning-large 105 | # https://huggingface.co/Salesforce/blip-image-captioning-base 106 | if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type 107 | weight_dtype = torch.float16 108 | 109 | processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}") 110 | model = BlipForConditionalGeneration.from_pretrained( 111 | f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device) 112 | 113 | valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"] 114 | if model_backbone not in valid_backbones: 115 | raise ValueError(f"Invalid model backbone '{model_backbone}'. \ 116 | Valid options are: {', '.join(valid_backbones)}") 117 | 118 | if conditional: 119 | text = "a photography of" 120 | inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype) 121 | else: 122 | inputs = processor(image, return_tensors="pt").to(device) 123 | out = model.generate(**inputs) 124 | caption = processor.decode(out[0], skip_special_tokens=True) 125 | return caption 126 | 127 | # def vit_gpt2_image_captioning(image: PIL.Image.Image, device: str) -> str: 128 | # # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning 129 | # model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device) 130 | # feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") 131 | # tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") 132 | 133 | # max_length = 16 134 | # num_beams = 4 135 | # gen_kwargs = {"max_length": max_length, "num_beams": num_beams} 136 | 137 | # pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values 138 | # pixel_values = pixel_values.to(device) 139 | 140 | # output_ids = model.generate(pixel_values, **gen_kwargs) 141 | 142 | # preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 143 | # caption = [pred.strip() for pred in preds] 144 | 145 | # return caption[0] 146 | 147 | # def clip_image_captioning(image: PIL.Image.Image, 148 | # clip_model_name: str, 149 | # device: str) -> str: 150 | # # validate clip model name 151 | # models = list_clip_models() 152 | # if clip_model_name not in models: 153 | # raise ValueError(f"Could not find CLIP model {clip_model_name}! \ 154 | # Available models: {models}") 155 | # config = Config(device=device, clip_model_name=clip_model_name) 156 | # config.apply_low_vram_defaults() 157 | # ci = Interrogator(config) 158 | # caption = ci.interrogate(image) 159 | # return caption 160 | 161 | # Define a function to process the image with the loaded model 162 | def process_image(image_path: str, 163 | controlnet_model_name_or_path: str, 164 | caption_model_name: str, 165 | positive_prompt: Optional[str], 166 | negative_prompt: Optional[str], 167 | seed: int, 168 | num_inference_steps: int, 169 | mixed_precision: str, 170 | pretrained_model_name_or_path: str, 171 | pretrained_vae_model_name_or_path: Optional[str], 172 | revision: Optional[str], 173 | variant: Optional[str], 174 | repo: str, 175 | ckpt: str,) -> PIL.Image.Image: 176 | # Seed 177 | generator = torch.manual_seed(seed) 178 | 179 | # Accelerator Setting 180 | accelerator = Accelerator( 181 | mixed_precision=mixed_precision, 182 | ) 183 | 184 | weight_dtype = torch.float32 185 | if accelerator.mixed_precision == "fp16": 186 | weight_dtype = torch.float16 187 | elif accelerator.mixed_precision == "bf16": 188 | weight_dtype = torch.bfloat16 189 | 190 | vae_path = ( 191 | pretrained_model_name_or_path 192 | if pretrained_vae_model_name_or_path is None 193 | else pretrained_vae_model_name_or_path 194 | ) 195 | vae = AutoencoderKL.from_pretrained( 196 | vae_path, 197 | subfolder="vae" if pretrained_vae_model_name_or_path is None else None, 198 | revision=revision, 199 | variant=variant, 200 | ) 201 | unet = UNet2DConditionModel.from_config( 202 | pretrained_model_name_or_path, 203 | subfolder="unet", 204 | revision=revision, 205 | variant=variant, 206 | ) 207 | unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) 208 | 209 | # Move vae, unet and text_encoder to device and cast to weight_dtype 210 | # The VAE is in float32 to avoid NaN losses. 211 | if pretrained_vae_model_name_or_path is not None: 212 | vae.to(accelerator.device, dtype=weight_dtype) 213 | else: 214 | vae.to(accelerator.device, dtype=torch.float32) 215 | unet.to(accelerator.device, dtype=weight_dtype) 216 | 217 | controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype) 218 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 219 | pretrained_model_name_or_path, 220 | vae=vae, 221 | unet=unet, 222 | controlnet=controlnet, 223 | ) 224 | pipe.to(accelerator.device, dtype=weight_dtype) 225 | 226 | image = PIL.Image.open(image_path) 227 | 228 | # Prepare everything with our `accelerator`. 229 | pipe, image = accelerator.prepare(pipe, image) 230 | pipe.safety_checker = None 231 | 232 | # Convert image into grayscale 233 | original_size = image.size 234 | control_image = image.convert("L").convert("RGB").resize((512, 512)) 235 | 236 | # Image captioning 237 | if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base": 238 | caption = blip_image_captioning(control_image, caption_model_name, 239 | weight_dtype, accelerator.device, conditional=True) 240 | # elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k": 241 | # caption = clip_image_captioning(control_image, caption_model_name, accelerator.device) 242 | # elif caption_model_name == "vit-gpt2-image-captioning": 243 | # caption = vit_gpt2_image_captioning(control_image, accelerator.device) 244 | caption = remove_unlikely_words(caption) 245 | 246 | # Combine positive prompt and captioning result 247 | prompt = [positive_prompt + ", " + caption] 248 | 249 | # Image colorization 250 | image = pipe(prompt=prompt, 251 | negative_prompt=negative_prompt, 252 | num_inference_steps=num_inference_steps, 253 | generator=generator, 254 | image=control_image).images[0] 255 | 256 | # Apply color mapping 257 | result_image = apply_color(control_image, image) 258 | result_image = result_image.resize(original_size) 259 | return result_image, caption 260 | 261 | # Define the image gallery based on folder path 262 | def get_image_paths(folder_path): 263 | import os 264 | image_paths = [] 265 | for filename in os.listdir(folder_path): 266 | if filename.endswith(".jpg") or filename.endswith(".png"): 267 | image_paths.append([os.path.join(folder_path, filename)]) 268 | return image_paths 269 | 270 | # Create the Gradio interface 271 | def create_interface(): 272 | controlnet_model_dict = { 273 | "sdxl-light-caption-30000": "sdxl_light_caption_output/checkpoint-30000/controlnet", 274 | "sdxl-light-custom-caption-30000": "sdxl_light_custom_caption_output/checkpoint-30000/controlnet", 275 | } 276 | images = get_image_paths("example/legacy_images") # Replace with your folder path 277 | 278 | interface = gr.Interface( 279 | fn=process_image, 280 | inputs=[ 281 | gr.Image(label="Upload image", 282 | value="example/legacy_images/Hollywood-Sign.jpg", 283 | type='filepath'), 284 | gr.Dropdown(choices=[controlnet_model_dict[key] for key in controlnet_model_dict], 285 | value=controlnet_model_dict["sdxl-light-caption-30000"], 286 | label="Select ControlNet Model"), 287 | gr.Dropdown(choices=["blip-image-captioning-large", 288 | "blip-image-captioning-base",], 289 | value="blip-image-captioning-large", 290 | label="Select Image Captioning Model"), 291 | gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt"), 292 | gr.Textbox(value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate", 293 | label="Negative Prompt", placeholder="Text for negative prompt"), 294 | ], 295 | outputs=[ 296 | gr.Image(label="Colorized image", 297 | value="example/UUColor_results/Hollywood-Sign.jpeg", 298 | format="jpeg"), 299 | gr.Textbox(label="Captioning Result", show_copy_button=True) 300 | ], 301 | examples=images, 302 | additional_inputs=[ 303 | # gr.Radio(choices=["Original", "Square"], value="Original", 304 | # label="Output resolution"), 305 | # gr.Slider(minimum=128, maximum=512, value=256, step=128, 306 | # label="Height & Width", 307 | # info='Only effect if select "Square" output resolution'), 308 | gr.Slider(0, 1000, 123, label="Seed"), 309 | gr.Radio(choices=[1, 2, 4, 8], 310 | value=8, 311 | label="Inference Steps", 312 | info="1-step, 2-step, 4-step, or 8-step distilled models"), 313 | gr.Radio(choices=["no", "fp16", "bf16"], 314 | value="fp16", 315 | label="Mixed Precision", 316 | info="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16)."), 317 | gr.Dropdown(choices=["stabilityai/stable-diffusion-xl-base-1.0"], 318 | value="stabilityai/stable-diffusion-xl-base-1.0", 319 | label="Base Model", 320 | info="Path to pretrained model or model identifier from huggingface.co/models."), 321 | gr.Dropdown(choices=["None"], 322 | value=None, 323 | label="VAE Model", 324 | info="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."), 325 | gr.Dropdown(choices=["None"], 326 | value=None, 327 | label="Varient", 328 | info="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"), 329 | gr.Dropdown(choices=["None"], 330 | value=None, 331 | label="Revision", 332 | info="Revision of pretrained model identifier from huggingface.co/models."), 333 | gr.Dropdown(choices=["ByteDance/SDXL-Lightning"], 334 | value="ByteDance/SDXL-Lightning", 335 | label="Repository", 336 | info="Repository from huggingface.co"), 337 | gr.Dropdown(choices=["sdxl_lightning_1step_unet.safetensors", 338 | "sdxl_lightning_2step_unet.safetensors", 339 | "sdxl_lightning_4step_unet.safetensors", 340 | "sdxl_lightning_8step_unet.safetensors"], 341 | value="sdxl_lightning_8step_unet.safetensors", 342 | label="Checkpoint", 343 | info="Available checkpoints from the repository. Caution! Checkpoint's 'N'step must match with inference steps"), 344 | ], 345 | title="Text-Guided Image Colorization", 346 | description="Upload an image and select a model to colorize it." 347 | ) 348 | return interface 349 | 350 | def main(): 351 | # Launch the Gradio interface 352 | interface = create_interface() 353 | interface.launch() 354 | 355 | if __name__ == "__main__": 356 | main() 357 | -------------------------------------------------------------------------------- /images/000000022935_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000022935_gray.jpg -------------------------------------------------------------------------------- /images/000000022935_green_shirt_on_right_girl.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000022935_green_shirt_on_right_girl.jpeg -------------------------------------------------------------------------------- /images/000000022935_purple_shirt_on_right_girl.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000022935_purple_shirt_on_right_girl.jpeg -------------------------------------------------------------------------------- /images/000000022935_red_shirt_on_right_girl.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000022935_red_shirt_on_right_girl.jpeg -------------------------------------------------------------------------------- /images/000000025560_color.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000025560_color.jpg -------------------------------------------------------------------------------- /images/000000025560_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000025560_gray.jpg -------------------------------------------------------------------------------- /images/000000025560_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000025560_gt.jpg -------------------------------------------------------------------------------- /images/000000041633_black_car.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000041633_black_car.jpeg -------------------------------------------------------------------------------- /images/000000041633_bright_red_car.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000041633_bright_red_car.jpeg -------------------------------------------------------------------------------- /images/000000041633_dark_blue_car.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000041633_dark_blue_car.jpeg -------------------------------------------------------------------------------- /images/000000041633_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000041633_gray.jpg -------------------------------------------------------------------------------- /images/000000065736_color.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000065736_color.jpg -------------------------------------------------------------------------------- /images/000000065736_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000065736_gray.jpg -------------------------------------------------------------------------------- /images/000000065736_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000065736_gt.jpg -------------------------------------------------------------------------------- /images/000000091779_color.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000091779_color.jpg -------------------------------------------------------------------------------- /images/000000091779_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000091779_gray.jpg -------------------------------------------------------------------------------- /images/000000091779_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000091779_gt.jpg -------------------------------------------------------------------------------- /images/000000092177_color.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000092177_color.jpg -------------------------------------------------------------------------------- /images/000000092177_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000092177_gray.jpg -------------------------------------------------------------------------------- /images/000000092177_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000092177_gt.jpg -------------------------------------------------------------------------------- /images/000000166426_color.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000166426_color.jpg -------------------------------------------------------------------------------- /images/000000166426_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000166426_gray.jpg -------------------------------------------------------------------------------- /images/000000166426_gt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000166426_gt.jpg -------------------------------------------------------------------------------- /images/000000286708_gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000286708_gray.jpg -------------------------------------------------------------------------------- /images/000000286708_orange_hat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000286708_orange_hat.jpeg -------------------------------------------------------------------------------- /images/000000286708_pink_hat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000286708_pink_hat.jpeg -------------------------------------------------------------------------------- /images/000000286708_yellow_hat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/000000286708_yellow_hat.jpeg -------------------------------------------------------------------------------- /images/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/framework.jpg -------------------------------------------------------------------------------- /images/gradio_ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nick8592/text-guided-image-colorization/a5954bc7ca24acd0712fae6bd81da0bce2da7cee/images/gradio_ui.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=0.16.0 2 | # torch==1.13.1+cu117 3 | # torchvision==0.14.1+cu117 4 | transformers>=4.25.1 5 | ftfy 6 | tensorboard 7 | datasets 8 | bitsandbytes 9 | git+https://github.com/huggingface/diffusers 10 | -------------------------------------------------------------------------------- /train_controlnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import contextlib 18 | import gc 19 | import logging 20 | import math 21 | import os 22 | import random 23 | import shutil 24 | from pathlib import Path 25 | 26 | import accelerate 27 | import numpy as np 28 | import torch 29 | import torch.nn.functional as F 30 | import torch.utils.checkpoint 31 | import transformers 32 | from accelerate import Accelerator 33 | from accelerate.logging import get_logger 34 | from accelerate.utils import ProjectConfiguration, set_seed 35 | from datasets import load_dataset 36 | from huggingface_hub import create_repo, upload_folder 37 | from packaging import version 38 | from PIL import Image 39 | from torchvision import transforms 40 | from tqdm.auto import tqdm 41 | from transformers import AutoTokenizer, PretrainedConfig 42 | 43 | import diffusers 44 | from diffusers import ( 45 | AutoencoderKL, 46 | ControlNetModel, 47 | DDPMScheduler, 48 | StableDiffusionControlNetPipeline, 49 | UNet2DConditionModel, 50 | UniPCMultistepScheduler, 51 | ) 52 | from diffusers.optimization import get_scheduler 53 | from diffusers.utils import check_min_version, is_wandb_available 54 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 55 | from diffusers.utils.import_utils import is_xformers_available 56 | from diffusers.utils.torch_utils import is_compiled_module 57 | 58 | if is_wandb_available(): 59 | import wandb 60 | 61 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 62 | check_min_version("0.28.0.dev0") 63 | 64 | logger = get_logger(__name__) 65 | 66 | 67 | def image_grid(imgs, rows, cols): 68 | assert len(imgs) == rows * cols 69 | 70 | w, h = imgs[0].size 71 | grid = Image.new("RGB", size=(cols * w, rows * h)) 72 | 73 | for i, img in enumerate(imgs): 74 | grid.paste(img, box=(i % cols * w, i // cols * h)) 75 | return grid 76 | 77 | 78 | def log_validation( 79 | vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False 80 | ): 81 | logger.info("Running validation... ") 82 | 83 | if not is_final_validation: 84 | controlnet = accelerator.unwrap_model(controlnet) 85 | else: 86 | controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) 87 | 88 | pipeline = StableDiffusionControlNetPipeline.from_pretrained( 89 | args.pretrained_model_name_or_path, 90 | vae=vae, 91 | text_encoder=text_encoder, 92 | tokenizer=tokenizer, 93 | unet=unet, 94 | controlnet=controlnet, 95 | safety_checker=None, 96 | revision=args.revision, 97 | variant=args.variant, 98 | torch_dtype=weight_dtype, 99 | ) 100 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) 101 | pipeline = pipeline.to(accelerator.device) 102 | pipeline.set_progress_bar_config(disable=True) 103 | 104 | if args.enable_xformers_memory_efficient_attention: 105 | pipeline.enable_xformers_memory_efficient_attention() 106 | 107 | if args.seed is None: 108 | generator = None 109 | else: 110 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 111 | 112 | if len(args.validation_image) == len(args.validation_prompt): 113 | validation_images = args.validation_image 114 | validation_prompts = args.validation_prompt 115 | elif len(args.validation_image) == 1: 116 | validation_images = args.validation_image * len(args.validation_prompt) 117 | validation_prompts = args.validation_prompt 118 | elif len(args.validation_prompt) == 1: 119 | validation_images = args.validation_image 120 | validation_prompts = args.validation_prompt * len(args.validation_image) 121 | else: 122 | raise ValueError( 123 | "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" 124 | ) 125 | 126 | image_logs = [] 127 | inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") 128 | 129 | for validation_prompt, validation_image in zip(validation_prompts, validation_images): 130 | validation_image = Image.open(validation_image).convert("RGB").resize((512, 512)) # resize to prevent size mismatch when stacking 131 | 132 | images = [] 133 | 134 | for _ in range(args.num_validation_images): 135 | with inference_ctx: 136 | image = pipeline( 137 | validation_prompt, validation_image, num_inference_steps=20, generator=generator 138 | ).images[0] 139 | 140 | images.append(image) 141 | 142 | image_logs.append( 143 | {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} 144 | ) 145 | 146 | tracker_key = "test" if is_final_validation else "validation" 147 | for tracker in accelerator.trackers: 148 | if tracker.name == "tensorboard": 149 | for log in image_logs: 150 | images = log["images"] 151 | validation_prompt = log["validation_prompt"] 152 | validation_image = log["validation_image"] 153 | 154 | formatted_images = [] 155 | 156 | formatted_images.append(np.asarray(validation_image)) 157 | 158 | for image in images: 159 | formatted_images.append(np.asarray(image)) 160 | 161 | formatted_images = np.stack(formatted_images) 162 | 163 | tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") 164 | elif tracker.name == "wandb": 165 | formatted_images = [] 166 | 167 | for log in image_logs: 168 | images = log["images"] 169 | validation_prompt = log["validation_prompt"] 170 | validation_image = log["validation_image"] 171 | 172 | formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) 173 | 174 | for image in images: 175 | image = wandb.Image(image, caption=validation_prompt) 176 | formatted_images.append(image) 177 | 178 | tracker.log({tracker_key: formatted_images}) 179 | else: 180 | logger.warning(f"image logging not implemented for {tracker.name}") 181 | 182 | del pipeline 183 | gc.collect() 184 | torch.cuda.empty_cache() 185 | 186 | return image_logs 187 | 188 | 189 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 190 | text_encoder_config = PretrainedConfig.from_pretrained( 191 | pretrained_model_name_or_path, 192 | subfolder="text_encoder", 193 | revision=revision, 194 | ) 195 | model_class = text_encoder_config.architectures[0] 196 | 197 | if model_class == "CLIPTextModel": 198 | from transformers import CLIPTextModel 199 | 200 | return CLIPTextModel 201 | elif model_class == "RobertaSeriesModelWithTransformation": 202 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 203 | 204 | return RobertaSeriesModelWithTransformation 205 | else: 206 | raise ValueError(f"{model_class} is not supported.") 207 | 208 | 209 | def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): 210 | img_str = "" 211 | if image_logs is not None: 212 | img_str = "You can find some example images below.\n\n" 213 | for i, log in enumerate(image_logs): 214 | images = log["images"] 215 | validation_prompt = log["validation_prompt"] 216 | validation_image = log["validation_image"] 217 | validation_image.save(os.path.join(repo_folder, "image_control.png")) 218 | img_str += f"prompt: {validation_prompt}\n" 219 | images = [validation_image] + images 220 | image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) 221 | img_str += f"![images_{i})](./images_{i}.png)\n" 222 | 223 | model_description = f""" 224 | # controlnet-{repo_id} 225 | 226 | These are controlnet weights trained on {base_model} with new type of conditioning. 227 | {img_str} 228 | """ 229 | model_card = load_or_create_model_card( 230 | repo_id_or_path=repo_id, 231 | from_training=True, 232 | license="creativeml-openrail-m", 233 | base_model=base_model, 234 | model_description=model_description, 235 | inference=True, 236 | ) 237 | 238 | tags = [ 239 | "stable-diffusion", 240 | "stable-diffusion-diffusers", 241 | "text-to-image", 242 | "diffusers", 243 | "controlnet", 244 | "diffusers-training", 245 | ] 246 | model_card = populate_model_card(model_card, tags=tags) 247 | 248 | model_card.save(os.path.join(repo_folder, "README.md")) 249 | 250 | 251 | def parse_args(input_args=None): 252 | parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") 253 | parser.add_argument( 254 | "--pretrained_model_name_or_path", 255 | type=str, 256 | default=None, 257 | required=True, 258 | help="Path to pretrained model or model identifier from huggingface.co/models.", 259 | ) 260 | parser.add_argument( 261 | "--controlnet_model_name_or_path", 262 | type=str, 263 | default=None, 264 | help="Path to pretrained controlnet model or model identifier from huggingface.co/models." 265 | " If not specified controlnet weights are initialized from unet.", 266 | ) 267 | parser.add_argument( 268 | "--revision", 269 | type=str, 270 | default=None, 271 | required=False, 272 | help="Revision of pretrained model identifier from huggingface.co/models.", 273 | ) 274 | parser.add_argument( 275 | "--variant", 276 | type=str, 277 | default=None, 278 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 279 | ) 280 | parser.add_argument( 281 | "--tokenizer_name", 282 | type=str, 283 | default=None, 284 | help="Pretrained tokenizer name or path if not the same as model_name", 285 | ) 286 | parser.add_argument( 287 | "--output_dir", 288 | type=str, 289 | default="controlnet-model", 290 | help="The output directory where the model predictions and checkpoints will be written.", 291 | ) 292 | parser.add_argument( 293 | "--cache_dir", 294 | type=str, 295 | default=None, 296 | help="The directory where the downloaded models and datasets will be stored.", 297 | ) 298 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 299 | parser.add_argument( 300 | "--resolution", 301 | type=int, 302 | default=512, 303 | help=( 304 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 305 | " resolution" 306 | ), 307 | ) 308 | parser.add_argument( 309 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 310 | ) 311 | parser.add_argument("--num_train_epochs", type=int, default=1) 312 | parser.add_argument( 313 | "--max_train_steps", 314 | type=int, 315 | default=None, 316 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 317 | ) 318 | parser.add_argument( 319 | "--checkpointing_steps", 320 | type=int, 321 | default=500, 322 | help=( 323 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 324 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 325 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 326 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 327 | "instructions." 328 | ), 329 | ) 330 | parser.add_argument( 331 | "--checkpoints_total_limit", 332 | type=int, 333 | default=None, 334 | help=("Max number of checkpoints to store."), 335 | ) 336 | parser.add_argument( 337 | "--resume_from_checkpoint", 338 | type=str, 339 | default=None, 340 | help=( 341 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 342 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 343 | ), 344 | ) 345 | parser.add_argument( 346 | "--gradient_accumulation_steps", 347 | type=int, 348 | default=1, 349 | help="Number of updates steps to accumulate before performing a backward/update pass.", 350 | ) 351 | parser.add_argument( 352 | "--gradient_checkpointing", 353 | action="store_true", 354 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 355 | ) 356 | parser.add_argument( 357 | "--learning_rate", 358 | type=float, 359 | default=5e-6, 360 | help="Initial learning rate (after the potential warmup period) to use.", 361 | ) 362 | parser.add_argument( 363 | "--scale_lr", 364 | action="store_true", 365 | default=False, 366 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 367 | ) 368 | parser.add_argument( 369 | "--lr_scheduler", 370 | type=str, 371 | default="constant", 372 | help=( 373 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 374 | ' "constant", "constant_with_warmup"]' 375 | ), 376 | ) 377 | parser.add_argument( 378 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 379 | ) 380 | parser.add_argument( 381 | "--lr_num_cycles", 382 | type=int, 383 | default=1, 384 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 385 | ) 386 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 387 | parser.add_argument( 388 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 389 | ) 390 | parser.add_argument( 391 | "--dataloader_num_workers", 392 | type=int, 393 | default=0, 394 | help=( 395 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 396 | ), 397 | ) 398 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 399 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 400 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 401 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 402 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 403 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 404 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 405 | parser.add_argument( 406 | "--hub_model_id", 407 | type=str, 408 | default=None, 409 | help="The name of the repository to keep in sync with the local `output_dir`.", 410 | ) 411 | parser.add_argument( 412 | "--logging_dir", 413 | type=str, 414 | default="logs", 415 | help=( 416 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 417 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 418 | ), 419 | ) 420 | parser.add_argument( 421 | "--allow_tf32", 422 | action="store_true", 423 | help=( 424 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 425 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 426 | ), 427 | ) 428 | parser.add_argument( 429 | "--report_to", 430 | type=str, 431 | default="tensorboard", 432 | help=( 433 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 434 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 435 | ), 436 | ) 437 | parser.add_argument( 438 | "--mixed_precision", 439 | type=str, 440 | default=None, 441 | choices=["no", "fp16", "bf16"], 442 | help=( 443 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 444 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 445 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 446 | ), 447 | ) 448 | parser.add_argument( 449 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 450 | ) 451 | parser.add_argument( 452 | "--set_grads_to_none", 453 | action="store_true", 454 | help=( 455 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 456 | " behaviors, so disable this argument if it causes any problems. More info:" 457 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 458 | ), 459 | ) 460 | parser.add_argument( 461 | "--dataset_name", 462 | type=str, 463 | default=None, 464 | help=( 465 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 466 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 467 | " or to a folder containing files that 🤗 Datasets can understand." 468 | ), 469 | ) 470 | parser.add_argument( 471 | "--dataset_config_name", 472 | type=str, 473 | default=None, 474 | help="The config of the Dataset, leave as None if there's only one config.", 475 | ) 476 | parser.add_argument( 477 | "--dataset_revision", 478 | type=str, 479 | default='main', 480 | help="The revision of the Dataset, leave as 'main' by default.", 481 | ) 482 | parser.add_argument( 483 | "--train_data_dir", 484 | type=str, 485 | default=None, 486 | help=( 487 | "A folder containing the training data. Folder contents must follow the structure described in" 488 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 489 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 490 | ), 491 | ) 492 | parser.add_argument( 493 | "--image_column", type=str, default="image", help="The column of the dataset containing the target image." 494 | ) 495 | parser.add_argument( 496 | "--conditioning_image_column", 497 | type=str, 498 | default="conditioning_image", 499 | help="The column of the dataset containing the controlnet conditioning image.", 500 | ) 501 | parser.add_argument( 502 | "--caption_column", 503 | type=str, 504 | default="text", 505 | help="The column of the dataset containing a caption or a list of captions.", 506 | ) 507 | parser.add_argument( 508 | "--max_train_samples", 509 | type=int, 510 | default=None, 511 | help=( 512 | "For debugging purposes or quicker training, truncate the number of training examples to this " 513 | "value if set." 514 | ), 515 | ) 516 | parser.add_argument( 517 | "--proportion_empty_prompts", 518 | type=float, 519 | default=0, 520 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", 521 | ) 522 | parser.add_argument( 523 | "--validation_prompt", 524 | type=str, 525 | default=None, 526 | nargs="+", 527 | help=( 528 | "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." 529 | " Provide either a matching number of `--validation_image`s, a single `--validation_image`" 530 | " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." 531 | ), 532 | ) 533 | parser.add_argument( 534 | "--validation_image", 535 | type=str, 536 | default=None, 537 | nargs="+", 538 | help=( 539 | "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" 540 | " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" 541 | " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" 542 | " `--validation_image` that will be used with all `--validation_prompt`s." 543 | ), 544 | ) 545 | parser.add_argument( 546 | "--num_validation_images", 547 | type=int, 548 | default=4, 549 | help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", 550 | ) 551 | parser.add_argument( 552 | "--validation_steps", 553 | type=int, 554 | default=100, 555 | help=( 556 | "Run validation every X steps. Validation consists of running the prompt" 557 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 558 | " and logging the images." 559 | ), 560 | ) 561 | parser.add_argument( 562 | "--tracker_project_name", 563 | type=str, 564 | default="train_controlnet", 565 | help=( 566 | "The `project_name` argument passed to Accelerator.init_trackers for" 567 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 568 | ), 569 | ) 570 | 571 | if input_args is not None: 572 | args = parser.parse_args(input_args) 573 | else: 574 | args = parser.parse_args() 575 | 576 | if args.dataset_name is None and args.train_data_dir is None: 577 | raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") 578 | 579 | if args.dataset_name is not None and args.train_data_dir is not None: 580 | raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") 581 | 582 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: 583 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") 584 | 585 | if args.validation_prompt is not None and args.validation_image is None: 586 | raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") 587 | 588 | if args.validation_prompt is None and args.validation_image is not None: 589 | raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") 590 | 591 | if ( 592 | args.validation_image is not None 593 | and args.validation_prompt is not None 594 | and len(args.validation_image) != 1 595 | and len(args.validation_prompt) != 1 596 | and len(args.validation_image) != len(args.validation_prompt) 597 | ): 598 | raise ValueError( 599 | "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," 600 | " or the same number of `--validation_prompt`s and `--validation_image`s" 601 | ) 602 | 603 | if args.resolution % 8 != 0: 604 | raise ValueError( 605 | "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." 606 | ) 607 | 608 | return args 609 | 610 | 611 | def make_train_dataset(args, tokenizer, accelerator): 612 | # Get the datasets: you can either provide your own training and evaluation files (see below) 613 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 614 | 615 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 616 | # download the dataset. 617 | if args.dataset_name is not None: 618 | # Downloading and loading a dataset from the hub. 619 | dataset = load_dataset( 620 | args.dataset_name, 621 | args.dataset_config_name, 622 | revision=args.dataset_revision, 623 | cache_dir=args.cache_dir, 624 | ) 625 | else: 626 | if args.train_data_dir is not None: 627 | dataset = load_dataset( 628 | args.train_data_dir, 629 | cache_dir=args.cache_dir, 630 | ) 631 | # See more about loading custom images at 632 | # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script 633 | 634 | # Preprocessing the datasets. 635 | # We need to tokenize inputs and targets. 636 | column_names = dataset["train"].column_names 637 | 638 | # 6. Get the column names for input/target. 639 | if args.image_column is None: 640 | image_column = column_names[0] 641 | logger.info(f"image column defaulting to {image_column}") 642 | else: 643 | image_column = args.image_column 644 | if image_column not in column_names: 645 | raise ValueError( 646 | f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 647 | ) 648 | 649 | if args.caption_column is None: 650 | caption_column = column_names[1] 651 | logger.info(f"caption column defaulting to {caption_column}") 652 | else: 653 | caption_column = args.caption_column 654 | if caption_column not in column_names: 655 | raise ValueError( 656 | f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 657 | ) 658 | 659 | if args.conditioning_image_column is None: 660 | conditioning_image_column = column_names[2] 661 | logger.info(f"conditioning image column defaulting to {conditioning_image_column}") 662 | else: 663 | conditioning_image_column = args.conditioning_image_column 664 | if conditioning_image_column not in column_names: 665 | raise ValueError( 666 | f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 667 | ) 668 | 669 | def tokenize_captions(examples, is_train=True): 670 | captions = [] 671 | for caption in examples[caption_column]: 672 | if random.random() < args.proportion_empty_prompts: 673 | captions.append("") 674 | elif isinstance(caption, str): 675 | captions.append(caption) 676 | elif isinstance(caption, (list, np.ndarray)): 677 | # take a random caption if there are multiple 678 | captions.append(random.choice(caption) if is_train else caption[0]) 679 | else: 680 | raise ValueError( 681 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 682 | ) 683 | inputs = tokenizer( 684 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 685 | ) 686 | return inputs.input_ids 687 | 688 | image_transforms = transforms.Compose( 689 | [ 690 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 691 | transforms.CenterCrop(args.resolution), 692 | transforms.ToTensor(), 693 | transforms.Normalize([0.5], [0.5]), 694 | ] 695 | ) 696 | 697 | conditioning_image_transforms = transforms.Compose( 698 | [ 699 | transforms.Grayscale(num_output_channels=3), # convert to grayscale image 700 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 701 | transforms.CenterCrop(args.resolution), 702 | transforms.ToTensor(), 703 | ] 704 | ) 705 | 706 | def preprocess_train(examples): 707 | images = [Image.open(image).convert("RGB") for image in examples[image_column]] 708 | images = [image_transforms(image) for image in images] 709 | 710 | conditioning_images = [Image.open(image).convert("RGB") for image in examples[conditioning_image_column]] 711 | conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] 712 | 713 | examples["pixel_values"] = images 714 | examples["conditioning_pixel_values"] = conditioning_images 715 | examples["input_ids"] = tokenize_captions(examples) 716 | 717 | return examples 718 | 719 | with accelerator.main_process_first(): 720 | if args.max_train_samples is not None: 721 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 722 | # Set the training transforms 723 | train_dataset = dataset["train"].with_transform(preprocess_train) 724 | 725 | return train_dataset 726 | 727 | 728 | def collate_fn(examples): 729 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 730 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 731 | 732 | conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) 733 | conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() 734 | 735 | input_ids = torch.stack([example["input_ids"] for example in examples]) 736 | 737 | return { 738 | "pixel_values": pixel_values, 739 | "conditioning_pixel_values": conditioning_pixel_values, 740 | "input_ids": input_ids, 741 | } 742 | 743 | 744 | def main(args): 745 | if args.report_to == "wandb" and args.hub_token is not None: 746 | raise ValueError( 747 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 748 | " Please use `huggingface-cli login` to authenticate with the Hub." 749 | ) 750 | 751 | logging_dir = Path(args.output_dir, args.logging_dir) 752 | 753 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 754 | 755 | accelerator = Accelerator( 756 | gradient_accumulation_steps=args.gradient_accumulation_steps, 757 | mixed_precision=args.mixed_precision, 758 | log_with=args.report_to, 759 | project_config=accelerator_project_config, 760 | ) 761 | 762 | # Disable AMP for MPS. 763 | if torch.backends.mps.is_available(): 764 | accelerator.native_amp = False 765 | 766 | # Make one log on every process with the configuration for debugging. 767 | logging.basicConfig( 768 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 769 | datefmt="%m/%d/%Y %H:%M:%S", 770 | level=logging.INFO, 771 | ) 772 | logger.info(accelerator.state, main_process_only=False) 773 | if accelerator.is_local_main_process: 774 | transformers.utils.logging.set_verbosity_warning() 775 | diffusers.utils.logging.set_verbosity_info() 776 | else: 777 | transformers.utils.logging.set_verbosity_error() 778 | diffusers.utils.logging.set_verbosity_error() 779 | 780 | # If passed along, set the training seed now. 781 | if args.seed is not None: 782 | set_seed(args.seed) 783 | 784 | # Handle the repository creation 785 | if accelerator.is_main_process: 786 | if args.output_dir is not None: 787 | os.makedirs(args.output_dir, exist_ok=True) 788 | 789 | if args.push_to_hub: 790 | repo_id = create_repo( 791 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 792 | ).repo_id 793 | 794 | # Load the tokenizer 795 | if args.tokenizer_name: 796 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 797 | elif args.pretrained_model_name_or_path: 798 | tokenizer = AutoTokenizer.from_pretrained( 799 | args.pretrained_model_name_or_path, 800 | subfolder="tokenizer", 801 | revision=args.revision, 802 | use_fast=False, 803 | ) 804 | 805 | # import correct text encoder class 806 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 807 | 808 | # Load scheduler and models 809 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 810 | text_encoder = text_encoder_cls.from_pretrained( 811 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 812 | ) 813 | vae = AutoencoderKL.from_pretrained( 814 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 815 | ) 816 | unet = UNet2DConditionModel.from_pretrained( 817 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 818 | ) 819 | 820 | if args.controlnet_model_name_or_path: 821 | logger.info("Loading existing controlnet weights") 822 | controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) 823 | else: 824 | logger.info("Initializing controlnet weights from unet") 825 | controlnet = ControlNetModel.from_unet(unet) 826 | 827 | # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) 828 | def unwrap_model(model): 829 | model = accelerator.unwrap_model(model) 830 | model = model._orig_mod if is_compiled_module(model) else model 831 | return model 832 | 833 | # `accelerate` 0.16.0 will have better support for customized saving 834 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 835 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 836 | def save_model_hook(models, weights, output_dir): 837 | if accelerator.is_main_process: 838 | i = len(weights) - 1 839 | 840 | while len(weights) > 0: 841 | weights.pop() 842 | model = models[i] 843 | 844 | sub_dir = "controlnet" 845 | model.save_pretrained(os.path.join(output_dir, sub_dir)) 846 | 847 | i -= 1 848 | 849 | def load_model_hook(models, input_dir): 850 | while len(models) > 0: 851 | # pop models so that they are not loaded again 852 | model = models.pop() 853 | 854 | # load diffusers style into model 855 | load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") 856 | model.register_to_config(**load_model.config) 857 | 858 | model.load_state_dict(load_model.state_dict()) 859 | del load_model 860 | 861 | accelerator.register_save_state_pre_hook(save_model_hook) 862 | accelerator.register_load_state_pre_hook(load_model_hook) 863 | 864 | vae.requires_grad_(False) 865 | unet.requires_grad_(False) 866 | text_encoder.requires_grad_(False) 867 | controlnet.train() 868 | 869 | if args.enable_xformers_memory_efficient_attention: 870 | if is_xformers_available(): 871 | import xformers 872 | 873 | xformers_version = version.parse(xformers.__version__) 874 | if xformers_version == version.parse("0.0.16"): 875 | logger.warning( 876 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 877 | ) 878 | unet.enable_xformers_memory_efficient_attention() 879 | controlnet.enable_xformers_memory_efficient_attention() 880 | else: 881 | raise ValueError("xformers is not available. Make sure it is installed correctly") 882 | 883 | if args.gradient_checkpointing: 884 | controlnet.enable_gradient_checkpointing() 885 | 886 | # Check that all trainable models are in full precision 887 | low_precision_error_string = ( 888 | " Please make sure to always have all model weights in full float32 precision when starting training - even if" 889 | " doing mixed precision training, copy of the weights should still be float32." 890 | ) 891 | 892 | if unwrap_model(controlnet).dtype != torch.float32: 893 | raise ValueError( 894 | f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" 895 | ) 896 | 897 | # Enable TF32 for faster training on Ampere GPUs, 898 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 899 | if args.allow_tf32: 900 | torch.backends.cuda.matmul.allow_tf32 = True 901 | 902 | if args.scale_lr: 903 | args.learning_rate = ( 904 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 905 | ) 906 | 907 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 908 | if args.use_8bit_adam: 909 | try: 910 | import bitsandbytes as bnb 911 | except ImportError: 912 | raise ImportError( 913 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 914 | ) 915 | 916 | optimizer_class = bnb.optim.AdamW8bit 917 | else: 918 | optimizer_class = torch.optim.AdamW 919 | 920 | # Optimizer creation 921 | params_to_optimize = controlnet.parameters() 922 | optimizer = optimizer_class( 923 | params_to_optimize, 924 | lr=args.learning_rate, 925 | betas=(args.adam_beta1, args.adam_beta2), 926 | weight_decay=args.adam_weight_decay, 927 | eps=args.adam_epsilon, 928 | ) 929 | 930 | train_dataset = make_train_dataset(args, tokenizer, accelerator) 931 | 932 | train_dataloader = torch.utils.data.DataLoader( 933 | train_dataset, 934 | shuffle=True, 935 | collate_fn=collate_fn, 936 | batch_size=args.train_batch_size, 937 | num_workers=args.dataloader_num_workers, 938 | ) 939 | 940 | # Scheduler and math around the number of training steps. 941 | overrode_max_train_steps = False 942 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 943 | if args.max_train_steps is None: 944 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 945 | overrode_max_train_steps = True 946 | 947 | lr_scheduler = get_scheduler( 948 | args.lr_scheduler, 949 | optimizer=optimizer, 950 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 951 | num_training_steps=args.max_train_steps * accelerator.num_processes, 952 | num_cycles=args.lr_num_cycles, 953 | power=args.lr_power, 954 | ) 955 | 956 | # Prepare everything with our `accelerator`. 957 | controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 958 | controlnet, optimizer, train_dataloader, lr_scheduler 959 | ) 960 | 961 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 962 | # as these models are only used for inference, keeping weights in full precision is not required. 963 | weight_dtype = torch.float32 964 | if accelerator.mixed_precision == "fp16": 965 | weight_dtype = torch.float16 966 | elif accelerator.mixed_precision == "bf16": 967 | weight_dtype = torch.bfloat16 968 | 969 | # Move vae, unet and text_encoder to device and cast to weight_dtype 970 | vae.to(accelerator.device, dtype=weight_dtype) 971 | unet.to(accelerator.device, dtype=weight_dtype) 972 | text_encoder.to(accelerator.device, dtype=weight_dtype) 973 | 974 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 975 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 976 | if overrode_max_train_steps: 977 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 978 | # Afterwards we recalculate our number of training epochs 979 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 980 | 981 | # We need to initialize the trackers we use, and also store our configuration. 982 | # The trackers initializes automatically on the main process. 983 | if accelerator.is_main_process: 984 | tracker_config = dict(vars(args)) 985 | 986 | # tensorboard cannot handle list types for config 987 | tracker_config.pop("validation_prompt") 988 | tracker_config.pop("validation_image") 989 | 990 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) 991 | 992 | # Train! 993 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 994 | 995 | logger.info("***** Running training *****") 996 | logger.info(f" Num examples = {len(train_dataset)}") 997 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 998 | logger.info(f" Num Epochs = {args.num_train_epochs}") 999 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 1000 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 1001 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 1002 | logger.info(f" Total optimization steps = {args.max_train_steps}") 1003 | global_step = 0 1004 | first_epoch = 0 1005 | 1006 | # Potentially load in the weights and states from a previous save 1007 | if args.resume_from_checkpoint: 1008 | if args.resume_from_checkpoint != "latest": 1009 | path = os.path.basename(args.resume_from_checkpoint) 1010 | else: 1011 | # Get the most recent checkpoint 1012 | dirs = os.listdir(args.output_dir) 1013 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1014 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1015 | path = dirs[-1] if len(dirs) > 0 else None 1016 | 1017 | if path is None: 1018 | accelerator.print( 1019 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 1020 | ) 1021 | args.resume_from_checkpoint = None 1022 | initial_global_step = 0 1023 | else: 1024 | accelerator.print(f"Resuming from checkpoint {path}") 1025 | accelerator.load_state(os.path.join(args.output_dir, path)) 1026 | global_step = int(path.split("-")[1]) 1027 | 1028 | initial_global_step = global_step 1029 | first_epoch = global_step // num_update_steps_per_epoch 1030 | else: 1031 | initial_global_step = 0 1032 | 1033 | progress_bar = tqdm( 1034 | range(0, args.max_train_steps), 1035 | initial=initial_global_step, 1036 | desc="Steps", 1037 | # Only show the progress bar once on each machine. 1038 | disable=not accelerator.is_local_main_process, 1039 | ) 1040 | 1041 | image_logs = None 1042 | for epoch in range(first_epoch, args.num_train_epochs): 1043 | for step, batch in enumerate(train_dataloader): 1044 | with accelerator.accumulate(controlnet): 1045 | # Convert images to latent space 1046 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 1047 | latents = latents * vae.config.scaling_factor 1048 | 1049 | # Sample noise that we'll add to the latents 1050 | noise = torch.randn_like(latents) 1051 | bsz = latents.shape[0] 1052 | # Sample a random timestep for each image 1053 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 1054 | timesteps = timesteps.long() 1055 | 1056 | # Add noise to the latents according to the noise magnitude at each timestep 1057 | # (this is the forward diffusion process) 1058 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1059 | 1060 | # Get the text embedding for conditioning 1061 | encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] 1062 | 1063 | controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) 1064 | 1065 | down_block_res_samples, mid_block_res_sample = controlnet( 1066 | noisy_latents, 1067 | timesteps, 1068 | encoder_hidden_states=encoder_hidden_states, 1069 | controlnet_cond=controlnet_image, 1070 | return_dict=False, 1071 | ) 1072 | 1073 | # Predict the noise residual 1074 | model_pred = unet( 1075 | noisy_latents, 1076 | timesteps, 1077 | encoder_hidden_states=encoder_hidden_states, 1078 | down_block_additional_residuals=[ 1079 | sample.to(dtype=weight_dtype) for sample in down_block_res_samples 1080 | ], 1081 | mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), 1082 | return_dict=False, 1083 | )[0] 1084 | 1085 | # Get the target for loss depending on the prediction type 1086 | if noise_scheduler.config.prediction_type == "epsilon": 1087 | target = noise 1088 | elif noise_scheduler.config.prediction_type == "v_prediction": 1089 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 1090 | else: 1091 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1092 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1093 | 1094 | accelerator.backward(loss) 1095 | if accelerator.sync_gradients: 1096 | params_to_clip = controlnet.parameters() 1097 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1098 | optimizer.step() 1099 | lr_scheduler.step() 1100 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 1101 | 1102 | # Checks if the accelerator has performed an optimization step behind the scenes 1103 | if accelerator.sync_gradients: 1104 | progress_bar.update(1) 1105 | global_step += 1 1106 | 1107 | if accelerator.is_main_process: 1108 | if global_step % args.checkpointing_steps == 0: 1109 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1110 | if args.checkpoints_total_limit is not None: 1111 | checkpoints = os.listdir(args.output_dir) 1112 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1113 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1114 | 1115 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1116 | if len(checkpoints) >= args.checkpoints_total_limit: 1117 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1118 | removing_checkpoints = checkpoints[0:num_to_remove] 1119 | 1120 | logger.info( 1121 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1122 | ) 1123 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1124 | 1125 | for removing_checkpoint in removing_checkpoints: 1126 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1127 | shutil.rmtree(removing_checkpoint) 1128 | 1129 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1130 | accelerator.save_state(save_path) 1131 | logger.info(f"Saved state to {save_path}") 1132 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 1133 | image_logs = log_validation( 1134 | vae, 1135 | text_encoder, 1136 | tokenizer, 1137 | unet, 1138 | controlnet, 1139 | args, 1140 | accelerator, 1141 | weight_dtype, 1142 | global_step, 1143 | ) 1144 | 1145 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1146 | progress_bar.set_postfix(**logs) 1147 | accelerator.log(logs, step=global_step) 1148 | 1149 | if global_step >= args.max_train_steps: 1150 | break 1151 | 1152 | # Create the pipeline using using the trained modules and save it. 1153 | accelerator.wait_for_everyone() 1154 | if accelerator.is_main_process: 1155 | controlnet = unwrap_model(controlnet) 1156 | controlnet.save_pretrained(args.output_dir) 1157 | 1158 | # Run a final round of validation. 1159 | image_logs = None 1160 | if args.validation_prompt is not None: 1161 | image_logs = log_validation( 1162 | vae=vae, 1163 | text_encoder=text_encoder, 1164 | tokenizer=tokenizer, 1165 | unet=unet, 1166 | controlnet=None, 1167 | args=args, 1168 | accelerator=accelerator, 1169 | weight_dtype=weight_dtype, 1170 | step=global_step, 1171 | is_final_validation=True, 1172 | ) 1173 | 1174 | if args.push_to_hub: 1175 | save_model_card( 1176 | repo_id, 1177 | image_logs=image_logs, 1178 | base_model=args.pretrained_model_name_or_path, 1179 | repo_folder=args.output_dir, 1180 | ) 1181 | upload_folder( 1182 | repo_id=repo_id, 1183 | folder_path=args.output_dir, 1184 | commit_message="End of training", 1185 | ignore_patterns=["step_*", "epoch_*"], 1186 | ) 1187 | 1188 | accelerator.end_training() 1189 | 1190 | 1191 | if __name__ == "__main__": 1192 | args = parse_args() 1193 | main(args) 1194 | -------------------------------------------------------------------------------- /train_controlnet.sh: -------------------------------------------------------------------------------- 1 | # Original ControlNet paper: 2 | # "In the training process, we randomly replace 50% text prompts ct with empty strings. 3 | # This approach increases ControlNet’s ability to directly recognize semantics 4 | # in the input conditioning images (e.g., edges, poses, depth, etc.) as a replacement for the prompt." 5 | # https://civitai.com/articles/2078/play-in-control-controlnet-training-setup-guide 6 | 7 | # export MODEL_DIR="runwayml/stable-diffusion-v1-5" 8 | export MODEL_DIR="stabilityai/stable-diffusion-2-base" 9 | export OUTPUT_DIR="sd_v2_caption_kl_output" 10 | export DATASET="nickpai/coco2017-colorization" 11 | export REVISION="main" # option: main/caption-free 12 | export VAL_IMG_NAME="'./000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg'" 13 | export VAL_PROMPT="'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.'" 14 | # export VAL_PROMPT="'Colorize this image as if it was taken with a color camera' 'Colorize this image' 'Add colors to this image' 'Make this image colorful' 'Colorize this grayscale image' 'Add colors to this image'" 15 | 16 | accelerate launch train_controlnet.py \ 17 | --pretrained_model_name_or_path=$MODEL_DIR \ 18 | --output_dir=$OUTPUT_DIR \ 19 | --seed=123123 \ 20 | --dataset_name=$DATASET \ 21 | --dataset_revision=$REVISION \ 22 | --image_column="file_name" \ 23 | --conditioning_image_column="file_name" \ 24 | --caption_column="captions" \ 25 | --max_train_samples=100000 \ 26 | --num_validation_images=1 \ 27 | --resolution=512 \ 28 | --num_train_epochs=5 \ 29 | --dataloader_num_workers=8 \ 30 | --learning_rate=1e-5 \ 31 | --validation_image './000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg' \ 32 | --validation_prompt 'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.' \ 33 | --train_batch_size=2 \ 34 | --gradient_accumulation_steps=8 \ 35 | --proportion_empty_prompts=0 \ 36 | --validation_steps=500 \ 37 | --checkpointing_steps=2500 \ 38 | --mixed_precision="fp16" \ 39 | --gradient_checkpointing \ 40 | --use_8bit_adam -------------------------------------------------------------------------------- /train_controlnet_sdxl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import functools 18 | import gc 19 | import logging 20 | import math 21 | import os 22 | import random 23 | import shutil 24 | from contextlib import nullcontext 25 | from pathlib import Path 26 | 27 | import accelerate 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | import transformers 33 | from accelerate import Accelerator 34 | from accelerate.logging import get_logger 35 | from accelerate.utils import ProjectConfiguration, set_seed 36 | from datasets import load_dataset 37 | from huggingface_hub import create_repo, upload_folder 38 | from packaging import version 39 | from PIL import Image 40 | from torchvision import transforms 41 | from tqdm.auto import tqdm 42 | from transformers import AutoTokenizer, PretrainedConfig 43 | 44 | import diffusers 45 | from diffusers import ( 46 | AutoencoderKL, 47 | ControlNetModel, 48 | DDPMScheduler, 49 | StableDiffusionXLControlNetPipeline, 50 | UNet2DConditionModel, 51 | UniPCMultistepScheduler, 52 | ) 53 | from diffusers.optimization import get_scheduler 54 | from diffusers.utils import check_min_version, is_wandb_available, make_image_grid 55 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 56 | from diffusers.utils.import_utils import is_xformers_available 57 | from diffusers.utils.torch_utils import is_compiled_module 58 | 59 | 60 | if is_wandb_available(): 61 | import wandb 62 | 63 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 64 | check_min_version("0.28.0.dev0") 65 | 66 | logger = get_logger(__name__) 67 | 68 | 69 | def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): 70 | logger.info("Running validation... ") 71 | 72 | if not is_final_validation: 73 | controlnet = accelerator.unwrap_model(controlnet) 74 | pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( 75 | args.pretrained_model_name_or_path, 76 | vae=vae, 77 | unet=unet, 78 | controlnet=controlnet, 79 | revision=args.revision, 80 | variant=args.variant, 81 | torch_dtype=weight_dtype, 82 | ) 83 | else: 84 | controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) 85 | if args.pretrained_vae_model_name_or_path is not None: 86 | vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype) 87 | else: 88 | vae = AutoencoderKL.from_pretrained( 89 | args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype 90 | ) 91 | 92 | pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( 93 | args.pretrained_model_name_or_path, 94 | vae=vae, 95 | controlnet=controlnet, 96 | revision=args.revision, 97 | variant=args.variant, 98 | torch_dtype=weight_dtype, 99 | ) 100 | 101 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) 102 | pipeline = pipeline.to(accelerator.device) 103 | pipeline.set_progress_bar_config(disable=True) 104 | 105 | if args.enable_xformers_memory_efficient_attention: 106 | pipeline.enable_xformers_memory_efficient_attention() 107 | 108 | if args.seed is None: 109 | generator = None 110 | else: 111 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 112 | 113 | if len(args.validation_image) == len(args.validation_prompt): 114 | validation_images = args.validation_image 115 | validation_prompts = args.validation_prompt 116 | elif len(args.validation_image) == 1: 117 | validation_images = args.validation_image * len(args.validation_prompt) 118 | validation_prompts = args.validation_prompt 119 | elif len(args.validation_prompt) == 1: 120 | validation_images = args.validation_image 121 | validation_prompts = args.validation_prompt * len(args.validation_image) 122 | else: 123 | raise ValueError( 124 | "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" 125 | ) 126 | 127 | image_logs = [] 128 | if is_final_validation or torch.backends.mps.is_available(): 129 | autocast_ctx = nullcontext() 130 | else: 131 | autocast_ctx = torch.autocast(accelerator.device.type) 132 | 133 | for validation_prompt, validation_image in zip(validation_prompts, validation_images): 134 | validation_image = Image.open(validation_image).convert("RGB") 135 | validation_image = validation_image.resize((args.resolution, args.resolution)) 136 | 137 | images = [] 138 | 139 | for _ in range(args.num_validation_images): 140 | with autocast_ctx: 141 | image = pipeline( 142 | prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator 143 | ).images[0] 144 | images.append(image) 145 | 146 | image_logs.append( 147 | {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} 148 | ) 149 | 150 | tracker_key = "test" if is_final_validation else "validation" 151 | for tracker in accelerator.trackers: 152 | if tracker.name == "tensorboard": 153 | for log in image_logs: 154 | images = log["images"] 155 | validation_prompt = log["validation_prompt"] 156 | validation_image = log["validation_image"] 157 | 158 | formatted_images = [] 159 | 160 | formatted_images.append(np.asarray(validation_image)) 161 | 162 | for image in images: 163 | formatted_images.append(np.asarray(image)) 164 | 165 | formatted_images = np.stack(formatted_images) 166 | 167 | tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") 168 | elif tracker.name == "wandb": 169 | formatted_images = [] 170 | 171 | for log in image_logs: 172 | images = log["images"] 173 | validation_prompt = log["validation_prompt"] 174 | validation_image = log["validation_image"] 175 | 176 | formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) 177 | 178 | for image in images: 179 | image = wandb.Image(image, caption=validation_prompt) 180 | formatted_images.append(image) 181 | 182 | tracker.log({tracker_key: formatted_images}) 183 | else: 184 | logger.warning(f"image logging not implemented for {tracker.name}") 185 | 186 | del pipeline 187 | gc.collect() 188 | torch.cuda.empty_cache() 189 | 190 | return image_logs 191 | 192 | 193 | def import_model_class_from_model_name_or_path( 194 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 195 | ): 196 | text_encoder_config = PretrainedConfig.from_pretrained( 197 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 198 | ) 199 | model_class = text_encoder_config.architectures[0] 200 | 201 | if model_class == "CLIPTextModel": 202 | from transformers import CLIPTextModel 203 | 204 | return CLIPTextModel 205 | elif model_class == "CLIPTextModelWithProjection": 206 | from transformers import CLIPTextModelWithProjection 207 | 208 | return CLIPTextModelWithProjection 209 | else: 210 | raise ValueError(f"{model_class} is not supported.") 211 | 212 | 213 | def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): 214 | img_str = "" 215 | if image_logs is not None: 216 | img_str = "You can find some example images below.\n\n" 217 | for i, log in enumerate(image_logs): 218 | images = log["images"] 219 | validation_prompt = log["validation_prompt"] 220 | validation_image = log["validation_image"] 221 | validation_image.save(os.path.join(repo_folder, "image_control.png")) 222 | img_str += f"prompt: {validation_prompt}\n" 223 | images = [validation_image] + images 224 | make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) 225 | img_str += f"![images_{i})](./images_{i}.png)\n" 226 | 227 | model_description = f""" 228 | # controlnet-{repo_id} 229 | 230 | These are controlnet weights trained on {base_model} with new type of conditioning. 231 | {img_str} 232 | """ 233 | 234 | model_card = load_or_create_model_card( 235 | repo_id_or_path=repo_id, 236 | from_training=True, 237 | license="openrail++", 238 | base_model=base_model, 239 | model_description=model_description, 240 | inference=True, 241 | ) 242 | 243 | tags = [ 244 | "stable-diffusion-xl", 245 | "stable-diffusion-xl-diffusers", 246 | "text-to-image", 247 | "diffusers", 248 | "controlnet", 249 | "diffusers-training", 250 | ] 251 | model_card = populate_model_card(model_card, tags=tags) 252 | 253 | model_card.save(os.path.join(repo_folder, "README.md")) 254 | 255 | 256 | def parse_args(input_args=None): 257 | parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") 258 | parser.add_argument( 259 | "--pretrained_model_name_or_path", 260 | type=str, 261 | default=None, 262 | required=True, 263 | help="Path to pretrained model or model identifier from huggingface.co/models.", 264 | ) 265 | parser.add_argument( 266 | "--pretrained_vae_model_name_or_path", 267 | type=str, 268 | default=None, 269 | help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", 270 | ) 271 | parser.add_argument( 272 | "--controlnet_model_name_or_path", 273 | type=str, 274 | default=None, 275 | help="Path to pretrained controlnet model or model identifier from huggingface.co/models." 276 | " If not specified controlnet weights are initialized from unet.", 277 | ) 278 | parser.add_argument( 279 | "--variant", 280 | type=str, 281 | default=None, 282 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 283 | ) 284 | parser.add_argument( 285 | "--revision", 286 | type=str, 287 | default=None, 288 | required=False, 289 | help="Revision of pretrained model identifier from huggingface.co/models.", 290 | ) 291 | parser.add_argument( 292 | "--tokenizer_name", 293 | type=str, 294 | default=None, 295 | help="Pretrained tokenizer name or path if not the same as model_name", 296 | ) 297 | parser.add_argument( 298 | "--output_dir", 299 | type=str, 300 | default="controlnet-model", 301 | help="The output directory where the model predictions and checkpoints will be written.", 302 | ) 303 | parser.add_argument( 304 | "--cache_dir", 305 | type=str, 306 | default=None, 307 | help="The directory where the downloaded models and datasets will be stored.", 308 | ) 309 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 310 | parser.add_argument( 311 | "--resolution", 312 | type=int, 313 | default=512, 314 | help=( 315 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 316 | " resolution" 317 | ), 318 | ) 319 | parser.add_argument( 320 | "--crops_coords_top_left_h", 321 | type=int, 322 | default=0, 323 | help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), 324 | ) 325 | parser.add_argument( 326 | "--crops_coords_top_left_w", 327 | type=int, 328 | default=0, 329 | help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), 330 | ) 331 | parser.add_argument( 332 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 333 | ) 334 | parser.add_argument("--num_train_epochs", type=int, default=1) 335 | parser.add_argument( 336 | "--max_train_steps", 337 | type=int, 338 | default=None, 339 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 340 | ) 341 | parser.add_argument( 342 | "--checkpointing_steps", 343 | type=int, 344 | default=500, 345 | help=( 346 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 347 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 348 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 349 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 350 | "instructions." 351 | ), 352 | ) 353 | parser.add_argument( 354 | "--checkpoints_total_limit", 355 | type=int, 356 | default=None, 357 | help=("Max number of checkpoints to store."), 358 | ) 359 | parser.add_argument( 360 | "--resume_from_checkpoint", 361 | type=str, 362 | default=None, 363 | help=( 364 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 365 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 366 | ), 367 | ) 368 | parser.add_argument( 369 | "--gradient_accumulation_steps", 370 | type=int, 371 | default=1, 372 | help="Number of updates steps to accumulate before performing a backward/update pass.", 373 | ) 374 | parser.add_argument( 375 | "--gradient_checkpointing", 376 | action="store_true", 377 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 378 | ) 379 | parser.add_argument( 380 | "--learning_rate", 381 | type=float, 382 | default=5e-6, 383 | help="Initial learning rate (after the potential warmup period) to use.", 384 | ) 385 | parser.add_argument( 386 | "--scale_lr", 387 | action="store_true", 388 | default=False, 389 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 390 | ) 391 | parser.add_argument( 392 | "--lr_scheduler", 393 | type=str, 394 | default="constant", 395 | help=( 396 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 397 | ' "constant", "constant_with_warmup"]' 398 | ), 399 | ) 400 | parser.add_argument( 401 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 402 | ) 403 | parser.add_argument( 404 | "--lr_num_cycles", 405 | type=int, 406 | default=1, 407 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 408 | ) 409 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 410 | parser.add_argument( 411 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 412 | ) 413 | parser.add_argument( 414 | "--dataloader_num_workers", 415 | type=int, 416 | default=0, 417 | help=( 418 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 419 | ), 420 | ) 421 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 422 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 423 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 424 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 425 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 426 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 427 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 428 | parser.add_argument( 429 | "--hub_model_id", 430 | type=str, 431 | default=None, 432 | help="The name of the repository to keep in sync with the local `output_dir`.", 433 | ) 434 | parser.add_argument( 435 | "--logging_dir", 436 | type=str, 437 | default="logs", 438 | help=( 439 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 440 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 441 | ), 442 | ) 443 | parser.add_argument( 444 | "--allow_tf32", 445 | action="store_true", 446 | help=( 447 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 448 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 449 | ), 450 | ) 451 | parser.add_argument( 452 | "--report_to", 453 | type=str, 454 | default="tensorboard", 455 | help=( 456 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 457 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 458 | ), 459 | ) 460 | parser.add_argument( 461 | "--mixed_precision", 462 | type=str, 463 | default=None, 464 | choices=["no", "fp16", "bf16"], 465 | help=( 466 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 467 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 468 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 469 | ), 470 | ) 471 | parser.add_argument( 472 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 473 | ) 474 | parser.add_argument( 475 | "--set_grads_to_none", 476 | action="store_true", 477 | help=( 478 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 479 | " behaviors, so disable this argument if it causes any problems. More info:" 480 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 481 | ), 482 | ) 483 | parser.add_argument( 484 | "--dataset_name", 485 | type=str, 486 | default=None, 487 | help=( 488 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 489 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 490 | " or to a folder containing files that 🤗 Datasets can understand." 491 | ), 492 | ) 493 | parser.add_argument( 494 | "--dataset_config_name", 495 | type=str, 496 | default=None, 497 | help="The config of the Dataset, leave as None if there's only one config.", 498 | ) 499 | parser.add_argument( 500 | "--dataset_revision", 501 | type=str, 502 | default='main', 503 | help="The revision of the Dataset, leave as 'main' by default.", 504 | ) 505 | parser.add_argument( 506 | "--train_data_dir", 507 | type=str, 508 | default=None, 509 | help=( 510 | "A folder containing the training data. Folder contents must follow the structure described in" 511 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 512 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 513 | ), 514 | ) 515 | parser.add_argument( 516 | "--image_column", type=str, default="image", help="The column of the dataset containing the target image." 517 | ) 518 | parser.add_argument( 519 | "--conditioning_image_column", 520 | type=str, 521 | default="conditioning_image", 522 | help="The column of the dataset containing the controlnet conditioning image.", 523 | ) 524 | parser.add_argument( 525 | "--caption_column", 526 | type=str, 527 | default="text", 528 | help="The column of the dataset containing a caption or a list of captions.", 529 | ) 530 | parser.add_argument( 531 | "--max_train_samples", 532 | type=int, 533 | default=None, 534 | help=( 535 | "For debugging purposes or quicker training, truncate the number of training examples to this " 536 | "value if set." 537 | ), 538 | ) 539 | parser.add_argument( 540 | "--proportion_empty_prompts", 541 | type=float, 542 | default=0, 543 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", 544 | ) 545 | parser.add_argument( 546 | "--validation_prompt", 547 | type=str, 548 | default=None, 549 | nargs="+", 550 | help=( 551 | "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." 552 | " Provide either a matching number of `--validation_image`s, a single `--validation_image`" 553 | " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." 554 | ), 555 | ) 556 | parser.add_argument( 557 | "--validation_image", 558 | type=str, 559 | default=None, 560 | nargs="+", 561 | help=( 562 | "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" 563 | " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" 564 | " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" 565 | " `--validation_image` that will be used with all `--validation_prompt`s." 566 | ), 567 | ) 568 | parser.add_argument( 569 | "--num_validation_images", 570 | type=int, 571 | default=4, 572 | help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", 573 | ) 574 | parser.add_argument( 575 | "--validation_steps", 576 | type=int, 577 | default=100, 578 | help=( 579 | "Run validation every X steps. Validation consists of running the prompt" 580 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 581 | " and logging the images." 582 | ), 583 | ) 584 | parser.add_argument( 585 | "--tracker_project_name", 586 | type=str, 587 | default="sd_xl_train_controlnet", 588 | help=( 589 | "The `project_name` argument passed to Accelerator.init_trackers for" 590 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 591 | ), 592 | ) 593 | 594 | if input_args is not None: 595 | args = parser.parse_args(input_args) 596 | else: 597 | args = parser.parse_args() 598 | 599 | if args.dataset_name is None and args.train_data_dir is None: 600 | raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") 601 | 602 | if args.dataset_name is not None and args.train_data_dir is not None: 603 | raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") 604 | 605 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: 606 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") 607 | 608 | if args.validation_prompt is not None and args.validation_image is None: 609 | raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") 610 | 611 | if args.validation_prompt is None and args.validation_image is not None: 612 | raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") 613 | 614 | if ( 615 | args.validation_image is not None 616 | and args.validation_prompt is not None 617 | and len(args.validation_image) != 1 618 | and len(args.validation_prompt) != 1 619 | and len(args.validation_image) != len(args.validation_prompt) 620 | ): 621 | raise ValueError( 622 | "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," 623 | " or the same number of `--validation_prompt`s and `--validation_image`s" 624 | ) 625 | 626 | if args.resolution % 8 != 0: 627 | raise ValueError( 628 | "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." 629 | ) 630 | 631 | return args 632 | 633 | 634 | def get_train_dataset(args, accelerator): 635 | # Get the datasets: you can either provide your own training and evaluation files (see below) 636 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 637 | 638 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 639 | # download the dataset. 640 | if args.dataset_name is not None: 641 | # Downloading and loading a dataset from the hub. 642 | dataset = load_dataset( 643 | args.dataset_name, 644 | args.dataset_config_name, 645 | revision=args.dataset_revision, 646 | cache_dir=args.cache_dir, 647 | ) 648 | else: 649 | if args.train_data_dir is not None: 650 | dataset = load_dataset( 651 | args.train_data_dir, 652 | cache_dir=args.cache_dir, 653 | ) 654 | # See more about loading custom images at 655 | # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script 656 | 657 | # Preprocessing the datasets. 658 | # We need to tokenize inputs and targets. 659 | column_names = dataset["train"].column_names 660 | 661 | # 6. Get the column names for input/target. 662 | if args.image_column is None: 663 | image_column = column_names[0] 664 | logger.info(f"image column defaulting to {image_column}") 665 | else: 666 | image_column = args.image_column 667 | if image_column not in column_names: 668 | raise ValueError( 669 | f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 670 | ) 671 | 672 | if args.caption_column is None: 673 | caption_column = column_names[1] 674 | logger.info(f"caption column defaulting to {caption_column}") 675 | else: 676 | caption_column = args.caption_column 677 | if caption_column not in column_names: 678 | raise ValueError( 679 | f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 680 | ) 681 | 682 | if args.conditioning_image_column is None: 683 | conditioning_image_column = column_names[2] 684 | logger.info(f"conditioning image column defaulting to {conditioning_image_column}") 685 | else: 686 | conditioning_image_column = args.conditioning_image_column 687 | if conditioning_image_column not in column_names: 688 | raise ValueError( 689 | f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 690 | ) 691 | 692 | with accelerator.main_process_first(): 693 | train_dataset = dataset["train"].shuffle(seed=args.seed) 694 | if args.max_train_samples is not None: 695 | train_dataset = train_dataset.select(range(args.max_train_samples)) 696 | return train_dataset 697 | 698 | 699 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt 700 | def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): 701 | prompt_embeds_list = [] 702 | 703 | captions = [] 704 | for caption in prompt_batch: 705 | if random.random() < proportion_empty_prompts: 706 | captions.append("") 707 | elif isinstance(caption, str): 708 | captions.append(caption) 709 | elif isinstance(caption, (list, np.ndarray)): 710 | # take a random caption if there are multiple 711 | captions.append(random.choice(caption) if is_train else caption[0]) 712 | 713 | with torch.no_grad(): 714 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 715 | text_inputs = tokenizer( 716 | captions, 717 | padding="max_length", 718 | max_length=tokenizer.model_max_length, 719 | truncation=True, 720 | return_tensors="pt", 721 | ) 722 | text_input_ids = text_inputs.input_ids 723 | prompt_embeds = text_encoder( 724 | text_input_ids.to(text_encoder.device), 725 | output_hidden_states=True, 726 | ) 727 | 728 | # We are only ALWAYS interested in the pooled output of the final text encoder 729 | pooled_prompt_embeds = prompt_embeds[0] 730 | prompt_embeds = prompt_embeds.hidden_states[-2] 731 | bs_embed, seq_len, _ = prompt_embeds.shape 732 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 733 | prompt_embeds_list.append(prompt_embeds) 734 | 735 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 736 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 737 | return prompt_embeds, pooled_prompt_embeds 738 | 739 | 740 | def prepare_train_dataset(dataset, accelerator): 741 | image_transforms = transforms.Compose( 742 | [ 743 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 744 | transforms.CenterCrop(args.resolution), 745 | transforms.ToTensor(), 746 | transforms.Normalize([0.5], [0.5]), 747 | ] 748 | ) 749 | 750 | conditioning_image_transforms = transforms.Compose( 751 | [ 752 | transforms.Grayscale(num_output_channels=3), # convert to grayscale image 753 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 754 | transforms.CenterCrop(args.resolution), 755 | transforms.ToTensor(), 756 | ] 757 | ) 758 | 759 | def preprocess_train(examples): 760 | images = [Image.open(image).convert("RGB") for image in examples[args.image_column]] 761 | images = [image_transforms(image) for image in images] 762 | 763 | conditioning_images = [Image.open(image).convert("RGB") for image in examples[args.conditioning_image_column]] 764 | conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] 765 | 766 | examples["pixel_values"] = images 767 | examples["conditioning_pixel_values"] = conditioning_images 768 | 769 | return examples 770 | 771 | with accelerator.main_process_first(): 772 | dataset = dataset.with_transform(preprocess_train) 773 | 774 | return dataset 775 | 776 | 777 | def collate_fn(examples): 778 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 779 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 780 | 781 | conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) 782 | conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() 783 | 784 | prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) 785 | 786 | add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples]) 787 | add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples]) 788 | 789 | return { 790 | "pixel_values": pixel_values, 791 | "conditioning_pixel_values": conditioning_pixel_values, 792 | "prompt_ids": prompt_ids, 793 | "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, 794 | } 795 | 796 | 797 | def main(args): 798 | if args.report_to == "wandb" and args.hub_token is not None: 799 | raise ValueError( 800 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 801 | " Please use `huggingface-cli login` to authenticate with the Hub." 802 | ) 803 | 804 | logging_dir = Path(args.output_dir, args.logging_dir) 805 | 806 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16": 807 | # due to pytorch#99272, MPS does not yet support bfloat16. 808 | raise ValueError( 809 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 810 | ) 811 | 812 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 813 | 814 | accelerator = Accelerator( 815 | gradient_accumulation_steps=args.gradient_accumulation_steps, 816 | mixed_precision=args.mixed_precision, 817 | log_with=args.report_to, 818 | project_config=accelerator_project_config, 819 | ) 820 | 821 | # Disable AMP for MPS. 822 | if torch.backends.mps.is_available(): 823 | accelerator.native_amp = False 824 | 825 | # Make one log on every process with the configuration for debugging. 826 | logging.basicConfig( 827 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 828 | datefmt="%m/%d/%Y %H:%M:%S", 829 | level=logging.INFO, 830 | ) 831 | logger.info(accelerator.state, main_process_only=False) 832 | if accelerator.is_local_main_process: 833 | transformers.utils.logging.set_verbosity_warning() 834 | diffusers.utils.logging.set_verbosity_info() 835 | else: 836 | transformers.utils.logging.set_verbosity_error() 837 | diffusers.utils.logging.set_verbosity_error() 838 | 839 | # If passed along, set the training seed now. 840 | if args.seed is not None: 841 | set_seed(args.seed) 842 | 843 | # Handle the repository creation 844 | if accelerator.is_main_process: 845 | if args.output_dir is not None: 846 | os.makedirs(args.output_dir, exist_ok=True) 847 | 848 | if args.push_to_hub: 849 | repo_id = create_repo( 850 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 851 | ).repo_id 852 | 853 | # Load the tokenizers 854 | tokenizer_one = AutoTokenizer.from_pretrained( 855 | args.pretrained_model_name_or_path, 856 | subfolder="tokenizer", 857 | revision=args.revision, 858 | use_fast=False, 859 | ) 860 | tokenizer_two = AutoTokenizer.from_pretrained( 861 | args.pretrained_model_name_or_path, 862 | subfolder="tokenizer_2", 863 | revision=args.revision, 864 | use_fast=False, 865 | ) 866 | 867 | # import correct text encoder classes 868 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 869 | args.pretrained_model_name_or_path, args.revision 870 | ) 871 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 872 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 873 | ) 874 | 875 | # Load scheduler and models 876 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 877 | text_encoder_one = text_encoder_cls_one.from_pretrained( 878 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 879 | ) 880 | text_encoder_two = text_encoder_cls_two.from_pretrained( 881 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant 882 | ) 883 | vae_path = ( 884 | args.pretrained_model_name_or_path 885 | if args.pretrained_vae_model_name_or_path is None 886 | else args.pretrained_vae_model_name_or_path 887 | ) 888 | vae = AutoencoderKL.from_pretrained( 889 | vae_path, 890 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 891 | revision=args.revision, 892 | variant=args.variant, 893 | ) 894 | unet = UNet2DConditionModel.from_pretrained( 895 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 896 | ) 897 | 898 | if args.controlnet_model_name_or_path: 899 | logger.info("Loading existing controlnet weights") 900 | controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) 901 | else: 902 | logger.info("Initializing controlnet weights from unet") 903 | controlnet = ControlNetModel.from_unet(unet) 904 | 905 | def unwrap_model(model): 906 | model = accelerator.unwrap_model(model) 907 | model = model._orig_mod if is_compiled_module(model) else model 908 | return model 909 | 910 | # `accelerate` 0.16.0 will have better support for customized saving 911 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 912 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 913 | def save_model_hook(models, weights, output_dir): 914 | if accelerator.is_main_process: 915 | i = len(weights) - 1 916 | 917 | while len(weights) > 0: 918 | weights.pop() 919 | model = models[i] 920 | 921 | sub_dir = "controlnet" 922 | model.save_pretrained(os.path.join(output_dir, sub_dir)) 923 | 924 | i -= 1 925 | 926 | def load_model_hook(models, input_dir): 927 | while len(models) > 0: 928 | # pop models so that they are not loaded again 929 | model = models.pop() 930 | 931 | # load diffusers style into model 932 | load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") 933 | model.register_to_config(**load_model.config) 934 | 935 | model.load_state_dict(load_model.state_dict()) 936 | del load_model 937 | 938 | accelerator.register_save_state_pre_hook(save_model_hook) 939 | accelerator.register_load_state_pre_hook(load_model_hook) 940 | 941 | vae.requires_grad_(False) 942 | unet.requires_grad_(False) 943 | text_encoder_one.requires_grad_(False) 944 | text_encoder_two.requires_grad_(False) 945 | controlnet.train() 946 | 947 | if args.enable_xformers_memory_efficient_attention: 948 | if is_xformers_available(): 949 | import xformers 950 | 951 | xformers_version = version.parse(xformers.__version__) 952 | if xformers_version == version.parse("0.0.16"): 953 | logger.warning( 954 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 955 | ) 956 | unet.enable_xformers_memory_efficient_attention() 957 | controlnet.enable_xformers_memory_efficient_attention() 958 | else: 959 | raise ValueError("xformers is not available. Make sure it is installed correctly") 960 | 961 | if args.gradient_checkpointing: 962 | controlnet.enable_gradient_checkpointing() 963 | unet.enable_gradient_checkpointing() 964 | 965 | # Check that all trainable models are in full precision 966 | low_precision_error_string = ( 967 | " Please make sure to always have all model weights in full float32 precision when starting training - even if" 968 | " doing mixed precision training, copy of the weights should still be float32." 969 | ) 970 | 971 | if unwrap_model(controlnet).dtype != torch.float32: 972 | raise ValueError( 973 | f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" 974 | ) 975 | 976 | # Enable TF32 for faster training on Ampere GPUs, 977 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 978 | if args.allow_tf32: 979 | torch.backends.cuda.matmul.allow_tf32 = True 980 | 981 | if args.scale_lr: 982 | args.learning_rate = ( 983 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 984 | ) 985 | 986 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 987 | if args.use_8bit_adam: 988 | try: 989 | import bitsandbytes as bnb 990 | except ImportError: 991 | raise ImportError( 992 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 993 | ) 994 | 995 | optimizer_class = bnb.optim.AdamW8bit 996 | else: 997 | optimizer_class = torch.optim.AdamW 998 | 999 | # Optimizer creation 1000 | params_to_optimize = controlnet.parameters() 1001 | optimizer = optimizer_class( 1002 | params_to_optimize, 1003 | lr=args.learning_rate, 1004 | betas=(args.adam_beta1, args.adam_beta2), 1005 | weight_decay=args.adam_weight_decay, 1006 | eps=args.adam_epsilon, 1007 | ) 1008 | 1009 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 1010 | # as these models are only used for inference, keeping weights in full precision is not required. 1011 | weight_dtype = torch.float32 1012 | if accelerator.mixed_precision == "fp16": 1013 | weight_dtype = torch.float16 1014 | elif accelerator.mixed_precision == "bf16": 1015 | weight_dtype = torch.bfloat16 1016 | 1017 | # Move vae, unet and text_encoder to device and cast to weight_dtype 1018 | # The VAE is in float32 to avoid NaN losses. 1019 | if args.pretrained_vae_model_name_or_path is not None: 1020 | vae.to(accelerator.device, dtype=weight_dtype) 1021 | else: 1022 | vae.to(accelerator.device, dtype=torch.float32) 1023 | unet.to(accelerator.device, dtype=weight_dtype) 1024 | text_encoder_one.to(accelerator.device, dtype=weight_dtype) 1025 | text_encoder_two.to(accelerator.device, dtype=weight_dtype) 1026 | 1027 | # Here, we compute not just the text embeddings but also the additional embeddings 1028 | # needed for the SD XL UNet to operate. 1029 | def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True): 1030 | original_size = (args.resolution, args.resolution) 1031 | target_size = (args.resolution, args.resolution) 1032 | crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) 1033 | prompt_batch = batch[args.caption_column] 1034 | 1035 | prompt_embeds, pooled_prompt_embeds = encode_prompt( 1036 | prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train 1037 | ) 1038 | add_text_embeds = pooled_prompt_embeds 1039 | 1040 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 1041 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 1042 | add_time_ids = torch.tensor([add_time_ids]) 1043 | 1044 | prompt_embeds = prompt_embeds.to(accelerator.device) 1045 | add_text_embeds = add_text_embeds.to(accelerator.device) 1046 | add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) 1047 | add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) 1048 | unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 1049 | 1050 | return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} 1051 | 1052 | # Let's first compute all the embeddings so that we can free up the text encoders 1053 | # from memory. 1054 | text_encoders = [text_encoder_one, text_encoder_two] 1055 | tokenizers = [tokenizer_one, tokenizer_two] 1056 | train_dataset = get_train_dataset(args, accelerator) 1057 | compute_embeddings_fn = functools.partial( 1058 | compute_embeddings, 1059 | text_encoders=text_encoders, 1060 | tokenizers=tokenizers, 1061 | proportion_empty_prompts=args.proportion_empty_prompts, 1062 | ) 1063 | with accelerator.main_process_first(): 1064 | from datasets.fingerprint import Hasher 1065 | 1066 | # fingerprint used by the cache for the other processes to load the result 1067 | # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 1068 | new_fingerprint = Hasher.hash(args) 1069 | train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) 1070 | 1071 | del text_encoders, tokenizers 1072 | gc.collect() 1073 | torch.cuda.empty_cache() 1074 | 1075 | # Then get the training dataset ready to be passed to the dataloader. 1076 | train_dataset = prepare_train_dataset(train_dataset, accelerator) 1077 | 1078 | train_dataloader = torch.utils.data.DataLoader( 1079 | train_dataset, 1080 | shuffle=True, 1081 | collate_fn=collate_fn, 1082 | batch_size=args.train_batch_size, 1083 | num_workers=args.dataloader_num_workers, 1084 | ) 1085 | 1086 | # Scheduler and math around the number of training steps. 1087 | overrode_max_train_steps = False 1088 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1089 | if args.max_train_steps is None: 1090 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1091 | overrode_max_train_steps = True 1092 | 1093 | lr_scheduler = get_scheduler( 1094 | args.lr_scheduler, 1095 | optimizer=optimizer, 1096 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 1097 | num_training_steps=args.max_train_steps * accelerator.num_processes, 1098 | num_cycles=args.lr_num_cycles, 1099 | power=args.lr_power, 1100 | ) 1101 | 1102 | # Prepare everything with our `accelerator`. 1103 | controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1104 | controlnet, optimizer, train_dataloader, lr_scheduler 1105 | ) 1106 | 1107 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 1108 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1109 | if overrode_max_train_steps: 1110 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1111 | # Afterwards we recalculate our number of training epochs 1112 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 1113 | 1114 | # We need to initialize the trackers we use, and also store our configuration. 1115 | # The trackers initializes automatically on the main process. 1116 | if accelerator.is_main_process: 1117 | tracker_config = dict(vars(args)) 1118 | 1119 | # tensorboard cannot handle list types for config 1120 | tracker_config.pop("validation_prompt") 1121 | tracker_config.pop("validation_image") 1122 | 1123 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) 1124 | 1125 | # Train! 1126 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 1127 | 1128 | logger.info("***** Running training *****") 1129 | logger.info(f" Num examples = {len(train_dataset)}") 1130 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 1131 | logger.info(f" Num Epochs = {args.num_train_epochs}") 1132 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 1133 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 1134 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 1135 | logger.info(f" Total optimization steps = {args.max_train_steps}") 1136 | global_step = 0 1137 | first_epoch = 0 1138 | 1139 | # Potentially load in the weights and states from a previous save 1140 | if args.resume_from_checkpoint: 1141 | if args.resume_from_checkpoint != "latest": 1142 | path = os.path.basename(args.resume_from_checkpoint) 1143 | else: 1144 | # Get the most recent checkpoint 1145 | dirs = os.listdir(args.output_dir) 1146 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1147 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1148 | path = dirs[-1] if len(dirs) > 0 else None 1149 | 1150 | if path is None: 1151 | accelerator.print( 1152 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 1153 | ) 1154 | args.resume_from_checkpoint = None 1155 | initial_global_step = 0 1156 | else: 1157 | accelerator.print(f"Resuming from checkpoint {path}") 1158 | accelerator.load_state(os.path.join(args.output_dir, path)) 1159 | global_step = int(path.split("-")[1]) 1160 | 1161 | initial_global_step = global_step 1162 | first_epoch = global_step // num_update_steps_per_epoch 1163 | else: 1164 | initial_global_step = 0 1165 | 1166 | progress_bar = tqdm( 1167 | range(0, args.max_train_steps), 1168 | initial=initial_global_step, 1169 | desc="Steps", 1170 | # Only show the progress bar once on each machine. 1171 | disable=not accelerator.is_local_main_process, 1172 | ) 1173 | 1174 | image_logs = None 1175 | for epoch in range(first_epoch, args.num_train_epochs): 1176 | for step, batch in enumerate(train_dataloader): 1177 | with accelerator.accumulate(controlnet): 1178 | # Convert images to latent space 1179 | if args.pretrained_vae_model_name_or_path is not None: 1180 | pixel_values = batch["pixel_values"].to(dtype=weight_dtype) 1181 | else: 1182 | pixel_values = batch["pixel_values"] 1183 | latents = vae.encode(pixel_values).latent_dist.sample() 1184 | latents = latents * vae.config.scaling_factor 1185 | if args.pretrained_vae_model_name_or_path is None: 1186 | latents = latents.to(weight_dtype) 1187 | 1188 | # Sample noise that we'll add to the latents 1189 | noise = torch.randn_like(latents) 1190 | bsz = latents.shape[0] 1191 | 1192 | # Sample a random timestep for each image 1193 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 1194 | timesteps = timesteps.long() 1195 | 1196 | # Add noise to the latents according to the noise magnitude at each timestep 1197 | # (this is the forward diffusion process) 1198 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1199 | 1200 | # ControlNet conditioning. 1201 | controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) 1202 | down_block_res_samples, mid_block_res_sample = controlnet( 1203 | noisy_latents, 1204 | timesteps, 1205 | encoder_hidden_states=batch["prompt_ids"], 1206 | added_cond_kwargs=batch["unet_added_conditions"], 1207 | controlnet_cond=controlnet_image, 1208 | return_dict=False, 1209 | ) 1210 | 1211 | # Predict the noise residual 1212 | model_pred = unet( 1213 | noisy_latents, 1214 | timesteps, 1215 | encoder_hidden_states=batch["prompt_ids"], 1216 | added_cond_kwargs=batch["unet_added_conditions"], 1217 | down_block_additional_residuals=[ 1218 | sample.to(dtype=weight_dtype) for sample in down_block_res_samples 1219 | ], 1220 | mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), 1221 | return_dict=False, 1222 | )[0] 1223 | 1224 | # Get the target for loss depending on the prediction type 1225 | if noise_scheduler.config.prediction_type == "epsilon": 1226 | target = noise 1227 | elif noise_scheduler.config.prediction_type == "v_prediction": 1228 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 1229 | else: 1230 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1231 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1232 | 1233 | accelerator.backward(loss) 1234 | if accelerator.sync_gradients: 1235 | params_to_clip = controlnet.parameters() 1236 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1237 | optimizer.step() 1238 | lr_scheduler.step() 1239 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 1240 | 1241 | # Checks if the accelerator has performed an optimization step behind the scenes 1242 | if accelerator.sync_gradients: 1243 | progress_bar.update(1) 1244 | global_step += 1 1245 | 1246 | if accelerator.is_main_process: 1247 | if global_step % args.checkpointing_steps == 0: 1248 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1249 | if args.checkpoints_total_limit is not None: 1250 | checkpoints = os.listdir(args.output_dir) 1251 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1252 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1253 | 1254 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1255 | if len(checkpoints) >= args.checkpoints_total_limit: 1256 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1257 | removing_checkpoints = checkpoints[0:num_to_remove] 1258 | 1259 | logger.info( 1260 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1261 | ) 1262 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1263 | 1264 | for removing_checkpoint in removing_checkpoints: 1265 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1266 | shutil.rmtree(removing_checkpoint) 1267 | 1268 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1269 | accelerator.save_state(save_path) 1270 | logger.info(f"Saved state to {save_path}") 1271 | 1272 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 1273 | image_logs = log_validation( 1274 | vae=vae, 1275 | unet=unet, 1276 | controlnet=controlnet, 1277 | args=args, 1278 | accelerator=accelerator, 1279 | weight_dtype=weight_dtype, 1280 | step=global_step, 1281 | ) 1282 | 1283 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1284 | progress_bar.set_postfix(**logs) 1285 | accelerator.log(logs, step=global_step) 1286 | 1287 | if global_step >= args.max_train_steps: 1288 | break 1289 | 1290 | # Create the pipeline using using the trained modules and save it. 1291 | accelerator.wait_for_everyone() 1292 | if accelerator.is_main_process: 1293 | controlnet = unwrap_model(controlnet) 1294 | controlnet.save_pretrained(args.output_dir) 1295 | 1296 | # Run a final round of validation. 1297 | # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`. 1298 | image_logs = None 1299 | if args.validation_prompt is not None: 1300 | image_logs = log_validation( 1301 | vae=None, 1302 | unet=None, 1303 | controlnet=None, 1304 | args=args, 1305 | accelerator=accelerator, 1306 | weight_dtype=weight_dtype, 1307 | step=global_step, 1308 | is_final_validation=True, 1309 | ) 1310 | 1311 | if args.push_to_hub: 1312 | save_model_card( 1313 | repo_id, 1314 | image_logs=image_logs, 1315 | base_model=args.pretrained_model_name_or_path, 1316 | repo_folder=args.output_dir, 1317 | ) 1318 | upload_folder( 1319 | repo_id=repo_id, 1320 | folder_path=args.output_dir, 1321 | commit_message="End of training", 1322 | ignore_patterns=["step_*", "epoch_*"], 1323 | ) 1324 | 1325 | accelerator.end_training() 1326 | 1327 | 1328 | if __name__ == "__main__": 1329 | args = parse_args() 1330 | main(args) -------------------------------------------------------------------------------- /train_controlnet_sdxl.sh: -------------------------------------------------------------------------------- 1 | # Original ControlNet paper: 2 | # "In the training process, we randomly replace 50% text prompts ct with empty strings. 3 | # This approach increases ControlNet’s ability to directly recognize semantics 4 | # in the input conditioning images (e.g., edges, poses, depth, etc.) as a replacement for the prompt." 5 | # https://civitai.com/articles/2078/play-in-control-controlnet-training-setup-guide 6 | 7 | export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" 8 | export OUTPUT_DIR="sdxl_caption_output" 9 | export DATASET="nickpai/coco2017-colorization" 10 | export REVISION="main" # option: main/caption-free 11 | export VAL_IMG_NAME="'./000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg'" 12 | export VAL_PROMPT="'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.'" 13 | # export VAL_PROMPT="'Colorize this image as if it was taken with a color camera' 'Colorize this image' 'Add colors to this image' 'Make this image colorful' 'Colorize this grayscale image' 'Add colors to this image'" 14 | 15 | accelerate launch train_controlnet_sdxl.py \ 16 | --pretrained_model_name_or_path=$MODEL_DIR \ 17 | --output_dir=$OUTPUT_DIR \ 18 | --seed=123123 \ 19 | --dataset_name=$DATASET \ 20 | --dataset_revision=$REVISION \ 21 | --image_column="file_name" \ 22 | --conditioning_image_column="file_name" \ 23 | --caption_column="captions" \ 24 | --max_train_samples=100000 \ 25 | --num_validation_images=1 \ 26 | --resolution=512 \ 27 | --num_train_epochs=5 \ 28 | --dataloader_num_workers=8 \ 29 | --learning_rate=1e-5 \ 30 | --train_batch_size=2 \ 31 | --gradient_accumulation_steps=8 \ 32 | --proportion_empty_prompts=0 \ 33 | --validation_steps=500 \ 34 | --checkpointing_steps=2500 \ 35 | --mixed_precision="fp16" \ 36 | --gradient_checkpointing \ 37 | --use_8bit_adam \ 38 | --enable_xformers_memory_efficient_attention 39 | 40 | # --validation_image './000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg' \ 41 | # --validation_prompt 'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.' \ 42 | -------------------------------------------------------------------------------- /train_controlnet_sdxl_light.sh: -------------------------------------------------------------------------------- 1 | # Original ControlNet paper: 2 | # "In the training process, we randomly replace 50% text prompts ct with empty strings. 3 | # This approach increases ControlNet’s ability to directly recognize semantics 4 | # in the input conditioning images (e.g., edges, poses, depth, etc.) as a replacement for the prompt." 5 | # https://civitai.com/articles/2078/play-in-control-controlnet-training-setup-guide 6 | 7 | export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" 8 | export REPO="ByteDance/SDXL-Lightning" 9 | export INFERENCE_STEP=8 10 | export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step 11 | export OUTPUT_DIR="test" 12 | export PROJECT_NAME="train_sdxl_light_controlnet" 13 | export DATASET="nickpai/coco2017-colorization" 14 | export REVISION="custom-caption" # option: main/caption-free/custom-caption 15 | export VAL_IMG_NAME="'./000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg'" 16 | export VAL_PROMPT="'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.'" 17 | # export VAL_PROMPT="'Colorize this image as if it was taken with a color camera' 'Colorize this image' 'Add colors to this image' 'Make this image colorful' 'Colorize this grayscale image' 'Add colors to this image'" 18 | 19 | accelerate launch train_controlnet_sdxl_light.py \ 20 | --pretrained_model_name_or_path=$MODEL_DIR \ 21 | --output_dir=$OUTPUT_DIR \ 22 | --tracker_project_name=$PROJECT_NAME \ 23 | --seed=123123 \ 24 | --dataset_name=$DATASET \ 25 | --dataset_revision=$REVISION \ 26 | --image_column="file_name" \ 27 | --conditioning_image_column="file_name" \ 28 | --caption_column="captions" \ 29 | --max_train_samples=100000 \ 30 | --num_validation_images=1 \ 31 | --resolution=512 \ 32 | --num_train_epochs=5 \ 33 | --dataloader_num_workers=8 \ 34 | --learning_rate=1e-5 \ 35 | --train_batch_size=2 \ 36 | --gradient_accumulation_steps=8 \ 37 | --proportion_empty_prompts=0 \ 38 | --validation_steps=500 \ 39 | --checkpointing_steps=2500 \ 40 | --mixed_precision="fp16" \ 41 | --gradient_checkpointing \ 42 | --use_8bit_adam \ 43 | --repo=$REPO \ 44 | --ckpt=$CKPT \ 45 | --num_inference_steps=$INFERENCE_STEP \ 46 | --enable_xformers_memory_efficient_attention 47 | 48 | # --validation_image './000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg' \ 49 | # --validation_prompt 'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.' \ 50 | --------------------------------------------------------------------------------