├── .gitignore ├── LICENSE ├── README.md ├── assets ├── README.md ├── cat_statue │ ├── 1.jpeg │ ├── 2.jpeg │ ├── 3.jpeg │ ├── 4.jpeg │ ├── 6.jpeg │ └── 7.jpeg ├── mug_skulls │ ├── 1.jpeg │ ├── 2.jpeg │ ├── 3.jpeg │ └── 4.jpeg ├── outputs │ ├── ti.png │ └── xti_v1.png └── paper.png ├── inference.py ├── prompt_plus ├── __init__.py ├── prompt_plus_pipeline_stable_diffusion.py └── prompt_plus_unet_2d_condition.py ├── requirements.txt ├── scripts ├── app.py └── textual_inversion.py └── train_p_plus.py /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | 3 | .eggs/ 4 | .mypy_cache 5 | *.egg-info/ 6 | build/ 7 | dist/ 8 | pip-wheel-metadata/ 9 | 10 | 11 | # dev tools 12 | 13 | .envrc 14 | .python-version 15 | .idea 16 | .venv/ 17 | .vscode/ 18 | /*.iml 19 | 20 | 21 | # jupyter notebooks 22 | 23 | .ipynb_checkpoints 24 | 25 | 26 | # miscellaneous 27 | 28 | .cache/ 29 | doc/_build/ 30 | *.swp 31 | .DS_Store 32 | 33 | 34 | # python 35 | 36 | *.pyc 37 | *.pyo 38 | __pycache__ 39 | 40 | 41 | # testing and continuous integration 42 | 43 | .coverage 44 | .pytest_cache/ 45 | .benchmarks 46 | 47 | # custom 48 | *.ipynb 49 | data 50 | private 51 | wandb 52 | models 53 | *.sh 54 | xti_cat 55 | grid.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Makoto Shing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # prompt-plus-pytorch 2 | Open In Colab 3 | 4 | An implementation of [P+: Extended Textual Conditioning in Text-to-Image Generation](https://prompt-plus.github.io/) by using d🧨ffusers. 5 | 6 | My summary is found [here](https://twitter.com/mk1stats/status/1637785231729262592). 7 | 8 | ![paper.png](https://prompt-plus.github.io/files/inversion_examples.jpeg) 9 | 10 | ## Current Status 11 | I can't still get better results than Textual Inversion. 12 | The hyper-parameters are exactly same as Textual Inversion except the number of training steps as the paper said in section 4.2.2. 13 | 14 | **Textual inversion:** 15 | ![ti](assets/outputs/ti.png) 16 | **Extended Textual Inversion:** 17 | ![xti](assets/outputs/xti_v1.png) 18 | 19 | Does it mean that we need n-layer x training steps (500) in total? My current implementation is jointly training all embeddings. 20 | > This optimization is applied independently to each cross-attention layer. 21 | 22 | ## Installation 23 | ```commandline 24 | git clone https://github.com/mkshing/prompt-plus-pytorch 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Training 29 | ```commandline 30 | accelerate launch train_p_plus.py \ 31 | --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ 32 | --train_data_dir="assets/cat_statue" \ 33 | --learnable_property="object" \ 34 | --placeholder_token="" --initializer_token="toy" \ 35 | --resolution=512 \ 36 | --train_batch_size=1 \ 37 | --gradient_accumulation_steps=8 \ 38 | --max_train_steps=500 \ 39 | --learning_rate=5.0e-03 \ 40 | --lr_scheduler="constant" \ 41 | --lr_warmup_steps=0 \ 42 | --output_dir="xti_cat" \ 43 | --report_to "wandb" \ 44 | --only_save_embeds \ 45 | --enable_xformers_memory_efficient_attention 46 | ``` 47 | 48 | ## Inference 49 | 50 | ```python 51 | from prompt_plus import PPlusStableDiffusionPipeline 52 | 53 | pipe = PPlusStableDiffusionPipeline.from_learned_embed( 54 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", 55 | learned_embed_name_or_path="learned-embed.bin path" 56 | ) 57 | prompt = "A backpack" 58 | image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] 59 | image.save("cat-backpack.png") 60 | ``` 61 | They also proposed "Style Mixing" to combine 2 embeds. 62 | ```python 63 | pipe = PPlusStableDiffusionPipeline.from_learned_embed( 64 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", 65 | learned_embed_name_or_path=["learned-embed 1", "learned-embed 2"], 66 | style_mixing_k_K=(5, 10), 67 | ) 68 | ``` 69 | Also, I made a pipeline for textual inversion to test easily. 70 | ```python 71 | from prompt_plus import TextualInversionStableDiffusionPipeline 72 | 73 | pipe = TextualInversionStableDiffusionPipeline.from_learned_embed( 74 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", 75 | learned_embed_name_or_path="sd-concepts-library/cat-toy", 76 | ) 77 | prompt = "A backpack" 78 | images = pipe(prompt, num_inference_steps=50, guidance_scale=7.5) 79 | ``` 80 | 81 | If you want to do inference in commandline, 82 | ```commandline 83 | python inference.py \ 84 | --pretrained_model_name_or_path "CompVis/stable-diffusion-v1-4" \ 85 | --learned_embed_name_or_path "xti_cat" \ 86 | --prompt "A backpack" \ 87 | --float16 \ 88 | --seed 1000 89 | ``` 90 | ## Citation 91 | 92 | ```bibtex 93 | @article{voynov2023P+, 94 | title={P+: Extended Textual Conditioning in Text-to-Image Generation}, 95 | author={Voynov, Andrey and Chu, Qinghao and Cohen-Or, Daniel and Aberman, Kfir}, 96 | booktitle={arXiv preprint}, 97 | year={2023}, 98 | url={https://arxiv.org/abs/2303.09522} 99 | } 100 | ``` 101 | 102 | ## Reference 103 | - [diffusers Textual Inversion code](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) 104 | 105 | ## TODO 106 | - [x] Training 107 | - [x] Inference 108 | - [x] Style Mixing 109 | - [ ] Regularization -------------------------------------------------------------------------------- /assets/README.md: -------------------------------------------------------------------------------- 1 | `cat_statue` and `mug_skulls` are taken from the [original Textual Inversion repository](https://github.com/rinongal/textual_inversion#pretrained-models--data) 2 | and resized to 512x512 by the following code. 3 | ```python 4 | import os 5 | from PIL import Image 6 | 7 | image_dir = "image-path" 8 | for file_path in os.listdir(image_dir): 9 | image_path = os.path.join(save_path, file_path) 10 | Image.open(image_path).resize((512, 512)).save(image_path) 11 | ``` -------------------------------------------------------------------------------- /assets/cat_statue/1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/1.jpeg -------------------------------------------------------------------------------- /assets/cat_statue/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/2.jpeg -------------------------------------------------------------------------------- /assets/cat_statue/3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/3.jpeg -------------------------------------------------------------------------------- /assets/cat_statue/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/4.jpeg -------------------------------------------------------------------------------- /assets/cat_statue/6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/6.jpeg -------------------------------------------------------------------------------- /assets/cat_statue/7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/7.jpeg -------------------------------------------------------------------------------- /assets/mug_skulls/1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/1.jpeg -------------------------------------------------------------------------------- /assets/mug_skulls/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/2.jpeg -------------------------------------------------------------------------------- /assets/mug_skulls/3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/3.jpeg -------------------------------------------------------------------------------- /assets/mug_skulls/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/4.jpeg -------------------------------------------------------------------------------- /assets/outputs/ti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/outputs/ti.png -------------------------------------------------------------------------------- /assets/outputs/xti_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/outputs/xti_v1.png -------------------------------------------------------------------------------- /assets/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/paper.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | import torch 4 | from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline 5 | from diffusers.utils import is_xformers_available 6 | from prompt_plus import TextualInversionStableDiffusionPipeline, PPlusStableDiffusionPipeline 7 | 8 | 9 | def image_grid(imgs, rows, cols): 10 | assert len(imgs) == rows * cols 11 | w, h = imgs[0].size 12 | grid = Image.new('RGB', size=(cols * w, rows * h)) 13 | for i, img in enumerate(imgs): 14 | grid.paste(img, box=(i % cols * w, i // cols * h)) 15 | return grid 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--pretrained_model_name_or_path", type=str, help="model name or path", default="runwayml/stable-diffusion-v1-5") 21 | parser.add_argument("--learned_embed_name_or_path", type=str, help="model path for learned embedding") 22 | parser.add_argument("--is_textual_inversion", action="store_true", help="Load textual inversion embeds") 23 | parser.add_argument("--original_pipe", action="store_true", help="load standard pipeline") 24 | parser.add_argument("--device", type=str, help="Device on which Stable Diffusion will be run", choices=["cpu", "cuda"], default=None) 25 | parser.add_argument("--float16", action="store_true", help="load float16") 26 | # diffusers config 27 | parser.add_argument("--prompt", type=str, nargs="?", default="a photo of *s", help="the prompt to render") 28 | parser.add_argument("--num_inference_steps", type=int, default=30, help="number of ddim sampling steps") 29 | parser.add_argument("--guidance_scale", type=float, default=7.5, help="unconditional guidance scale") 30 | parser.add_argument("--num_images_per_prompt", type=int, default=3, help="number of images per prompt") 31 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",) 32 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",) 33 | parser.add_argument("--seed", type=int, default=None, help="the seed (for reproducible sampling)") 34 | opt = parser.parse_args() 35 | return opt 36 | 37 | 38 | def main(): 39 | args = parse_args() 40 | if args.device is None: 41 | args.device = "cuda" if torch.cuda.is_available() else "cpu" 42 | print(f"device: {args.device}") 43 | 44 | # load model 45 | if args.is_textual_inversion or not args.original_pipe: 46 | if args.is_textual_inversion: 47 | Pipeline = TextualInversionStableDiffusionPipeline 48 | else: 49 | Pipeline = PPlusStableDiffusionPipeline 50 | pipe = Pipeline.from_learned_embed( 51 | pretrained_model_name_or_path=args.pretrained_model_name_or_path, 52 | learned_embed_name_or_path=args.learned_embed_name_or_path, 53 | torch_dtype=torch.float16 if args.float16 else None, 54 | ).to(args.device) 55 | else: 56 | print("loading the original pipeline") 57 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float16 if args.float16 else None).to(args.device) 58 | pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 59 | if is_xformers_available(): 60 | pipe.enable_xformers_memory_efficient_attention() 61 | print("loaded pipeline") 62 | # run! 63 | generator = None 64 | if args.seed: 65 | print(f"Using seed: {args.seed}") 66 | generator = torch.Generator(device=args.device).manual_seed(args.seed) 67 | images = pipe( 68 | args.prompt, 69 | num_inference_steps=args.num_inference_steps, 70 | guidance_scale=args.guidance_scale, 71 | generator=generator, 72 | num_images_per_prompt=args.num_images_per_prompt, 73 | height=args.height, 74 | width=args.width 75 | ).images 76 | grid_image = image_grid(images, 1, args.num_images_per_prompt) 77 | grid_image.save("grid.png") 78 | print("DONE!") 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | 84 | -------------------------------------------------------------------------------- /prompt_plus/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_plus_unet_2d_condition import PPlusUNet2DConditionModel 2 | from .prompt_plus_pipeline_stable_diffusion import PPlusStableDiffusionPipeline, TextualInversionStableDiffusionPipeline 3 | -------------------------------------------------------------------------------- /prompt_plus/prompt_plus_pipeline_stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Optional, List, Union, Callable, Dict, Any, Tuple 4 | import torch 5 | from transformers import CLIPTextModel, CLIPTokenizer 6 | from diffusers import StableDiffusionPipeline 7 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 8 | from diffusers.utils import logging 9 | from huggingface_hub import hf_hub_download 10 | from prompt_plus.prompt_plus_unet_2d_condition import PPlusUNet2DConditionModel 11 | 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | class TextualInversionStableDiffusionPipeline(StableDiffusionPipeline): 17 | @classmethod 18 | def from_learned_embed( 19 | cls, 20 | pretrained_model_name_or_path: Union[str, os.PathLike], 21 | learned_embed_name_or_path: Union[str, os.PathLike], 22 | **kwargs 23 | ): 24 | if os.path.exists(learned_embed_name_or_path): 25 | embeds_path = os.path.join(learned_embed_name_or_path, "learned_embeds.bin") if os.path.isdir(learned_embed_name_or_path) else learned_embed_name_or_path 26 | # token_path = os.path.join(model_dir, "token_identifier.txt") 27 | else: 28 | # download 29 | embeds_path = hf_hub_download(repo_id=learned_embed_name_or_path, filename="learned_embeds.bin") 30 | # token_path = hf_hub_download(repo_id=learned_embed_name_or_path, filename="token_identifier.txt") 31 | 32 | text_encoder = CLIPTextModel.from_pretrained( 33 | pretrained_model_name_or_path, subfolder="text_encoder", **kwargs 34 | ) 35 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **kwargs) 36 | loaded_learned_embeds = torch.load(embeds_path, map_location="cpu") 37 | # separate token and the embeds 38 | trained_token = list(loaded_learned_embeds.keys())[0] 39 | embeds = loaded_learned_embeds[trained_token] 40 | 41 | # cast to dtype of text_encoder 42 | dtype = text_encoder.get_input_embeddings().weight.dtype 43 | embeds.to(dtype) 44 | 45 | # add the token in tokenizer 46 | # token = token if token is not None else trained_token 47 | num_added_tokens = tokenizer.add_tokens(trained_token) 48 | if num_added_tokens == 0: 49 | raise ValueError( 50 | f"The tokenizer already contains the token {trained_token}. Please pass a different `token` that is not already in the tokenizer.") 51 | 52 | # resize the token embeddings 53 | text_encoder.resize_token_embeddings(len(tokenizer)) 54 | 55 | # get the id for the token and assign the embeds 56 | token_id = tokenizer.convert_tokens_to_ids(trained_token) 57 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds 58 | print(f"placeholder_token: {trained_token}") 59 | return super().from_pretrained( 60 | pretrained_model_name_or_path=pretrained_model_name_or_path, 61 | text_encoder=text_encoder, 62 | tokenizer=tokenizer, 63 | **kwargs 64 | ) 65 | 66 | 67 | def _load_embed_from_name_or_path(learned_embed_name_or_path): 68 | if os.path.exists(learned_embed_name_or_path): 69 | embeds_path = os.path.join(learned_embed_name_or_path, "learned_embeds.bin") if os.path.isdir( 70 | learned_embed_name_or_path) else learned_embed_name_or_path 71 | # config_path = os.path.join(model_dir, "config.json") 72 | else: 73 | # download 74 | embeds_path = hf_hub_download(repo_id=learned_embed_name_or_path, filename="learned_embeds.bin") 75 | # config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json") 76 | # with open(config_path, "r", encoding="utf-8") as f: 77 | # config = json.load(f) 78 | # load 79 | loaded_learned_embeds = torch.load(embeds_path, map_location="cpu") 80 | return loaded_learned_embeds 81 | 82 | 83 | def load_embed_from_name_or_path(learned_embed_name_or_path, style_mixing_k_K=None): 84 | if isinstance(learned_embed_name_or_path, str): 85 | assert style_mixing_k_K is None, "You inputted only one learned embed but `style_mixing_k_K` was specified!" 86 | return _load_embed_from_name_or_path(learned_embed_name_or_path) 87 | else: 88 | assert len(learned_embed_name_or_path) == 2, "Only 2 embeds are supported for now but it's especially possible." 89 | k, K = style_mixing_k_K 90 | embeds = [] 91 | for p in learned_embed_name_or_path: 92 | embeds.append(_load_embed_from_name_or_path(p)) 93 | # use first embeds tokens to align 94 | tokens = list(embeds[0].keys()) 95 | n = len(tokens) 96 | assert k < n, f"k must be lower than n={n}" 97 | assert K < n, f"K must be lower than n={n}" 98 | loaded_learned_embeds = dict() 99 | for i in range(n): 100 | if i <= k or K > i: 101 | embed_idx = 0 102 | else: 103 | embed_idx = 1 104 | embed = list(embeds[embed_idx].values())[i] 105 | loaded_learned_embeds[tokens[i]] = embed 106 | return loaded_learned_embeds 107 | 108 | 109 | class PPlusStableDiffusionPipeline(StableDiffusionPipeline): 110 | @classmethod 111 | def from_learned_embed( 112 | cls, 113 | pretrained_model_name_or_path: Union[str, os.PathLike], 114 | learned_embed_name_or_path: Optional[Union[str, os.PathLike, List[str]]] = None, 115 | style_mixing_k_K: Optional[Tuple[int]] = None, 116 | loaded_learned_embeds: Optional[Dict[str, torch.Tensor]] = None, 117 | **kwargs, 118 | ): 119 | text_encoder = CLIPTextModel.from_pretrained( 120 | pretrained_model_name_or_path, subfolder="text_encoder", **kwargs 121 | ) 122 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **kwargs) 123 | if loaded_learned_embeds is None: 124 | loaded_learned_embeds = load_embed_from_name_or_path(learned_embed_name_or_path, style_mixing_k_K) 125 | new_tokens = list(loaded_learned_embeds.keys()) 126 | # easy validation for textual inversion 127 | assert len(new_tokens) > 1, "You might want to load textual inversion pipeline!" 128 | # cast to dtype of text_encoder 129 | dtype = text_encoder.get_input_embeddings().weight.dtype 130 | # resize the token embeddings 131 | text_encoder.resize_token_embeddings(len(tokenizer)+len(new_tokens)) 132 | 133 | for token in new_tokens: 134 | embeds = loaded_learned_embeds[token] 135 | embeds.to(dtype) 136 | # add the token in tokenizer 137 | # token = token if token is not None else trained_token 138 | num_added_tokens = tokenizer.add_tokens(token) 139 | if num_added_tokens == 0: 140 | raise ValueError( 141 | f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.") 142 | # get the id for the token and assign the embeds 143 | token_id = tokenizer.convert_tokens_to_ids(token) 144 | text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token] 145 | # store placeholder_token to text_encoder config 146 | text_encoder.config.placeholder_token = "-".join(new_tokens[0].split("-")[:-1]) 147 | text_encoder.config.placeholder_tokens = new_tokens 148 | print(f"placeholder_token: {text_encoder.config.placeholder_token}") 149 | unet = PPlusUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", **kwargs) 150 | return super().from_pretrained( 151 | pretrained_model_name_or_path=pretrained_model_name_or_path, 152 | unet=unet, 153 | text_encoder=text_encoder, 154 | tokenizer=tokenizer, 155 | **kwargs 156 | ) 157 | 158 | def _encode_prompt( 159 | self, 160 | prompt, 161 | device, 162 | num_images_per_prompt, 163 | do_classifier_free_guidance, 164 | negative_prompt=None, 165 | prompt_embeds: Optional[torch.FloatTensor] = None, 166 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 167 | ): 168 | assert isinstance(prompt, str), "Currently, only string `prompt` is supported!" 169 | if prompt is not None and isinstance(prompt, str): 170 | batch_size = 1 171 | elif prompt is not None and isinstance(prompt, list): 172 | batch_size = len(prompt) 173 | else: 174 | batch_size = prompt_embeds.shape[0] 175 | 176 | if prompt_embeds is None: 177 | encoder_hidden_states_list = [] 178 | for token in self.text_encoder.config.placeholder_tokens: 179 | one_prompt = prompt.replace(self.text_encoder.config.placeholder_token, token) 180 | text_inputs = self.tokenizer( 181 | one_prompt, 182 | padding="max_length", 183 | max_length=self.tokenizer.model_max_length, 184 | truncation=True, 185 | return_tensors="pt", 186 | ) 187 | text_input_ids = text_inputs.input_ids 188 | untruncated_ids = self.tokenizer(one_prompt, padding="longest", return_tensors="pt").input_ids 189 | 190 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 191 | text_input_ids, untruncated_ids 192 | ): 193 | removed_text = self.tokenizer.batch_decode( 194 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] 195 | ) 196 | logger.warning( 197 | "The following part of your input was truncated because CLIP can only handle sequences up to" 198 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 199 | ) 200 | 201 | if hasattr(self.text_encoder.config, 202 | "use_attention_mask") and self.text_encoder.config.use_attention_mask: 203 | attention_mask = text_inputs.attention_mask.to(device) 204 | else: 205 | attention_mask = None 206 | 207 | prompt_embeds = self.text_encoder( 208 | text_input_ids.to(device), 209 | attention_mask=attention_mask, 210 | ) 211 | prompt_embeds = prompt_embeds[0] 212 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 213 | 214 | bs_embed, seq_len, _ = prompt_embeds.shape 215 | # duplicate text embeddings for each generation per prompt, using mps friendly method 216 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 217 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 218 | 219 | # get unconditional embeddings for classifier free guidance 220 | if do_classifier_free_guidance: 221 | uncond_tokens: List[str] 222 | if negative_prompt is None: 223 | uncond_tokens = [""] * batch_size 224 | elif type(prompt) is not type(negative_prompt): 225 | raise TypeError( 226 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 227 | f" {type(prompt)}." 228 | ) 229 | elif isinstance(negative_prompt, str): 230 | uncond_tokens = [negative_prompt] 231 | elif batch_size != len(negative_prompt): 232 | raise ValueError( 233 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 234 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 235 | " the batch size of `prompt`." 236 | ) 237 | else: 238 | uncond_tokens = negative_prompt 239 | 240 | max_length = prompt_embeds.shape[1] 241 | uncond_input = self.tokenizer( 242 | uncond_tokens, 243 | padding="max_length", 244 | max_length=max_length, 245 | truncation=True, 246 | return_tensors="pt", 247 | ) 248 | 249 | if hasattr(self.text_encoder.config, 250 | "use_attention_mask") and self.text_encoder.config.use_attention_mask: 251 | attention_mask = uncond_input.attention_mask.to(device) 252 | else: 253 | attention_mask = None 254 | 255 | negative_prompt_embeds = self.text_encoder( 256 | uncond_input.input_ids.to(device), 257 | attention_mask=attention_mask, 258 | ) 259 | negative_prompt_embeds = negative_prompt_embeds[0] 260 | 261 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 262 | seq_len = negative_prompt_embeds.shape[1] 263 | 264 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 265 | 266 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 267 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, 268 | -1) 269 | 270 | # For classifier free guidance, we need to do two forward passes. 271 | # Here we concatenate the unconditional and text embeddings into a single batch 272 | # to avoid doing two forward passes 273 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 274 | 275 | encoder_hidden_states_list.append(prompt_embeds) 276 | else: 277 | # trust you! 278 | encoder_hidden_states_list = prompt_embeds 279 | return encoder_hidden_states_list 280 | 281 | @torch.no_grad() 282 | def __call__( 283 | self, 284 | prompt: Union[str, List[str]] = None, 285 | height: Optional[int] = None, 286 | width: Optional[int] = None, 287 | num_inference_steps: int = 50, 288 | guidance_scale: float = 7.5, 289 | negative_prompt: Optional[Union[str, List[str]]] = None, 290 | num_images_per_prompt: Optional[int] = 1, 291 | eta: float = 0.0, 292 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 293 | latents: Optional[torch.FloatTensor] = None, 294 | prompt_embeds: Optional[torch.FloatTensor] = None, 295 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 296 | output_type: Optional[str] = "pil", 297 | return_dict: bool = True, 298 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 299 | callback_steps: int = 1, 300 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 301 | ): 302 | # 0. Default height and width to unet 303 | height = height or self.unet.config.sample_size * self.vae_scale_factor 304 | width = width or self.unet.config.sample_size * self.vae_scale_factor 305 | 306 | # 1. Check inputs. Raise error if not correct 307 | self.check_inputs( 308 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds 309 | ) 310 | 311 | # 2. Define call parameters 312 | if prompt is not None and isinstance(prompt, str): 313 | batch_size = 1 314 | elif prompt is not None and isinstance(prompt, list): 315 | batch_size = len(prompt) 316 | else: 317 | batch_size = prompt_embeds.shape[0] 318 | 319 | device = self._execution_device 320 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 321 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 322 | # corresponds to doing no classifier free guidance. 323 | do_classifier_free_guidance = guidance_scale > 1.0 324 | 325 | # 3. Encode input prompt 326 | encoder_hidden_states_list = self._encode_prompt( 327 | prompt, 328 | device, 329 | num_images_per_prompt, 330 | do_classifier_free_guidance, 331 | negative_prompt, 332 | prompt_embeds=prompt_embeds, 333 | negative_prompt_embeds=negative_prompt_embeds, 334 | ) 335 | 336 | # 4. Prepare timesteps 337 | self.scheduler.set_timesteps(num_inference_steps, device=device) 338 | timesteps = self.scheduler.timesteps 339 | 340 | # 5. Prepare latent variables 341 | num_channels_latents = self.unet.in_channels 342 | latents = self.prepare_latents( 343 | batch_size * num_images_per_prompt, 344 | num_channels_latents, 345 | height, 346 | width, 347 | encoder_hidden_states_list[0].dtype, 348 | device, 349 | generator, 350 | latents, 351 | ) 352 | 353 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 354 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 355 | 356 | # 7. Denoising loop 357 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 358 | with self.progress_bar(total=num_inference_steps) as progress_bar: 359 | for i, t in enumerate(timesteps): 360 | # expand the latents if we are doing classifier free guidance 361 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 362 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 363 | 364 | # predict the noise residual 365 | noise_pred = self.unet( 366 | latent_model_input, 367 | t, 368 | encoder_hidden_states_list=encoder_hidden_states_list, 369 | cross_attention_kwargs=cross_attention_kwargs, 370 | ).sample 371 | 372 | # perform guidance 373 | if do_classifier_free_guidance: 374 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 375 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 376 | 377 | # compute the previous noisy sample x_t -> x_t-1 378 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 379 | 380 | # call the callback, if provided 381 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 382 | progress_bar.update() 383 | if callback is not None and i % callback_steps == 0: 384 | callback(i, t, latents) 385 | 386 | if output_type == "latent": 387 | image = latents 388 | has_nsfw_concept = None 389 | elif output_type == "pil": 390 | # 8. Post-processing 391 | image = self.decode_latents(latents) 392 | 393 | # 9. Run safety checker 394 | image, has_nsfw_concept = self.run_safety_checker(image, device, encoder_hidden_states_list[0].dtype) 395 | 396 | # 10. Convert to PIL 397 | image = self.numpy_to_pil(image) 398 | else: 399 | # 8. Post-processing 400 | image = self.decode_latents(latents) 401 | 402 | # 9. Run safety checker 403 | image, has_nsfw_concept = self.run_safety_checker(image, device, encoder_hidden_states_list[0].dtype) 404 | 405 | # Offload last model to CPU 406 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 407 | self.final_offload_hook.offload() 408 | 409 | if not return_dict: 410 | return (image, has_nsfw_concept) 411 | 412 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 413 | 414 | -------------------------------------------------------------------------------- /prompt_plus/prompt_plus_unet_2d_condition.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Any, Dict, Tuple, List 2 | import torch 3 | from diffusers import UNet2DConditionModel 4 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 5 | from diffusers.utils import logging 6 | 7 | 8 | logger = logging.get_logger(__name__) 9 | 10 | 11 | class PPlusUNet2DConditionModel(UNet2DConditionModel): 12 | def forward( 13 | self, 14 | sample: torch.FloatTensor, 15 | timestep: Union[torch.Tensor, float, int], 16 | encoder_hidden_states: torch.Tensor = None, 17 | ######################################### 18 | encoder_hidden_states_list: List[torch.Tensor] = None, 19 | ######################################### 20 | class_labels: Optional[torch.Tensor] = None, 21 | timestep_cond: Optional[torch.Tensor] = None, 22 | attention_mask: Optional[torch.Tensor] = None, 23 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 24 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 25 | mid_block_additional_residual: Optional[torch.Tensor] = None, 26 | return_dict: bool = True, 27 | ): 28 | if encoder_hidden_states is None and encoder_hidden_states_list is None: 29 | raise ValueError("You must input either `encoder_hidden_states` or `encoder_hidden_states_list`!") 30 | if encoder_hidden_states_list is not None: 31 | select_idx = 0 32 | 33 | # By default samples have to be AT least a multiple of the overall upsampling factor. 34 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 35 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 36 | # on the fly if necessary. 37 | default_overall_up_factor = 2 ** self.num_upsamplers 38 | 39 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 40 | forward_upsample_size = False 41 | upsample_size = None 42 | 43 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 44 | logger.info("Forward upsample size to force interpolation output size.") 45 | forward_upsample_size = True 46 | 47 | # prepare attention_mask 48 | if attention_mask is not None: 49 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 50 | attention_mask = attention_mask.unsqueeze(1) 51 | 52 | # 0. center input if necessary 53 | if self.config.center_input_sample: 54 | sample = 2 * sample - 1.0 55 | 56 | # 1. time 57 | timesteps = timestep 58 | if not torch.is_tensor(timesteps): 59 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 60 | # This would be a good case for the `match` statement (Python 3.10+) 61 | is_mps = sample.device.type == "mps" 62 | if isinstance(timestep, float): 63 | dtype = torch.float32 if is_mps else torch.float64 64 | else: 65 | dtype = torch.int32 if is_mps else torch.int64 66 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 67 | elif len(timesteps.shape) == 0: 68 | timesteps = timesteps[None].to(sample.device) 69 | 70 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 71 | timesteps = timesteps.expand(sample.shape[0]) 72 | 73 | t_emb = self.time_proj(timesteps) 74 | 75 | # timesteps does not contain any weights and will always return f32 tensors 76 | # but time_embedding might actually be running in fp16. so we need to cast here. 77 | # there might be better ways to encapsulate this. 78 | t_emb = t_emb.to(dtype=self.dtype) 79 | 80 | emb = self.time_embedding(t_emb, timestep_cond) 81 | 82 | if self.class_embedding is not None: 83 | if class_labels is None: 84 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 85 | 86 | if self.config.class_embed_type == "timestep": 87 | class_labels = self.time_proj(class_labels) 88 | 89 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 90 | emb = emb + class_emb 91 | 92 | # 2. pre-process 93 | sample = self.conv_in(sample) 94 | 95 | # 3. down 96 | down_block_res_samples = (sample,) 97 | for downsample_block in self.down_blocks: 98 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 99 | sample, res_samples = downsample_block( 100 | hidden_states=sample, 101 | temb=emb, 102 | encoder_hidden_states=encoder_hidden_states if encoder_hidden_states_list is None else encoder_hidden_states_list[select_idx], 103 | attention_mask=attention_mask, 104 | cross_attention_kwargs=cross_attention_kwargs, 105 | ) 106 | if encoder_hidden_states_list is not None: 107 | select_idx += 1 108 | else: 109 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 110 | 111 | down_block_res_samples += res_samples 112 | 113 | if down_block_additional_residuals is not None: 114 | new_down_block_res_samples = () 115 | 116 | for down_block_res_sample, down_block_additional_residual in zip( 117 | down_block_res_samples, down_block_additional_residuals 118 | ): 119 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 120 | new_down_block_res_samples += (down_block_res_sample,) 121 | 122 | down_block_res_samples = new_down_block_res_samples 123 | 124 | # 4. mid 125 | if self.mid_block is not None: 126 | sample = self.mid_block( 127 | sample, 128 | emb, 129 | encoder_hidden_states=encoder_hidden_states if encoder_hidden_states_list is None else 130 | encoder_hidden_states_list[select_idx], 131 | attention_mask=attention_mask, 132 | cross_attention_kwargs=cross_attention_kwargs, 133 | ) 134 | if encoder_hidden_states_list is not None: 135 | select_idx += 1 136 | 137 | if mid_block_additional_residual is not None: 138 | sample = sample + mid_block_additional_residual 139 | 140 | # 5. up 141 | for i, upsample_block in enumerate(self.up_blocks): 142 | is_final_block = i == len(self.up_blocks) - 1 143 | 144 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 145 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 146 | 147 | # if we have not reached the final block and need to forward the 148 | # upsample size, we do it here 149 | if not is_final_block and forward_upsample_size: 150 | upsample_size = down_block_res_samples[-1].shape[2:] 151 | 152 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 153 | sample = upsample_block( 154 | hidden_states=sample, 155 | temb=emb, 156 | res_hidden_states_tuple=res_samples, 157 | encoder_hidden_states=encoder_hidden_states if encoder_hidden_states_list is None else 158 | encoder_hidden_states_list[select_idx], 159 | cross_attention_kwargs=cross_attention_kwargs, 160 | upsample_size=upsample_size, 161 | attention_mask=attention_mask, 162 | ) 163 | if encoder_hidden_states_list is not None: 164 | select_idx += 1 165 | else: 166 | sample = upsample_block( 167 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 168 | ) 169 | 170 | # 6. post-process 171 | if self.conv_norm_out: 172 | sample = self.conv_norm_out(sample) 173 | sample = self.conv_act(sample) 174 | sample = self.conv_out(sample) 175 | 176 | if not return_dict: 177 | return (sample,) 178 | 179 | return UNet2DConditionOutput(sample=sample) 180 | 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers[torch] 2 | accelerate 3 | torchvision 4 | transformers>=4.25.1 5 | ftfy 6 | tensorboard 7 | Jinja2 8 | wandb 9 | natsort 10 | safetensors 11 | datasets 12 | bitsandbytes -------------------------------------------------------------------------------- /scripts/app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from functools import lru_cache 4 | import subprocess 5 | import torch 6 | import gradio as gr 7 | from diffusers.utils import is_xformers_available 8 | 9 | 10 | device = "cuda" if torch.cuda.is_available() else "cpu" 11 | print(f"device: {device}") 12 | 13 | 14 | def gitclone(url, target_dir=None, branch_arg=None): 15 | run_args = ["git", "clone"] 16 | if branch_arg: 17 | run_args.extend(["-b", branch_arg]) 18 | run_args.append(url) 19 | if target_dir: 20 | run_args.append(target_dir) 21 | res = subprocess.run(run_args, stdout=subprocess.PIPE).stdout.decode("utf-8") 22 | print(res) 23 | 24 | 25 | def pipi(modulestr): 26 | res = subprocess.run( 27 | ["pip", "install", modulestr], stdout=subprocess.PIPE 28 | ).stdout.decode("utf-8") 29 | print(res) 30 | 31 | 32 | try: 33 | proj_dir = os.path.dirname(__file__) 34 | sys.path.append(proj_dir) 35 | from prompt_plus import PPlusStableDiffusionPipeline 36 | except ImportError: 37 | GITHUB_SECRET = os.environ.get("GITHUB_SECRET") 38 | gitclone("https://github.com/mkshing/prompt-plus-pytorch" if GITHUB_SECRET is None else f"https://{GITHUB_SECRET}@github.com/mkshing/prompt-plus-pytorch") 39 | from prompt_plus import PPlusStableDiffusionPipeline 40 | 41 | 42 | @lru_cache(maxsize=3) 43 | def load_pipe(pretrained_model_name_or_path, learned_embed_name_or_path): 44 | pipe = PPlusStableDiffusionPipeline.from_learned_embed( 45 | pretrained_model_name_or_path=pretrained_model_name_or_path, 46 | learned_embed_name_or_path=learned_embed_name_or_path, 47 | revision="fp16", torch_dtype=torch.float16 48 | ) 49 | if is_xformers_available(): 50 | pipe.enable_xformers_memory_efficient_attention() 51 | return pipe 52 | 53 | 54 | def txt2img_func(pretrained_model_name_or_path, learned_embed_name_or_path, prompt, n_samples=4, scale=7.5, steps=25, width=512, height=512, seed="random"): 55 | n_samples = int(n_samples) 56 | scale = float(scale) 57 | steps = int(steps) 58 | width = int(width) 59 | height = int(height) 60 | generator = torch.Generator(device=device) 61 | if seed == "random": 62 | seed = generator.seed() 63 | else: 64 | seed = int(seed) 65 | generator = generator.manual_seed(int(seed)) 66 | pipe = load_pipe(pretrained_model_name_or_path, learned_embed_name_or_path).to(device) 67 | images = pipe( 68 | prompt, 69 | num_inference_steps=steps, 70 | guidance_scale=scale, 71 | generator=generator, 72 | num_images_per_prompt=n_samples, 73 | height=height, 74 | width=width 75 | ).images 76 | return images 77 | 78 | 79 | with gr.Blocks() as demo: 80 | gr.Markdown("# P+: Extended Textual Conditioning in Text-to-Image Generation") 81 | pretrained_model_name_or_path = gr.Textbox(label="pre-trained model name or path", value="runwayml/stable-diffusion-v1-5") 82 | learned_embed_name_or_path = gr.Textbox(label="learned embedding name or path") 83 | with gr.Row(): 84 | with gr.Column(): 85 | # input 86 | prompt = gr.Textbox(label="Prompt") 87 | n_samples = gr.Number(value=3, label="n_samples") 88 | cfg_scale = gr.Slider(minimum=0.0, maximum=20, value=7.5, label="cfg_scale", step=0.5) 89 | steps = gr.Number(value=30, label="steps") 90 | width = gr.Slider(minimum=128, maximum=1024, value=512, label="width", step=64) 91 | height = gr.Slider(minimum=128, maximum=1024, value=512, label="height", step=64) 92 | seed = gr.Textbox(value='random', 93 | placeholder="If you fix seed, you get same outputs all the time. You can set as integer like 42.", 94 | label="seed") 95 | 96 | # button 97 | button = gr.Button(value="Generate!") 98 | with gr.Column(): 99 | # output 100 | out_images = gr.Gallery(label="Output") 101 | button.click( 102 | txt2img_func, 103 | inputs=[pretrained_model_name_or_path, learned_embed_name_or_path, prompt, n_samples, cfg_scale, steps, width, height, seed], 104 | outputs=[out_images], 105 | api_name="txt2img" 106 | ) 107 | 108 | demo.launch() 109 | -------------------------------------------------------------------------------- /scripts/textual_inversion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import logging 18 | import math 19 | import os 20 | import random 21 | import warnings 22 | from pathlib import Path 23 | from typing import Optional 24 | 25 | import numpy as np 26 | import PIL 27 | import torch 28 | import torch.nn.functional as F 29 | import torch.utils.checkpoint 30 | import transformers 31 | from accelerate import Accelerator 32 | from accelerate.logging import get_logger 33 | from accelerate.utils import ProjectConfiguration, set_seed 34 | from huggingface_hub import HfFolder, Repository, create_repo, whoami 35 | 36 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released 37 | from packaging import version 38 | from PIL import Image 39 | from torch.utils.data import Dataset 40 | from torchvision import transforms 41 | from tqdm.auto import tqdm 42 | from transformers import CLIPTextModel, CLIPTokenizer 43 | 44 | import diffusers 45 | from diffusers import ( 46 | AutoencoderKL, 47 | DDPMScheduler, 48 | DiffusionPipeline, 49 | DPMSolverMultistepScheduler, 50 | StableDiffusionPipeline, 51 | UNet2DConditionModel, 52 | ) 53 | from diffusers.optimization import get_scheduler 54 | from diffusers.utils import check_min_version, is_wandb_available 55 | from diffusers.utils.import_utils import is_xformers_available 56 | 57 | 58 | if is_wandb_available(): 59 | import wandb 60 | 61 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 62 | PIL_INTERPOLATION = { 63 | "linear": PIL.Image.Resampling.BILINEAR, 64 | "bilinear": PIL.Image.Resampling.BILINEAR, 65 | "bicubic": PIL.Image.Resampling.BICUBIC, 66 | "lanczos": PIL.Image.Resampling.LANCZOS, 67 | "nearest": PIL.Image.Resampling.NEAREST, 68 | } 69 | else: 70 | PIL_INTERPOLATION = { 71 | "linear": PIL.Image.LINEAR, 72 | "bilinear": PIL.Image.BILINEAR, 73 | "bicubic": PIL.Image.BICUBIC, 74 | "lanczos": PIL.Image.LANCZOS, 75 | "nearest": PIL.Image.NEAREST, 76 | } 77 | # ------------------------------------------------------------------------------ 78 | 79 | 80 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 81 | # check_min_version("0.15.0.dev0") 82 | 83 | logger = get_logger(__name__) 84 | 85 | 86 | def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): 87 | logger.info( 88 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 89 | f" {args.validation_prompt}." 90 | ) 91 | # create pipeline (note: unet and vae are loaded again in float32) 92 | pipeline = DiffusionPipeline.from_pretrained( 93 | args.pretrained_model_name_or_path, 94 | text_encoder=accelerator.unwrap_model(text_encoder), 95 | tokenizer=tokenizer, 96 | unet=unet, 97 | vae=vae, 98 | revision=args.revision, 99 | torch_dtype=weight_dtype, 100 | ) 101 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 102 | pipeline = pipeline.to(accelerator.device) 103 | pipeline.set_progress_bar_config(disable=True) 104 | 105 | # run inference 106 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) 107 | images = [] 108 | for _ in range(args.num_validation_images): 109 | with torch.autocast("cuda"): 110 | image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 111 | images.append(image) 112 | 113 | for tracker in accelerator.trackers: 114 | if tracker.name == "tensorboard": 115 | np_images = np.stack([np.asarray(img) for img in images]) 116 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 117 | if tracker.name == "wandb": 118 | tracker.log( 119 | { 120 | "validation": [ 121 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) 122 | ] 123 | } 124 | ) 125 | 126 | del pipeline 127 | torch.cuda.empty_cache() 128 | 129 | 130 | def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): 131 | logger.info("Saving embeddings") 132 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] 133 | learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} 134 | torch.save(learned_embeds_dict, save_path) 135 | 136 | 137 | def parse_args(): 138 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 139 | parser.add_argument( 140 | "--save_steps", 141 | type=int, 142 | default=500, 143 | help="Save learned_embeds.bin every X updates steps.", 144 | ) 145 | parser.add_argument( 146 | "--only_save_embeds", 147 | action="store_true", 148 | default=False, 149 | help="Save only the embeddings for the new concept.", 150 | ) 151 | parser.add_argument( 152 | "--pretrained_model_name_or_path", 153 | type=str, 154 | default=None, 155 | required=True, 156 | help="Path to pretrained model or model identifier from huggingface.co/models.", 157 | ) 158 | parser.add_argument( 159 | "--revision", 160 | type=str, 161 | default=None, 162 | required=False, 163 | help="Revision of pretrained model identifier from huggingface.co/models.", 164 | ) 165 | parser.add_argument( 166 | "--tokenizer_name", 167 | type=str, 168 | default=None, 169 | help="Pretrained tokenizer name or path if not the same as model_name", 170 | ) 171 | parser.add_argument( 172 | "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." 173 | ) 174 | parser.add_argument( 175 | "--placeholder_token", 176 | type=str, 177 | default=None, 178 | required=True, 179 | help="A token to use as a placeholder for the concept.", 180 | ) 181 | parser.add_argument( 182 | "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." 183 | ) 184 | parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") 185 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") 186 | parser.add_argument( 187 | "--output_dir", 188 | type=str, 189 | default="text-inversion-model", 190 | help="The output directory where the model predictions and checkpoints will be written.", 191 | ) 192 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 193 | parser.add_argument( 194 | "--resolution", 195 | type=int, 196 | default=512, 197 | help=( 198 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 199 | " resolution" 200 | ), 201 | ) 202 | parser.add_argument( 203 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution." 204 | ) 205 | parser.add_argument( 206 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 207 | ) 208 | parser.add_argument("--num_train_epochs", type=int, default=100) 209 | parser.add_argument( 210 | "--max_train_steps", 211 | type=int, 212 | default=5000, 213 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 214 | ) 215 | parser.add_argument( 216 | "--gradient_accumulation_steps", 217 | type=int, 218 | default=1, 219 | help="Number of updates steps to accumulate before performing a backward/update pass.", 220 | ) 221 | parser.add_argument( 222 | "--gradient_checkpointing", 223 | action="store_true", 224 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 225 | ) 226 | parser.add_argument( 227 | "--learning_rate", 228 | type=float, 229 | default=1e-4, 230 | help="Initial learning rate (after the potential warmup period) to use.", 231 | ) 232 | parser.add_argument( 233 | "--scale_lr", 234 | action="store_true", 235 | default=False, 236 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 237 | ) 238 | parser.add_argument( 239 | "--lr_scheduler", 240 | type=str, 241 | default="constant", 242 | help=( 243 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 244 | ' "constant", "constant_with_warmup"]' 245 | ), 246 | ) 247 | parser.add_argument( 248 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 249 | ) 250 | parser.add_argument( 251 | "--dataloader_num_workers", 252 | type=int, 253 | default=0, 254 | help=( 255 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 256 | ), 257 | ) 258 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 259 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 260 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 261 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 262 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 263 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 264 | parser.add_argument( 265 | "--hub_model_id", 266 | type=str, 267 | default=None, 268 | help="The name of the repository to keep in sync with the local `output_dir`.", 269 | ) 270 | parser.add_argument( 271 | "--logging_dir", 272 | type=str, 273 | default="logs", 274 | help=( 275 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 276 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 277 | ), 278 | ) 279 | parser.add_argument( 280 | "--mixed_precision", 281 | type=str, 282 | default="no", 283 | choices=["no", "fp16", "bf16"], 284 | help=( 285 | "Whether to use mixed precision. Choose" 286 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 287 | "and an Nvidia Ampere GPU." 288 | ), 289 | ) 290 | parser.add_argument( 291 | "--allow_tf32", 292 | action="store_true", 293 | help=( 294 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 295 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 296 | ), 297 | ) 298 | parser.add_argument( 299 | "--report_to", 300 | type=str, 301 | default="tensorboard", 302 | help=( 303 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 304 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 305 | ), 306 | ) 307 | parser.add_argument( 308 | "--validation_prompt", 309 | type=str, 310 | default=None, 311 | help="A prompt that is used during validation to verify that the model is learning.", 312 | ) 313 | parser.add_argument( 314 | "--num_validation_images", 315 | type=int, 316 | default=4, 317 | help="Number of images that should be generated during validation with `validation_prompt`.", 318 | ) 319 | parser.add_argument( 320 | "--validation_steps", 321 | type=int, 322 | default=100, 323 | help=( 324 | "Run validation every X steps. Validation consists of running the prompt" 325 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 326 | " and logging the images." 327 | ), 328 | ) 329 | parser.add_argument( 330 | "--validation_epochs", 331 | type=int, 332 | default=None, 333 | help=( 334 | "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" 335 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 336 | " and logging the images." 337 | ), 338 | ) 339 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 340 | parser.add_argument( 341 | "--checkpointing_steps", 342 | type=int, 343 | default=500, 344 | help=( 345 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 346 | " training using `--resume_from_checkpoint`." 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--checkpoints_total_limit", 351 | type=int, 352 | default=None, 353 | help=( 354 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 355 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 356 | " for more docs" 357 | ), 358 | ) 359 | parser.add_argument( 360 | "--resume_from_checkpoint", 361 | type=str, 362 | default=None, 363 | help=( 364 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 365 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 366 | ), 367 | ) 368 | parser.add_argument( 369 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 370 | ) 371 | 372 | args = parser.parse_args() 373 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 374 | if env_local_rank != -1 and env_local_rank != args.local_rank: 375 | args.local_rank = env_local_rank 376 | 377 | if args.train_data_dir is None: 378 | raise ValueError("You must specify a train data directory.") 379 | 380 | return args 381 | 382 | 383 | imagenet_templates_small = [ 384 | "a photo of a {}", 385 | "a rendering of a {}", 386 | "a cropped photo of the {}", 387 | "the photo of a {}", 388 | "a photo of a clean {}", 389 | "a photo of a dirty {}", 390 | "a dark photo of the {}", 391 | "a photo of my {}", 392 | "a photo of the cool {}", 393 | "a close-up photo of a {}", 394 | "a bright photo of the {}", 395 | "a cropped photo of a {}", 396 | "a photo of the {}", 397 | "a good photo of the {}", 398 | "a photo of one {}", 399 | "a close-up photo of the {}", 400 | "a rendition of the {}", 401 | "a photo of the clean {}", 402 | "a rendition of a {}", 403 | "a photo of a nice {}", 404 | "a good photo of a {}", 405 | "a photo of the nice {}", 406 | "a photo of the small {}", 407 | "a photo of the weird {}", 408 | "a photo of the large {}", 409 | "a photo of a cool {}", 410 | "a photo of a small {}", 411 | ] 412 | 413 | imagenet_style_templates_small = [ 414 | "a painting in the style of {}", 415 | "a rendering in the style of {}", 416 | "a cropped painting in the style of {}", 417 | "the painting in the style of {}", 418 | "a clean painting in the style of {}", 419 | "a dirty painting in the style of {}", 420 | "a dark painting in the style of {}", 421 | "a picture in the style of {}", 422 | "a cool painting in the style of {}", 423 | "a close-up painting in the style of {}", 424 | "a bright painting in the style of {}", 425 | "a cropped painting in the style of {}", 426 | "a good painting in the style of {}", 427 | "a close-up painting in the style of {}", 428 | "a rendition in the style of {}", 429 | "a nice painting in the style of {}", 430 | "a small painting in the style of {}", 431 | "a weird painting in the style of {}", 432 | "a large painting in the style of {}", 433 | ] 434 | 435 | 436 | class TextualInversionDataset(Dataset): 437 | def __init__( 438 | self, 439 | data_root, 440 | tokenizer, 441 | learnable_property="object", # [object, style] 442 | size=512, 443 | repeats=100, 444 | interpolation="bicubic", 445 | flip_p=0.5, 446 | set="train", 447 | placeholder_token="*", 448 | center_crop=False, 449 | ): 450 | self.data_root = data_root 451 | self.tokenizer = tokenizer 452 | self.learnable_property = learnable_property 453 | self.size = size 454 | self.placeholder_token = placeholder_token 455 | self.center_crop = center_crop 456 | self.flip_p = flip_p 457 | 458 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 459 | 460 | self.num_images = len(self.image_paths) 461 | self._length = self.num_images 462 | 463 | if set == "train": 464 | self._length = self.num_images * repeats 465 | 466 | self.interpolation = { 467 | "linear": PIL_INTERPOLATION["linear"], 468 | "bilinear": PIL_INTERPOLATION["bilinear"], 469 | "bicubic": PIL_INTERPOLATION["bicubic"], 470 | "lanczos": PIL_INTERPOLATION["lanczos"], 471 | }[interpolation] 472 | 473 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small 474 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 475 | 476 | def __len__(self): 477 | return self._length 478 | 479 | def __getitem__(self, i): 480 | example = {} 481 | image = Image.open(self.image_paths[i % self.num_images]) 482 | 483 | if not image.mode == "RGB": 484 | image = image.convert("RGB") 485 | 486 | placeholder_string = self.placeholder_token 487 | text = random.choice(self.templates).format(placeholder_string) 488 | 489 | example["input_ids"] = self.tokenizer( 490 | text, 491 | padding="max_length", 492 | truncation=True, 493 | max_length=self.tokenizer.model_max_length, 494 | return_tensors="pt", 495 | ).input_ids[0] 496 | 497 | # default to score-sde preprocessing 498 | img = np.array(image).astype(np.uint8) 499 | 500 | if self.center_crop: 501 | crop = min(img.shape[0], img.shape[1]) 502 | ( 503 | h, 504 | w, 505 | ) = ( 506 | img.shape[0], 507 | img.shape[1], 508 | ) 509 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 510 | 511 | image = Image.fromarray(img) 512 | image = image.resize((self.size, self.size), resample=self.interpolation) 513 | 514 | image = self.flip_transform(image) 515 | image = np.array(image).astype(np.uint8) 516 | image = (image / 127.5 - 1.0).astype(np.float32) 517 | 518 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 519 | return example 520 | 521 | 522 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 523 | if token is None: 524 | token = HfFolder.get_token() 525 | if organization is None: 526 | username = whoami(token)["name"] 527 | return f"{username}/{model_id}" 528 | else: 529 | return f"{organization}/{model_id}" 530 | 531 | 532 | def main(): 533 | args = parse_args() 534 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 535 | 536 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) 537 | 538 | accelerator = Accelerator( 539 | gradient_accumulation_steps=args.gradient_accumulation_steps, 540 | mixed_precision=args.mixed_precision, 541 | log_with=args.report_to, 542 | logging_dir=logging_dir, 543 | project_config=accelerator_project_config, 544 | ) 545 | 546 | if args.report_to == "wandb": 547 | if not is_wandb_available(): 548 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 549 | 550 | # Make one log on every process with the configuration for debugging. 551 | logging.basicConfig( 552 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 553 | datefmt="%m/%d/%Y %H:%M:%S", 554 | level=logging.INFO, 555 | ) 556 | logger.info(accelerator.state, main_process_only=False) 557 | if accelerator.is_local_main_process: 558 | transformers.utils.logging.set_verbosity_warning() 559 | diffusers.utils.logging.set_verbosity_info() 560 | else: 561 | transformers.utils.logging.set_verbosity_error() 562 | diffusers.utils.logging.set_verbosity_error() 563 | 564 | # If passed along, set the training seed now. 565 | if args.seed is not None: 566 | set_seed(args.seed) 567 | 568 | # Handle the repository creation 569 | if accelerator.is_main_process: 570 | if args.push_to_hub: 571 | if args.hub_model_id is None: 572 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 573 | else: 574 | repo_name = args.hub_model_id 575 | create_repo(repo_name, exist_ok=True, token=args.hub_token) 576 | repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) 577 | 578 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 579 | if "step_*" not in gitignore: 580 | gitignore.write("step_*\n") 581 | if "epoch_*" not in gitignore: 582 | gitignore.write("epoch_*\n") 583 | elif args.output_dir is not None: 584 | os.makedirs(args.output_dir, exist_ok=True) 585 | 586 | # Load tokenizer 587 | if args.tokenizer_name: 588 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 589 | elif args.pretrained_model_name_or_path: 590 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 591 | 592 | # Load scheduler and models 593 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 594 | text_encoder = CLIPTextModel.from_pretrained( 595 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 596 | ) 597 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 598 | unet = UNet2DConditionModel.from_pretrained( 599 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 600 | ) 601 | 602 | # Add the placeholder token in tokenizer 603 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 604 | if num_added_tokens == 0: 605 | raise ValueError( 606 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 607 | " `placeholder_token` that is not already in the tokenizer." 608 | ) 609 | 610 | # Convert the initializer_token, placeholder_token to ids 611 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 612 | # Check if initializer_token is a single token or a sequence of tokens 613 | if len(token_ids) > 1: 614 | raise ValueError("The initializer token must be a single token.") 615 | 616 | initializer_token_id = token_ids[0] 617 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 618 | 619 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 620 | text_encoder.resize_token_embeddings(len(tokenizer)) 621 | 622 | # Initialise the newly added placeholder token with the embeddings of the initializer token 623 | token_embeds = text_encoder.get_input_embeddings().weight.data 624 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] 625 | 626 | # Freeze vae and unet 627 | vae.requires_grad_(False) 628 | unet.requires_grad_(False) 629 | # Freeze all parameters except for the token embeddings in text encoder 630 | text_encoder.text_model.encoder.requires_grad_(False) 631 | text_encoder.text_model.final_layer_norm.requires_grad_(False) 632 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 633 | 634 | if args.gradient_checkpointing: 635 | # Keep unet in train mode if we are using gradient checkpointing to save memory. 636 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. 637 | unet.train() 638 | text_encoder.gradient_checkpointing_enable() 639 | unet.enable_gradient_checkpointing() 640 | 641 | if args.enable_xformers_memory_efficient_attention: 642 | if is_xformers_available(): 643 | import xformers 644 | 645 | xformers_version = version.parse(xformers.__version__) 646 | if xformers_version == version.parse("0.0.16"): 647 | logger.warn( 648 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 649 | ) 650 | unet.enable_xformers_memory_efficient_attention() 651 | else: 652 | raise ValueError("xformers is not available. Make sure it is installed correctly") 653 | 654 | # Enable TF32 for faster training on Ampere GPUs, 655 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 656 | if args.allow_tf32: 657 | torch.backends.cuda.matmul.allow_tf32 = True 658 | 659 | if args.scale_lr: 660 | args.learning_rate = ( 661 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 662 | ) 663 | 664 | # Initialize the optimizer 665 | optimizer = torch.optim.AdamW( 666 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 667 | lr=args.learning_rate, 668 | betas=(args.adam_beta1, args.adam_beta2), 669 | weight_decay=args.adam_weight_decay, 670 | eps=args.adam_epsilon, 671 | ) 672 | 673 | # Dataset and DataLoaders creation: 674 | train_dataset = TextualInversionDataset( 675 | data_root=args.train_data_dir, 676 | tokenizer=tokenizer, 677 | size=args.resolution, 678 | placeholder_token=args.placeholder_token, 679 | repeats=args.repeats, 680 | learnable_property=args.learnable_property, 681 | center_crop=args.center_crop, 682 | set="train", 683 | ) 684 | train_dataloader = torch.utils.data.DataLoader( 685 | train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers 686 | ) 687 | if args.validation_epochs is not None: 688 | warnings.warn( 689 | f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." 690 | " Deprecated validation_epochs in favor of `validation_steps`" 691 | f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", 692 | FutureWarning, 693 | stacklevel=2, 694 | ) 695 | args.validation_steps = args.validation_epochs * len(train_dataset) 696 | 697 | # Scheduler and math around the number of training steps. 698 | overrode_max_train_steps = False 699 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 700 | if args.max_train_steps is None: 701 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 702 | overrode_max_train_steps = True 703 | 704 | lr_scheduler = get_scheduler( 705 | args.lr_scheduler, 706 | optimizer=optimizer, 707 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 708 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 709 | ) 710 | 711 | # Prepare everything with our `accelerator`. 712 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 713 | text_encoder, optimizer, train_dataloader, lr_scheduler 714 | ) 715 | 716 | # For mixed precision training we cast the unet and vae weights to half-precision 717 | # as these models are only used for inference, keeping weights in full precision is not required. 718 | weight_dtype = torch.float32 719 | if accelerator.mixed_precision == "fp16": 720 | weight_dtype = torch.float16 721 | elif accelerator.mixed_precision == "bf16": 722 | weight_dtype = torch.bfloat16 723 | 724 | # Move vae and unet to device and cast to weight_dtype 725 | unet.to(accelerator.device, dtype=weight_dtype) 726 | vae.to(accelerator.device, dtype=weight_dtype) 727 | 728 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 729 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 730 | if overrode_max_train_steps: 731 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 732 | # Afterwards we recalculate our number of training epochs 733 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 734 | 735 | # We need to initialize the trackers we use, and also store our configuration. 736 | # The trackers initializes automatically on the main process. 737 | if accelerator.is_main_process: 738 | accelerator.init_trackers("p_plust_xti", config=vars(args)) 739 | 740 | # Train! 741 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 742 | 743 | logger.info("***** Running training *****") 744 | logger.info(f" Num examples = {len(train_dataset)}") 745 | logger.info(f" Num Epochs = {args.num_train_epochs}") 746 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 747 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 748 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 749 | logger.info(f" Total optimization steps = {args.max_train_steps}") 750 | global_step = 0 751 | first_epoch = 0 752 | # Potentially load in the weights and states from a previous save 753 | if args.resume_from_checkpoint: 754 | if args.resume_from_checkpoint != "latest": 755 | path = os.path.basename(args.resume_from_checkpoint) 756 | else: 757 | # Get the most recent checkpoint 758 | dirs = os.listdir(args.output_dir) 759 | dirs = [d for d in dirs if d.startswith("checkpoint")] 760 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 761 | path = dirs[-1] if len(dirs) > 0 else None 762 | 763 | if path is None: 764 | accelerator.print( 765 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 766 | ) 767 | args.resume_from_checkpoint = None 768 | else: 769 | accelerator.print(f"Resuming from checkpoint {path}") 770 | accelerator.load_state(os.path.join(args.output_dir, path)) 771 | global_step = int(path.split("-")[1]) 772 | 773 | resume_global_step = global_step * args.gradient_accumulation_steps 774 | first_epoch = global_step // num_update_steps_per_epoch 775 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 776 | 777 | # Only show the progress bar once on each machine. 778 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 779 | progress_bar.set_description("Steps") 780 | 781 | # keep original embeddings as reference 782 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() 783 | 784 | for epoch in range(first_epoch, args.num_train_epochs): 785 | text_encoder.train() 786 | for step, batch in enumerate(train_dataloader): 787 | # Skip steps until we reach the resumed step 788 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 789 | if step % args.gradient_accumulation_steps == 0: 790 | progress_bar.update(1) 791 | continue 792 | 793 | with accelerator.accumulate(text_encoder): 794 | # Convert images to latent space 795 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() 796 | latents = latents * vae.config.scaling_factor 797 | 798 | # Sample noise that we'll add to the latents 799 | noise = torch.randn_like(latents) 800 | bsz = latents.shape[0] 801 | # Sample a random timestep for each image 802 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 803 | timesteps = timesteps.long() 804 | 805 | # Add noise to the latents according to the noise magnitude at each timestep 806 | # (this is the forward diffusion process) 807 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 808 | 809 | # Get the text embedding for conditioning 810 | encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) 811 | 812 | # Predict the noise residual 813 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 814 | 815 | # Get the target for loss depending on the prediction type 816 | if noise_scheduler.config.prediction_type == "epsilon": 817 | target = noise 818 | elif noise_scheduler.config.prediction_type == "v_prediction": 819 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 820 | else: 821 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 822 | 823 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 824 | 825 | accelerator.backward(loss) 826 | 827 | optimizer.step() 828 | lr_scheduler.step() 829 | optimizer.zero_grad() 830 | 831 | # Let's make sure we don't update any embedding weights besides the newly added token 832 | index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id 833 | with torch.no_grad(): 834 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 835 | index_no_updates 836 | ] = orig_embeds_params[index_no_updates] 837 | 838 | # Checks if the accelerator has performed an optimization step behind the scenes 839 | if accelerator.sync_gradients: 840 | progress_bar.update(1) 841 | global_step += 1 842 | if global_step % args.save_steps == 0: 843 | save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") 844 | save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) 845 | 846 | if accelerator.is_main_process: 847 | if global_step % args.checkpointing_steps == 0: 848 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 849 | accelerator.save_state(save_path) 850 | logger.info(f"Saved state to {save_path}") 851 | 852 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 853 | log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) 854 | 855 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 856 | progress_bar.set_postfix(**logs) 857 | accelerator.log(logs, step=global_step) 858 | 859 | if global_step >= args.max_train_steps: 860 | break 861 | # Create the pipeline using using the trained modules and save it. 862 | accelerator.wait_for_everyone() 863 | if accelerator.is_main_process: 864 | if args.push_to_hub and args.only_save_embeds: 865 | logger.warn("Enabling full model saving because --push_to_hub=True was specified.") 866 | save_full_model = True 867 | else: 868 | save_full_model = not args.only_save_embeds 869 | if save_full_model: 870 | pipeline = StableDiffusionPipeline.from_pretrained( 871 | args.pretrained_model_name_or_path, 872 | text_encoder=accelerator.unwrap_model(text_encoder), 873 | vae=vae, 874 | unet=unet, 875 | tokenizer=tokenizer, 876 | ) 877 | pipeline.save_pretrained(args.output_dir) 878 | # Save the newly trained embeddings 879 | save_path = os.path.join(args.output_dir, "learned_embeds.bin") 880 | save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) 881 | 882 | if args.push_to_hub: 883 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 884 | 885 | accelerator.end_training() 886 | 887 | 888 | if __name__ == "__main__": 889 | main() 890 | -------------------------------------------------------------------------------- /train_p_plus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import json 6 | import random 7 | import warnings 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import numpy as np 12 | import PIL 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint 16 | import transformers 17 | from accelerate import Accelerator 18 | from accelerate.logging import get_logger 19 | from accelerate.utils import ProjectConfiguration, set_seed 20 | from huggingface_hub import HfFolder, Repository, create_repo, whoami 21 | 22 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released 23 | from packaging import version 24 | from PIL import Image 25 | from torch.utils.data import Dataset 26 | from torchvision import transforms 27 | from tqdm.auto import tqdm 28 | from transformers import CLIPTextModel, CLIPTokenizer 29 | 30 | import diffusers 31 | from diffusers import ( 32 | AutoencoderKL, 33 | DDPMScheduler, 34 | DiffusionPipeline, 35 | DPMSolverMultistepScheduler, 36 | StableDiffusionPipeline, 37 | ) 38 | from diffusers.optimization import get_scheduler 39 | from diffusers.utils import check_min_version, is_wandb_available 40 | from diffusers.utils.import_utils import is_xformers_available 41 | from prompt_plus import PPlusUNet2DConditionModel, PPlusStableDiffusionPipeline 42 | 43 | 44 | if is_wandb_available(): 45 | import wandb 46 | 47 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 48 | PIL_INTERPOLATION = { 49 | "linear": PIL.Image.Resampling.BILINEAR, 50 | "bilinear": PIL.Image.Resampling.BILINEAR, 51 | "bicubic": PIL.Image.Resampling.BICUBIC, 52 | "lanczos": PIL.Image.Resampling.LANCZOS, 53 | "nearest": PIL.Image.Resampling.NEAREST, 54 | } 55 | else: 56 | PIL_INTERPOLATION = { 57 | "linear": PIL.Image.LINEAR, 58 | "bilinear": PIL.Image.BILINEAR, 59 | "bicubic": PIL.Image.BICUBIC, 60 | "lanczos": PIL.Image.LANCZOS, 61 | "nearest": PIL.Image.NEAREST, 62 | } 63 | # ------------------------------------------------------------------------------ 64 | 65 | 66 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 67 | # check_min_version("0.15.0.dev0") 68 | 69 | logger = get_logger(__name__) 70 | 71 | 72 | def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): 73 | logger.info( 74 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 75 | f" {args.validation_prompt}." 76 | ) 77 | # create pipeline (note: unet and vae are loaded again in float32) 78 | pipeline = DiffusionPipeline.from_pretrained( 79 | args.pretrained_model_name_or_path, 80 | text_encoder=accelerator.unwrap_model(text_encoder), 81 | tokenizer=tokenizer, 82 | unet=unet, 83 | vae=vae, 84 | revision=args.revision, 85 | torch_dtype=weight_dtype, 86 | ) 87 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 88 | pipeline = pipeline.to(accelerator.device) 89 | pipeline.set_progress_bar_config(disable=True) 90 | 91 | # run inference 92 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) 93 | images = [] 94 | for _ in range(args.num_validation_images): 95 | with torch.autocast("cuda"): 96 | image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 97 | images.append(image) 98 | 99 | for tracker in accelerator.trackers: 100 | if tracker.name == "tensorboard": 101 | np_images = np.stack([np.asarray(img) for img in images]) 102 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 103 | if tracker.name == "wandb": 104 | tracker.log( 105 | { 106 | "validation": [ 107 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) 108 | ] 109 | } 110 | ) 111 | 112 | del pipeline 113 | torch.cuda.empty_cache() 114 | 115 | 116 | def save_progress(text_encoder, placeholder_tokens, placeholder_token_ids, accelerator, args, save_path): 117 | logger.info("Saving embeddings") 118 | learned_embeds_dict = dict() 119 | for placeholder_token, placeholder_token_id in zip(placeholder_tokens, placeholder_token_ids): 120 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] 121 | learned_embeds_dict[placeholder_token] = learned_embeds.detach().cpu() 122 | torch.save(learned_embeds_dict, save_path) 123 | with open(os.path.join(os.path.dirname(save_path), "config.json"), "w") as f: 124 | json.dump(args.__dict__, f, indent=2) 125 | 126 | 127 | def parse_args(): 128 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 129 | parser.add_argument( 130 | "--save_steps", 131 | type=int, 132 | default=500, 133 | help="Save learned_embeds.bin every X updates steps.", 134 | ) 135 | parser.add_argument( 136 | "--only_save_embeds", 137 | action="store_true", 138 | default=False, 139 | help="Save only the embeddings for the new concept.", 140 | ) 141 | parser.add_argument( 142 | "--pretrained_model_name_or_path", 143 | type=str, 144 | default=None, 145 | required=True, 146 | help="Path to pretrained model or model identifier from huggingface.co/models.", 147 | ) 148 | parser.add_argument( 149 | "--revision", 150 | type=str, 151 | default=None, 152 | required=False, 153 | help="Revision of pretrained model identifier from huggingface.co/models.", 154 | ) 155 | parser.add_argument( 156 | "--tokenizer_name", 157 | type=str, 158 | default=None, 159 | help="Pretrained tokenizer name or path if not the same as model_name", 160 | ) 161 | parser.add_argument( 162 | "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." 163 | ) 164 | parser.add_argument( 165 | "--placeholder_token", 166 | type=str, 167 | default=None, 168 | required=True, 169 | help="A token to use as a placeholder for the concept.", 170 | ) 171 | parser.add_argument( 172 | "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." 173 | ) 174 | parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") 175 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") 176 | parser.add_argument( 177 | "--output_dir", 178 | type=str, 179 | default="text-inversion-model", 180 | help="The output directory where the model predictions and checkpoints will be written.", 181 | ) 182 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 183 | parser.add_argument( 184 | "--resolution", 185 | type=int, 186 | default=512, 187 | help=( 188 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 189 | " resolution" 190 | ), 191 | ) 192 | parser.add_argument( 193 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution." 194 | ) 195 | parser.add_argument( 196 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 197 | ) 198 | parser.add_argument("--num_train_epochs", type=int, default=100) 199 | parser.add_argument( 200 | "--max_train_steps", 201 | type=int, 202 | default=5000, 203 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 204 | ) 205 | parser.add_argument( 206 | "--gradient_accumulation_steps", 207 | type=int, 208 | default=1, 209 | help="Number of updates steps to accumulate before performing a backward/update pass.", 210 | ) 211 | parser.add_argument( 212 | "--gradient_checkpointing", 213 | action="store_true", 214 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 215 | ) 216 | parser.add_argument( 217 | "--learning_rate", 218 | type=float, 219 | default=1e-4, 220 | help="Initial learning rate (after the potential warmup period) to use.", 221 | ) 222 | parser.add_argument( 223 | "--scale_lr", 224 | action="store_true", 225 | default=False, 226 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 227 | ) 228 | parser.add_argument( 229 | "--lr_scheduler", 230 | type=str, 231 | default="constant", 232 | help=( 233 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 234 | ' "constant", "constant_with_warmup"]' 235 | ), 236 | ) 237 | parser.add_argument( 238 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 239 | ) 240 | parser.add_argument( 241 | "--dataloader_num_workers", 242 | type=int, 243 | default=0, 244 | help=( 245 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 246 | ), 247 | ) 248 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 249 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 250 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 251 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 252 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 253 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 254 | parser.add_argument( 255 | "--hub_model_id", 256 | type=str, 257 | default=None, 258 | help="The name of the repository to keep in sync with the local `output_dir`.", 259 | ) 260 | parser.add_argument( 261 | "--logging_dir", 262 | type=str, 263 | default="logs", 264 | help=( 265 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 266 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 267 | ), 268 | ) 269 | parser.add_argument( 270 | "--mixed_precision", 271 | type=str, 272 | default="no", 273 | choices=["no", "fp16", "bf16"], 274 | help=( 275 | "Whether to use mixed precision. Choose" 276 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 277 | "and an Nvidia Ampere GPU." 278 | ), 279 | ) 280 | parser.add_argument( 281 | "--allow_tf32", 282 | action="store_true", 283 | help=( 284 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 285 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 286 | ), 287 | ) 288 | parser.add_argument( 289 | "--report_to", 290 | type=str, 291 | default="tensorboard", 292 | help=( 293 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 294 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 295 | ), 296 | ) 297 | parser.add_argument( 298 | "--validation_prompt", 299 | type=str, 300 | default=None, 301 | help="A prompt that is used during validation to verify that the model is learning.", 302 | ) 303 | parser.add_argument( 304 | "--num_validation_images", 305 | type=int, 306 | default=4, 307 | help="Number of images that should be generated during validation with `validation_prompt`.", 308 | ) 309 | parser.add_argument( 310 | "--validation_steps", 311 | type=int, 312 | default=100, 313 | help=( 314 | "Run validation every X steps. Validation consists of running the prompt" 315 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 316 | " and logging the images." 317 | ), 318 | ) 319 | parser.add_argument( 320 | "--validation_epochs", 321 | type=int, 322 | default=None, 323 | help=( 324 | "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" 325 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 326 | " and logging the images." 327 | ), 328 | ) 329 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 330 | parser.add_argument( 331 | "--checkpointing_steps", 332 | type=int, 333 | default=500, 334 | help=( 335 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 336 | " training using `--resume_from_checkpoint`." 337 | ), 338 | ) 339 | parser.add_argument( 340 | "--checkpoints_total_limit", 341 | type=int, 342 | default=None, 343 | help=( 344 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 345 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 346 | " for more docs" 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--resume_from_checkpoint", 351 | type=str, 352 | default=None, 353 | help=( 354 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 355 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 356 | ), 357 | ) 358 | parser.add_argument( 359 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 360 | ) 361 | 362 | args = parser.parse_args() 363 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 364 | if env_local_rank != -1 and env_local_rank != args.local_rank: 365 | args.local_rank = env_local_rank 366 | 367 | if args.train_data_dir is None: 368 | raise ValueError("You must specify a train data directory.") 369 | 370 | return args 371 | 372 | 373 | imagenet_templates_small = [ 374 | "a photo of a {}", 375 | "a rendering of a {}", 376 | "a cropped photo of the {}", 377 | "the photo of a {}", 378 | "a photo of a clean {}", 379 | "a photo of a dirty {}", 380 | "a dark photo of the {}", 381 | "a photo of my {}", 382 | "a photo of the cool {}", 383 | "a close-up photo of a {}", 384 | "a bright photo of the {}", 385 | "a cropped photo of a {}", 386 | "a photo of the {}", 387 | "a good photo of the {}", 388 | "a photo of one {}", 389 | "a close-up photo of the {}", 390 | "a rendition of the {}", 391 | "a photo of the clean {}", 392 | "a rendition of a {}", 393 | "a photo of a nice {}", 394 | "a good photo of a {}", 395 | "a photo of the nice {}", 396 | "a photo of the small {}", 397 | "a photo of the weird {}", 398 | "a photo of the large {}", 399 | "a photo of a cool {}", 400 | "a photo of a small {}", 401 | ] 402 | 403 | imagenet_style_templates_small = [ 404 | "a painting in the style of {}", 405 | "a rendering in the style of {}", 406 | "a cropped painting in the style of {}", 407 | "the painting in the style of {}", 408 | "a clean painting in the style of {}", 409 | "a dirty painting in the style of {}", 410 | "a dark painting in the style of {}", 411 | "a picture in the style of {}", 412 | "a cool painting in the style of {}", 413 | "a close-up painting in the style of {}", 414 | "a bright painting in the style of {}", 415 | "a cropped painting in the style of {}", 416 | "a good painting in the style of {}", 417 | "a close-up painting in the style of {}", 418 | "a rendition in the style of {}", 419 | "a nice painting in the style of {}", 420 | "a small painting in the style of {}", 421 | "a weird painting in the style of {}", 422 | "a large painting in the style of {}", 423 | ] 424 | 425 | 426 | class TextualInversionDataset(Dataset): 427 | def __init__( 428 | self, 429 | data_root, 430 | tokenizer, 431 | learnable_property="object", # [object, style] 432 | size=512, 433 | repeats=100, 434 | interpolation="bicubic", 435 | flip_p=0.5, 436 | set="train", 437 | placeholder_tokens=None, 438 | center_crop=False, 439 | ): 440 | assert isinstance(placeholder_tokens, list) 441 | self.data_root = data_root 442 | self.tokenizer = tokenizer 443 | self.learnable_property = learnable_property 444 | self.size = size 445 | self.placeholder_tokens = placeholder_tokens 446 | self.center_crop = center_crop 447 | self.flip_p = flip_p 448 | 449 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 450 | 451 | self.num_images = len(self.image_paths) 452 | self._length = self.num_images 453 | 454 | if set == "train": 455 | self._length = self.num_images * repeats 456 | 457 | self.interpolation = { 458 | "linear": PIL_INTERPOLATION["linear"], 459 | "bilinear": PIL_INTERPOLATION["bilinear"], 460 | "bicubic": PIL_INTERPOLATION["bicubic"], 461 | "lanczos": PIL_INTERPOLATION["lanczos"], 462 | }[interpolation] 463 | 464 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small 465 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 466 | 467 | def __len__(self): 468 | return self._length 469 | 470 | def __getitem__(self, i): 471 | example = {} 472 | image = Image.open(self.image_paths[i % self.num_images]) 473 | 474 | if not image.mode == "RGB": 475 | image = image.convert("RGB") 476 | 477 | template = random.choice(self.templates) 478 | text = [template.format(placeholder_string) for placeholder_string in self.placeholder_tokens] 479 | example["input_ids"] = self.tokenizer( 480 | text, 481 | padding="max_length", 482 | truncation=True, 483 | max_length=self.tokenizer.model_max_length, 484 | return_tensors="pt", 485 | ).input_ids 486 | # (num_new_tokens, seq_length) 487 | 488 | # default to score-sde preprocessing 489 | img = np.array(image).astype(np.uint8) 490 | 491 | if self.center_crop: 492 | crop = min(img.shape[0], img.shape[1]) 493 | ( 494 | h, 495 | w, 496 | ) = ( 497 | img.shape[0], 498 | img.shape[1], 499 | ) 500 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 501 | 502 | image = Image.fromarray(img) 503 | image = image.resize((self.size, self.size), resample=self.interpolation) 504 | 505 | image = self.flip_transform(image) 506 | image = np.array(image).astype(np.uint8) 507 | image = (image / 127.5 - 1.0).astype(np.float32) 508 | 509 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 510 | return example 511 | 512 | 513 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 514 | if token is None: 515 | token = HfFolder.get_token() 516 | if organization is None: 517 | username = whoami(token)["name"] 518 | return f"{username}/{model_id}" 519 | else: 520 | return f"{organization}/{model_id}" 521 | 522 | 523 | def main(): 524 | args = parse_args() 525 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 526 | 527 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) 528 | 529 | accelerator = Accelerator( 530 | gradient_accumulation_steps=args.gradient_accumulation_steps, 531 | mixed_precision=args.mixed_precision, 532 | log_with=args.report_to, 533 | logging_dir=logging_dir, 534 | project_config=accelerator_project_config, 535 | ) 536 | 537 | if args.report_to == "wandb": 538 | if not is_wandb_available(): 539 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 540 | 541 | # Make one log on every process with the configuration for debugging. 542 | logging.basicConfig( 543 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 544 | datefmt="%m/%d/%Y %H:%M:%S", 545 | level=logging.INFO, 546 | ) 547 | logger.info(accelerator.state, main_process_only=False) 548 | if accelerator.is_local_main_process: 549 | transformers.utils.logging.set_verbosity_warning() 550 | diffusers.utils.logging.set_verbosity_info() 551 | else: 552 | transformers.utils.logging.set_verbosity_error() 553 | diffusers.utils.logging.set_verbosity_error() 554 | 555 | # If passed along, set the training seed now. 556 | if args.seed is not None: 557 | set_seed(args.seed) 558 | 559 | # Handle the repository creation 560 | if accelerator.is_main_process: 561 | if args.push_to_hub: 562 | if args.hub_model_id is None: 563 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 564 | else: 565 | repo_name = args.hub_model_id 566 | create_repo(repo_name, exist_ok=True, token=args.hub_token) 567 | repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) 568 | 569 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 570 | if "step_*" not in gitignore: 571 | gitignore.write("step_*\n") 572 | if "epoch_*" not in gitignore: 573 | gitignore.write("epoch_*\n") 574 | elif args.output_dir is not None: 575 | os.makedirs(args.output_dir, exist_ok=True) 576 | 577 | # Load tokenizer 578 | if args.tokenizer_name: 579 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 580 | elif args.pretrained_model_name_or_path: 581 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 582 | 583 | # Load scheduler and models 584 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 585 | text_encoder = CLIPTextModel.from_pretrained( 586 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 587 | ) 588 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 589 | unet = PPlusUNet2DConditionModel.from_pretrained( 590 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 591 | ) 592 | 593 | # Convert the initializer_token, placeholder_token to ids 594 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 595 | # Check if initializer_token is a single token or a sequence of tokens 596 | if len(token_ids) > 1: 597 | raise ValueError("The initializer token must be a single token.") 598 | 599 | initializer_token_id = token_ids[0] 600 | 601 | # TODO: more flexible (16 cross attention layers for stable diffusion) 602 | num_cross_attn_layers = 16 603 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 604 | text_encoder.resize_token_embeddings(len(tokenizer)+num_cross_attn_layers) 605 | # Initialise the newly added placeholder token with the embeddings of the initializer token 606 | token_embeds = text_encoder.get_input_embeddings().weight.data 607 | # Add the placeholder token in tokenizer 608 | placeholder_tokens = [] 609 | placeholder_token_ids = [] 610 | for i in range(num_cross_attn_layers): 611 | placeholder_token = f"{args.placeholder_token}-{i}" 612 | num_added_tokens = tokenizer.add_tokens(placeholder_token) 613 | if num_added_tokens == 0: 614 | raise ValueError( 615 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 616 | " `placeholder_token` that is not already in the tokenizer." 617 | ) 618 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) 619 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] 620 | placeholder_tokens.append(placeholder_token) 621 | placeholder_token_ids.append(placeholder_token_id) 622 | 623 | # Freeze vae and unet 624 | vae.requires_grad_(False) 625 | unet.requires_grad_(False) 626 | # Freeze all parameters except for the token embeddings in text encoder 627 | text_encoder.text_model.encoder.requires_grad_(False) 628 | text_encoder.text_model.final_layer_norm.requires_grad_(False) 629 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 630 | 631 | if args.gradient_checkpointing: 632 | # Keep unet in train mode if we are using gradient checkpointing to save memory. 633 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. 634 | unet.train() 635 | text_encoder.gradient_checkpointing_enable() 636 | unet.enable_gradient_checkpointing() 637 | 638 | if args.enable_xformers_memory_efficient_attention: 639 | if is_xformers_available(): 640 | import xformers 641 | 642 | xformers_version = version.parse(xformers.__version__) 643 | if xformers_version == version.parse("0.0.16"): 644 | logger.warn( 645 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 646 | ) 647 | unet.enable_xformers_memory_efficient_attention() 648 | else: 649 | raise ValueError("xformers is not available. Make sure it is installed correctly") 650 | 651 | # Enable TF32 for faster training on Ampere GPUs, 652 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 653 | if args.allow_tf32: 654 | torch.backends.cuda.matmul.allow_tf32 = True 655 | 656 | if args.scale_lr: 657 | args.learning_rate = ( 658 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 659 | ) 660 | 661 | # Initialize the optimizer 662 | optimizer = torch.optim.AdamW( 663 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 664 | lr=args.learning_rate, 665 | betas=(args.adam_beta1, args.adam_beta2), 666 | weight_decay=args.adam_weight_decay, 667 | eps=args.adam_epsilon, 668 | ) 669 | 670 | # Dataset and DataLoaders creation: 671 | train_dataset = TextualInversionDataset( 672 | data_root=args.train_data_dir, 673 | tokenizer=tokenizer, 674 | size=args.resolution, 675 | placeholder_tokens=placeholder_tokens, 676 | repeats=args.repeats, 677 | learnable_property=args.learnable_property, 678 | center_crop=args.center_crop, 679 | set="train", 680 | ) 681 | train_dataloader = torch.utils.data.DataLoader( 682 | train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers 683 | ) 684 | if args.validation_epochs is not None: 685 | warnings.warn( 686 | f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." 687 | " Deprecated validation_epochs in favor of `validation_steps`" 688 | f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", 689 | FutureWarning, 690 | stacklevel=2, 691 | ) 692 | args.validation_steps = args.validation_epochs * len(train_dataset) 693 | 694 | # Scheduler and math around the number of training steps. 695 | overrode_max_train_steps = False 696 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 697 | if args.max_train_steps is None: 698 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 699 | overrode_max_train_steps = True 700 | 701 | lr_scheduler = get_scheduler( 702 | args.lr_scheduler, 703 | optimizer=optimizer, 704 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 705 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 706 | ) 707 | 708 | # Prepare everything with our `accelerator`. 709 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 710 | text_encoder, optimizer, train_dataloader, lr_scheduler 711 | ) 712 | 713 | # For mixed precision training we cast the unet and vae weights to half-precision 714 | # as these models are only used for inference, keeping weights in full precision is not required. 715 | weight_dtype = torch.float32 716 | if accelerator.mixed_precision == "fp16": 717 | weight_dtype = torch.float16 718 | elif accelerator.mixed_precision == "bf16": 719 | weight_dtype = torch.bfloat16 720 | 721 | # Move vae and unet to device and cast to weight_dtype 722 | unet.to(accelerator.device, dtype=weight_dtype) 723 | vae.to(accelerator.device, dtype=weight_dtype) 724 | 725 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 726 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 727 | if overrode_max_train_steps: 728 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 729 | # Afterwards we recalculate our number of training epochs 730 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 731 | 732 | # We need to initialize the trackers we use, and also store our configuration. 733 | # The trackers initializes automatically on the main process. 734 | if accelerator.is_main_process: 735 | accelerator.init_trackers("p_plus_xti", config=vars(args)) 736 | 737 | # Train! 738 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 739 | 740 | logger.info("***** Running training *****") 741 | logger.info(f" Num examples = {len(train_dataset)}") 742 | logger.info(f" Num Epochs = {args.num_train_epochs}") 743 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 744 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 745 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 746 | logger.info(f" Total optimization steps = {args.max_train_steps}") 747 | global_step = 0 748 | first_epoch = 0 749 | # Potentially load in the weights and states from a previous save 750 | if args.resume_from_checkpoint: 751 | if args.resume_from_checkpoint != "latest": 752 | path = os.path.basename(args.resume_from_checkpoint) 753 | else: 754 | # Get the most recent checkpoint 755 | dirs = os.listdir(args.output_dir) 756 | dirs = [d for d in dirs if d.startswith("checkpoint")] 757 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 758 | path = dirs[-1] if len(dirs) > 0 else None 759 | 760 | if path is None: 761 | accelerator.print( 762 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 763 | ) 764 | args.resume_from_checkpoint = None 765 | else: 766 | accelerator.print(f"Resuming from checkpoint {path}") 767 | accelerator.load_state(os.path.join(args.output_dir, path)) 768 | global_step = int(path.split("-")[1]) 769 | 770 | resume_global_step = global_step * args.gradient_accumulation_steps 771 | first_epoch = global_step // num_update_steps_per_epoch 772 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 773 | 774 | # Only show the progress bar once on each machine. 775 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 776 | progress_bar.set_description("Steps") 777 | 778 | # keep original embeddings as reference 779 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() 780 | 781 | for epoch in range(first_epoch, args.num_train_epochs): 782 | text_encoder.train() 783 | for step, batch in enumerate(train_dataloader): 784 | # Skip steps until we reach the resumed step 785 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 786 | if step % args.gradient_accumulation_steps == 0: 787 | progress_bar.update(1) 788 | continue 789 | 790 | with accelerator.accumulate(text_encoder): 791 | # Convert images to latent space 792 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() 793 | latents = latents * vae.config.scaling_factor 794 | 795 | # Sample noise that we'll add to the latents 796 | noise = torch.randn_like(latents) 797 | bsz = latents.shape[0] 798 | # Sample a random timestep for each image 799 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 800 | timesteps = timesteps.long() 801 | 802 | # Add noise to the latents according to the noise magnitude at each timestep 803 | # (this is the forward diffusion process) 804 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 805 | 806 | # Get the text embedding for conditioning 807 | encoder_hidden_states_list = [] 808 | num_new_tokens = batch["input_ids"].size(1) 809 | for i in range(num_new_tokens): 810 | encoder_hidden_states = text_encoder(batch["input_ids"][:, i, :])[0].to(dtype=weight_dtype) 811 | encoder_hidden_states_list.append(encoder_hidden_states) 812 | 813 | # Predict the noise residual 814 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states_list=encoder_hidden_states_list).sample 815 | 816 | # Get the target for loss depending on the prediction type 817 | if noise_scheduler.config.prediction_type == "epsilon": 818 | target = noise 819 | elif noise_scheduler.config.prediction_type == "v_prediction": 820 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 821 | else: 822 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 823 | 824 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 825 | 826 | accelerator.backward(loss) 827 | 828 | optimizer.step() 829 | lr_scheduler.step() 830 | optimizer.zero_grad() 831 | 832 | # Let's make sure we don't update any embedding weights besides the newly added token 833 | vocab = torch.arange(len(tokenizer)) 834 | index_no_updates = torch.any( 835 | torch.stack([torch.eq(vocab, aelem).logical_or_(torch.eq(vocab, aelem)) for aelem in placeholder_token_ids], dim=0), dim=0 836 | ) 837 | with torch.no_grad(): 838 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 839 | index_no_updates 840 | ] = orig_embeds_params[index_no_updates] 841 | 842 | # Checks if the accelerator has performed an optimization step behind the scenes 843 | if accelerator.sync_gradients: 844 | progress_bar.update(1) 845 | global_step += 1 846 | if global_step % args.save_steps == 0: 847 | save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") 848 | save_progress(text_encoder, placeholder_tokens, placeholder_token_ids, accelerator, args, save_path) 849 | 850 | if accelerator.is_main_process: 851 | if global_step % args.checkpointing_steps == 0: 852 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 853 | accelerator.save_state(save_path) 854 | logger.info(f"Saved state to {save_path}") 855 | 856 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 857 | log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) 858 | 859 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 860 | progress_bar.set_postfix(**logs) 861 | accelerator.log(logs, step=global_step) 862 | 863 | if global_step >= args.max_train_steps: 864 | break 865 | # Create the pipeline using using the trained modules and save it. 866 | accelerator.wait_for_everyone() 867 | if accelerator.is_main_process: 868 | if args.push_to_hub and args.only_save_embeds: 869 | logger.warn("Enabling full model saving because --push_to_hub=True was specified.") 870 | save_full_model = True 871 | else: 872 | save_full_model = not args.only_save_embeds 873 | if save_full_model: 874 | pipeline = StableDiffusionPipeline.from_pretrained( 875 | args.pretrained_model_name_or_path, 876 | text_encoder=accelerator.unwrap_model(text_encoder), 877 | vae=vae, 878 | unet=unet, 879 | tokenizer=tokenizer, 880 | ) 881 | pipeline.save_pretrained(args.output_dir) 882 | # Save the newly trained embeddings 883 | save_path = os.path.join(args.output_dir, "learned_embeds.bin") 884 | save_progress(text_encoder, placeholder_tokens, placeholder_token_ids, accelerator, args, save_path) 885 | if args.push_to_hub: 886 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 887 | 888 | accelerator.end_training() 889 | 890 | 891 | if __name__ == "__main__": 892 | main() 893 | --------------------------------------------------------------------------------