├── .gitignore ├── LICENSE ├── README.md ├── inference-promptsliders-sd.py ├── inference-promptsliders-sdxl.py ├── inference_sd.py ├── prompt_slider_emotions.sh ├── requirements.txt ├── textsliders ├── __init__.py ├── data │ ├── prompts-anime.yaml │ ├── prompts-bald.yaml │ ├── prompts-glasses.yaml │ ├── prompts-hat.yaml │ ├── prompts-smiling.yaml │ ├── prompts-surprised.yaml │ ├── prompts-xl.yaml │ ├── prompts-zombie.yaml │ └── prompts.yaml ├── flush.py ├── prompt_util.py └── train_util.py ├── textual_inversion.py └── textual_inversion_sdxl.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # slider files 7 | *.bin 8 | *.pt 9 | *.pkl 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt Sliders for Fine-Grained Control, Editing and Erasing of Concepts in Diffusion Models 2 | We introduce the Prompt Slider method for precise manipulation, editing, and erasure of concepts in diffusion models. [Project Page](https://deepaksridhar.github.io/promptsliders.github.io/) 3 | 4 | 5 | ### Installing the dependencies 6 | 7 | Before running the scripts, make sure to install the library's training dependencies: 8 | 9 | You can install diffusers directly from pip or install from the latest version. To do this, execute one of the following steps in a new virtual environment: 10 | 11 | Install with pip 12 | ```bash 13 | pip install diffusers==0.27 14 | ``` 15 | 16 | Install from source 17 | ```bash 18 | git clone https://github.com/huggingface/diffusers 19 | cd diffusers 20 | pip install . 21 | ``` 22 | 23 | Then cd in the promptsliders folder (you can also copy it to the examples folder in diffusers) and run: 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with: 29 | 30 | ```bash 31 | accelerate config 32 | ``` 33 | 34 | Now we can launch the training using: 35 | 36 | ```bash 37 | export MODEL_NAME="runwayml/stable-diffusion-v1-5" 38 | export EMOTION="smiling" 39 | 40 | accelerate launch textual_inversion.py \ 41 | --pretrained_model_name_or_path=$MODEL_NAME \ 42 | --learnable_property="object" \ 43 | --placeholder_token="<$EMOTION-lora>" \ 44 | --initializer_token="$EMOTION" \ 45 | --mixed_precision="no" \ 46 | --resolution=512 \ 47 | --train_batch_size=1 \ 48 | --gradient_accumulation_steps=1 \ 49 | --max_train_steps=2000 \ 50 | --learning_rate=5.0e-04 \ 51 | --scale_lr \ 52 | --lr_scheduler="constant" \ 53 | --lr_warmup_steps=0 \ 54 | --save_as_full_pipeline \ 55 | --output_dir=outputs/$EMOTION-promptslider/ \ 56 | --prompts_file="textsliders/data/prompts-$EMOTION.yaml" 57 | ``` 58 | 59 | Alternatively, one could run with default settings 60 | 61 | ```bash 62 | bash prompt_slider_emotions.sh 63 | ``` 64 | 65 | A full training run takes ~1-2 hours on one A10 GPU. 66 | 67 | ### Inference 68 | 69 | If you have issues in running the code `TypeError: unsupported operand type(s) for +: 'int' and 'NoneType' `, install the earlier version of diffusers 70 | 71 | ```bash 72 | pip install diffusers==0.20.2 73 | pip install huggingface-hub==0.23.2 74 | ``` 75 | 76 | Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` or `StableDiffusionXLPipeline` wih the following script. Make sure to modify the concept name to your concept and the output is at `output/age-slider_prompt/learned_embeds.safetensors`. 77 | 78 | ```bash 79 | python inference-promptsliders-sdxl.py age 80 | ``` 81 | 82 | To run inference with SD with default scale, 83 | 84 | ```bash 85 | python inference_sd.py $path_to_the_saved_embedding $token_name 86 | ``` 87 | ## Acknowledgements 88 | 89 | Thanks to [diffusers](https://github.com/huggingface/diffusers) and [Concept Sliders](https://github.com/rohitgandikota/sliders)! 90 | -------------------------------------------------------------------------------- /inference-promptsliders-sd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | #Author: Deepak Sridhar 4 | 5 | 6 | import torch 7 | from PIL import Image 8 | import argparse 9 | import os, json, random, sys 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | import glob, re 13 | import warnings 14 | warnings.filterwarnings("ignore") 15 | 16 | 17 | from tqdm import tqdm 18 | import numpy as np 19 | 20 | from safetensors.torch import load_file 21 | import matplotlib.image as mpimg 22 | import copy 23 | import gc 24 | from transformers import CLIPTextModel, CLIPTokenizer 25 | 26 | import diffusers 27 | from diffusers import DiffusionPipeline 28 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler 29 | from diffusers.loaders import AttnProcsLayers 30 | from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor 31 | from typing import Any, Dict, List, Optional, Tuple, Union 32 | 33 | import safetensors.torch 34 | 35 | def flush(): 36 | torch.cuda.empty_cache() 37 | gc.collect() 38 | flush() 39 | pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" 40 | 41 | revision = None 42 | device = 'cuda:0' 43 | concept = sys.argv[1] 44 | mconcept = "iid-1" 45 | weight_dtype = torch.float32 46 | 47 | # Load scheduler, tokenizer and models. 48 | noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 49 | tokenizer = CLIPTokenizer.from_pretrained( 50 | pretrained_model_name_or_path, subfolder="tokenizer", revision=revision 51 | ) 52 | text_encoder = CLIPTextModel.from_pretrained( 53 | pretrained_model_name_or_path, subfolder="text_encoder", revision=revision 54 | ) 55 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) 56 | unet = UNet2DConditionModel.from_pretrained( 57 | pretrained_model_name_or_path, subfolder="unet", revision=revision 58 | ) 59 | 60 | new_token = f"<{mconcept}>" 61 | learned_embeds_path = f"output/{concept}-slider_prompt/learned_embeds.safetensors" 62 | loaded_embeds = safetensors.torch.load_file(learned_embeds_path) 63 | 64 | # Check if the token already exists in the vocabulary 65 | if new_token not in tokenizer.get_vocab(): 66 | # Add the token to the tokenizer 67 | tokenizer.add_tokens([new_token]) 68 | new_token_id = tokenizer.convert_tokens_to_ids(new_token) 69 | # Resize the model’s token embeddings to accommodate the new token 70 | text_encoder.resize_token_embeddings(len(tokenizer)) 71 | 72 | keyy = list(loaded_embeds.keys())[0] 73 | new_token_embed = loaded_embeds[keyy] 74 | 75 | with torch.no_grad(): 76 | text_encoder.get_input_embeddings().weight.data[new_token_id] = new_token_embed.clone() 77 | 78 | 79 | # freeze parameters of models to save more memory 80 | unet.requires_grad_(False) 81 | unet.to(device, dtype=weight_dtype) 82 | vae.requires_grad_(False) 83 | vae.to(device, dtype=weight_dtype) 84 | text_encoder.requires_grad_(False) 85 | text_encoder.to(device, dtype=weight_dtype) 86 | 87 | # prompts to try 88 | prompts = [ 89 | f"a photo of a man, <{mconcept}>", 90 | ] 91 | # scale to test 92 | scales = [0, 0.5, 1.0, 1.25] 93 | 94 | # timestep during inference when we switch to scale>0 (this is done to ensure structure in the images) 95 | start_noise = 800 96 | 97 | 98 | #number of images per prompt 99 | num_images_per_prompt = 1 100 | 101 | torch_device = device 102 | negative_prompt = None 103 | batch_size = 1 104 | height = 512 105 | width = 512 106 | ddim_steps = 50 107 | guidance_scale = 7.5 108 | unet = UNet2DConditionModel.from_pretrained( 109 | pretrained_model_name_or_path, subfolder="unet", revision=revision 110 | ) 111 | # freeze parameters of models to save more memory 112 | unet.requires_grad_(False) 113 | unet.to(device, dtype=weight_dtype) 114 | 115 | for prompt in prompts: 116 | # for different seeds on same prompt 117 | for _ in range(num_images_per_prompt): 118 | seed = random.randint(0, 5000) 119 | 120 | 121 | 122 | images_list = [] 123 | 124 | print(prompt, seed) 125 | 126 | for scale in scales: 127 | generator = torch.manual_seed(seed) 128 | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 129 | idx = text_input.input_ids.argmax(-1) 130 | 131 | 132 | 133 | max_length = text_input.input_ids.shape[-1] 134 | batch_indices = torch.arange(len(text_input.input_ids)) 135 | if negative_prompt is None: 136 | uncond_input = tokenizer( 137 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 138 | ) 139 | else: 140 | uncond_input = tokenizer( 141 | [negative_prompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 142 | ) 143 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 144 | 145 | 146 | 147 | latents = torch.randn( 148 | (batch_size, unet.in_channels, height // 8, width // 8), 149 | generator=generator, 150 | ) 151 | latents = latents.to(torch_device) 152 | 153 | noise_scheduler.set_timesteps(ddim_steps) 154 | 155 | latents = latents * noise_scheduler.init_noise_sigma 156 | latents = latents.to(weight_dtype) 157 | latent_model_input = torch.cat([latents] * 2) 158 | 159 | for t in tqdm(noise_scheduler.timesteps): 160 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] 161 | if t>start_noise and scale > 0.0: 162 | text_embeddings[batch_indices, idx, :] = 0.0 * text_embeddings[batch_indices, idx, :] 163 | else: 164 | text_embeddings[batch_indices, idx, :] = scale * text_embeddings[batch_indices, idx, :] 165 | 166 | concat_text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 167 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 168 | latent_model_input = torch.cat([latents] * 2) 169 | 170 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) 171 | # predict the noise residual 172 | 173 | with torch.no_grad(): 174 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=concat_text_embeddings).sample 175 | # perform guidance 176 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 177 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 178 | 179 | # compute the previous noisy sample x_t -> x_t-1 180 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample 181 | 182 | # scale and decode the image latents with vae 183 | latents = 1 / 0.18215 * latents 184 | with torch.no_grad(): 185 | image = vae.decode(latents).sample 186 | image = (image / 2 + 0.5).clamp(0, 1) 187 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 188 | images = (image * 255).round().astype("uint8") 189 | pil_images = [Image.fromarray(image) for image in images] 190 | images_list.append(pil_images[0]) 191 | 192 | fig, ax = plt.subplots(1, len(images_list), figsize=(20,4)) 193 | for i, a in enumerate(ax): 194 | a.imshow(images_list[i]) 195 | a.set_title(f"{scales[i]}",fontsize=15) 196 | a.axis('off') 197 | 198 | plt.show() 199 | plt.savefig(f'{prompt}.jpg') 200 | plt.close() 201 | -------------------------------------------------------------------------------- /inference-promptsliders-sdxl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import torch 5 | # from PIL import Image 6 | # import argparse 7 | import os, json, random, sys 8 | import matplotlib.pyplot as plt 9 | 10 | from safetensors.torch import load_file 11 | import gc 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | import safetensors.torch 14 | 15 | import diffusers 16 | from typing import Any, Dict, List, Optional, Tuple, Union 17 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput 18 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 19 | from diffusers.pipelines import StableDiffusionXLPipeline 20 | import random 21 | 22 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer 23 | 24 | 25 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] 26 | 27 | 28 | def text_encode_xl( 29 | text_encoder: SDXL_TEXT_ENCODER_TYPE, 30 | tokens: torch.FloatTensor, 31 | num_images_per_prompt: int = 1, 32 | ): 33 | prompt_embeds = text_encoder( 34 | tokens.to(text_encoder.device), output_hidden_states=True 35 | ) 36 | pooled_prompt_embeds = prompt_embeds[0] 37 | prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer 38 | 39 | bs_embed, seq_len, _ = prompt_embeds.shape 40 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 41 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 42 | 43 | return prompt_embeds, pooled_prompt_embeds 44 | 45 | 46 | def text_tokenize( 47 | tokenizer: CLIPTokenizer, 48 | prompts: List[str], 49 | ): 50 | return tokenizer( 51 | prompts, 52 | padding="max_length", 53 | max_length=tokenizer.model_max_length, 54 | truncation=True, 55 | return_tensors="pt", 56 | ).input_ids 57 | 58 | 59 | def encode_prompts_xl( 60 | tokenizers: List[CLIPTokenizer], 61 | text_encoders: List[SDXL_TEXT_ENCODER_TYPE], 62 | prompts: List[str], 63 | num_images_per_prompt: int = 1, 64 | sc: float = 1.0, 65 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 66 | # text_encoder and text_encoder_2's penultimate layer's output 67 | text_embeds_list = [] 68 | pooled_text_embeds = None # always text_encoder_2's pool 69 | k = 0 70 | num_k = 1 71 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 72 | 73 | text_tokens_input_ids = text_tokenize(tokenizer, prompts) 74 | # Get idx of new token 75 | idx = text_tokens_input_ids.argmax(-1) 76 | 77 | text_embeds, pooled_text_embeds = text_encode_xl( 78 | text_encoder, text_tokens_input_ids, num_images_per_prompt 79 | ) 80 | batch_indices = torch.arange(len(text_tokens_input_ids)) 81 | if k == 0: 82 | # Adjust the scale 83 | text_embeds[batch_indices, idx, :] = sc * text_embeds[batch_indices, idx, :] 84 | 85 | text_embeds_list.append(text_embeds) 86 | k += 1 87 | 88 | bs_embed = pooled_text_embeds.shape[0] 89 | pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( 90 | bs_embed * num_images_per_prompt, -1 91 | ) 92 | 93 | return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds.view(1, -1) 94 | 95 | 96 | def flush(): 97 | torch.cuda.empty_cache() 98 | gc.collect() 99 | 100 | @torch.no_grad() 101 | def call( 102 | self, 103 | prompt: Union[str, List[str]] = None, 104 | prompt_2: Optional[Union[str, List[str]]] = None, 105 | height: Optional[int] = None, 106 | width: Optional[int] = None, 107 | num_inference_steps: int = 50, 108 | denoising_end: Optional[float] = None, 109 | guidance_scale: float = 5.0, 110 | negative_prompt: Optional[Union[str, List[str]]] = None, 111 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 112 | num_images_per_prompt: Optional[int] = 1, 113 | eta: float = 0.0, 114 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 115 | latents: Optional[torch.FloatTensor] = None, 116 | prompt_embeds: Optional[torch.FloatTensor] = None, 117 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 118 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 119 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 120 | output_type: Optional[str] = "pil", 121 | return_dict: bool = True, 122 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 123 | callback_steps: int = 1, 124 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 125 | guidance_rescale: float = 0.0, 126 | original_size: Optional[Tuple[int, int]] = None, 127 | crops_coords_top_left: Tuple[int, int] = (0, 0), 128 | target_size: Optional[Tuple[int, int]] = None, 129 | negative_original_size: Optional[Tuple[int, int]] = None, 130 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 131 | negative_target_size: Optional[Tuple[int, int]] = None, 132 | pipe=None, 133 | network=None, 134 | start_noise=None, 135 | scale=None, 136 | unet=None, 137 | ): 138 | r""" 139 | Function invoked when calling the pipeline for generation. 140 | 141 | Args: 142 | prompt (`str` or `List[str]`, *optional*): 143 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 144 | instead. 145 | prompt_2 (`str` or `List[str]`, *optional*): 146 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 147 | used in both text-encoders 148 | height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor): 149 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 150 | Anything below 512 pixels won't work well for 151 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 152 | and checkpoints that are not specifically fine-tuned on low resolutions. 153 | width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor): 154 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 155 | Anything below 512 pixels won't work well for 156 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 157 | and checkpoints that are not specifically fine-tuned on low resolutions. 158 | num_inference_steps (`int`, *optional*, defaults to 50): 159 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 160 | expense of slower inference. 161 | denoising_end (`float`, *optional*): 162 | When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be 163 | completed before it is intentionally prematurely terminated. As a result, the returned sample will 164 | still retain a substantial amount of noise as determined by the discrete timesteps selected by the 165 | scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 166 | "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image 167 | Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) 168 | guidance_scale (`float`, *optional*, defaults to 5.0): 169 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 170 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 171 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 172 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 173 | usually at the expense of lower image quality. 174 | negative_prompt (`str` or `List[str]`, *optional*): 175 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 176 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 177 | less than `1`). 178 | negative_prompt_2 (`str` or `List[str]`, *optional*): 179 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 180 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders 181 | num_images_per_prompt (`int`, *optional*, defaults to 1): 182 | The number of images to generate per prompt. 183 | eta (`float`, *optional*, defaults to 0.0): 184 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 185 | [`schedulers.DDIMScheduler`], will be ignored for others. 186 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 187 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 188 | to make generation deterministic. 189 | latents (`torch.FloatTensor`, *optional*): 190 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 191 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 192 | tensor will ge generated by sampling using the supplied random `generator`. 193 | prompt_embeds (`torch.FloatTensor`, *optional*): 194 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 195 | provided, text embeddings will be generated from `prompt` input argument. 196 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 197 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 198 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 199 | argument. 200 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 201 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 202 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 203 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 204 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 205 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 206 | input argument. 207 | output_type (`str`, *optional*, defaults to `"pil"`): 208 | The output format of the generate image. Choose between 209 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 210 | return_dict (`bool`, *optional*, defaults to `True`): 211 | Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead 212 | of a plain tuple. 213 | callback (`Callable`, *optional*): 214 | A function that will be called every `callback_steps` steps during inference. The function will be 215 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 216 | callback_steps (`int`, *optional*, defaults to 1): 217 | The frequency at which the `callback` function will be called. If not specified, the callback will be 218 | called at every step. 219 | cross_attention_kwargs (`dict`, *optional*): 220 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 221 | `self.processor` in 222 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 223 | guidance_rescale (`float`, *optional*, defaults to 0.7): 224 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are 225 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of 226 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 227 | Guidance rescale factor should fix overexposure when using zero terminal SNR. 228 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 229 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. 230 | `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as 231 | explained in section 2.2 of 232 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 233 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 234 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position 235 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting 236 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of 237 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 238 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 239 | For most cases, `target_size` should be set to the desired height and width of the generated image. If 240 | not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in 241 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 242 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 243 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's 244 | micro-conditioning as explained in section 2.2 of 245 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 246 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 247 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 248 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's 249 | micro-conditioning as explained in section 2.2 of 250 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 251 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 252 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 253 | To negatively condition the generation process based on a target image resolution. It should be as same 254 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of 255 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 256 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 257 | 258 | Examples: 259 | 260 | Returns: 261 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: 262 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a 263 | `tuple`. When returning a tuple, the first element is a list with the generated images. 264 | """ 265 | # 0. Default height and width to unet 266 | height = height or self.default_sample_size * self.vae_scale_factor 267 | width = width or self.default_sample_size * self.vae_scale_factor 268 | 269 | original_size = original_size or (height, width) 270 | target_size = target_size or (height, width) 271 | 272 | # 1. Check inputs. Raise error if not correct 273 | self.check_inputs( 274 | prompt, 275 | prompt_2, 276 | height, 277 | width, 278 | callback_steps, 279 | negative_prompt, 280 | negative_prompt_2, 281 | prompt_embeds, 282 | negative_prompt_embeds, 283 | pooled_prompt_embeds, 284 | negative_pooled_prompt_embeds, 285 | ) 286 | 287 | # 2. Define call parameters 288 | if prompt is not None and isinstance(prompt, str): 289 | batch_size = 1 290 | elif prompt is not None and isinstance(prompt, list): 291 | batch_size = len(prompt) 292 | else: 293 | batch_size = prompt_embeds.shape[0] 294 | 295 | device = self._execution_device 296 | 297 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 298 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 299 | # corresponds to doing no classifier free guidance. 300 | do_classifier_free_guidance = guidance_scale > 1.0 301 | 302 | # 3. Encode input prompt 303 | text_encoder_lora_scale = ( 304 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 305 | ) 306 | ( 307 | embs, 308 | negative_prompt_embeds, 309 | pool_embs, 310 | negative_pooled_prompt_embeds, 311 | ) = self.encode_prompt( 312 | prompt=prompt, 313 | prompt_2=prompt_2, 314 | device=device, 315 | num_images_per_prompt=num_images_per_prompt, 316 | do_classifier_free_guidance=do_classifier_free_guidance, 317 | negative_prompt=negative_prompt, 318 | negative_prompt_2=negative_prompt_2, 319 | prompt_embeds=prompt_embeds, 320 | negative_prompt_embeds=negative_prompt_embeds, 321 | pooled_prompt_embeds=pooled_prompt_embeds, 322 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 323 | lora_scale=text_encoder_lora_scale, 324 | ) 325 | 326 | prompt_embeds, pooled_prompt_embeds = encode_prompts_xl([pipe.tokenizer, pipe.tokenizer_2], [pipe.text_encoder, pipe.text_encoder_2], [prompt], sc=scale) 327 | 328 | prompt_embeds2, pooled_prompt_embeds2 = encode_prompts_xl([pipe.tokenizer, pipe.tokenizer_2], [pipe.text_encoder, pipe.text_encoder_2], [prompt], sc=0.0) 329 | 330 | # 4. Prepare timesteps 331 | self.scheduler.set_timesteps(num_inference_steps, device=device) 332 | 333 | timesteps = self.scheduler.timesteps 334 | 335 | # 5. Prepare latent variables 336 | num_channels_latents = unet.config.in_channels 337 | latents = self.prepare_latents( 338 | batch_size * num_images_per_prompt, 339 | num_channels_latents, 340 | height, 341 | width, 342 | prompt_embeds.dtype, 343 | device, 344 | generator, 345 | latents, 346 | ) 347 | 348 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 349 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 350 | 351 | # 7. Prepare added time ids & embeddings 352 | add_text_embeds = pooled_prompt_embeds 353 | add_time_ids = self._get_add_time_ids( 354 | original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype 355 | ) 356 | if negative_original_size is not None and negative_target_size is not None: 357 | negative_add_time_ids = self._get_add_time_ids( 358 | negative_original_size, 359 | negative_crops_coords_top_left, 360 | negative_target_size, 361 | dtype=prompt_embeds.dtype, 362 | ) 363 | else: 364 | negative_add_time_ids = add_time_ids 365 | 366 | if do_classifier_free_guidance: 367 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 368 | prompt_embeds2 = torch.cat([negative_prompt_embeds, prompt_embeds2], dim=0) 369 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 370 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 371 | 372 | prompt_embeds = prompt_embeds.to(device) 373 | prompt_embeds2 = prompt_embeds2.to(device) 374 | add_text_embeds = add_text_embeds.to(device) 375 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 376 | 377 | # 8. Denoising loop 378 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 379 | 380 | # 7.1 Apply denoising_end 381 | if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: 382 | discrete_timestep_cutoff = int( 383 | round( 384 | self.scheduler.config.num_train_timesteps 385 | - (denoising_end * self.scheduler.config.num_train_timesteps) 386 | ) 387 | ) 388 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) 389 | timesteps = timesteps[:num_inference_steps] 390 | latents = latents.to(unet.dtype) 391 | with self.progress_bar(total=num_inference_steps) as progress_bar: 392 | for i, t in enumerate(timesteps): 393 | if t>start_noise and scale > 0.0: 394 | prompt_embed = prompt_embeds2 395 | else: 396 | prompt_embed = prompt_embeds 397 | # expand the latents if we are doing classifier free guidance 398 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 399 | 400 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 401 | 402 | # predict the noise residual 403 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 404 | 405 | noise_pred = unet( 406 | latent_model_input, 407 | t, 408 | encoder_hidden_states=prompt_embed, 409 | cross_attention_kwargs=cross_attention_kwargs, 410 | added_cond_kwargs=added_cond_kwargs, 411 | return_dict=False, 412 | )[0] 413 | 414 | # perform guidance 415 | if do_classifier_free_guidance: 416 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 417 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 418 | 419 | if do_classifier_free_guidance and guidance_rescale > 0.0: 420 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 421 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 422 | 423 | # compute the previous noisy sample x_t -> x_t-1 424 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 425 | 426 | # call the callback, if provided 427 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 428 | progress_bar.update() 429 | if callback is not None and i % callback_steps == 0: 430 | callback(i, t, latents) 431 | 432 | if not output_type == "latent": 433 | # make sure the VAE is in float32 mode, as it overflows in float16 434 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 435 | 436 | if needs_upcasting: 437 | self.upcast_vae() 438 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 439 | 440 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 441 | 442 | # cast back to fp16 if needed 443 | if needs_upcasting: 444 | self.vae.to(dtype=torch.float16) 445 | else: 446 | image = latents 447 | 448 | if not output_type == "latent": 449 | # apply watermark if available 450 | if self.watermark is not None: 451 | image = self.watermark.apply_watermark(image) 452 | 453 | image = self.image_processor.postprocess(image, output_type=output_type) 454 | 455 | if not return_dict: 456 | return (image,) 457 | 458 | return StableDiffusionXLPipelineOutput(images=image) 459 | 460 | concept = sys.argv[1] 461 | cname = "iid-1" 462 | 463 | new_token = f"<{cname}>" 464 | learned_embeds_path = f"output/{concept}-slider_prompt/learned_embeds.safetensors" #path to the trained sliders 465 | device = 'cuda:0' 466 | dtype = torch.bfloat16 467 | StableDiffusionXLPipeline.__call__ = call 468 | pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=dtype) 469 | 470 | # pipe.__call__ = call 471 | pipe = pipe.to(device) 472 | loaded_embeds = safetensors.torch.load_file(learned_embeds_path) 473 | 474 | # Check if the token already exists in the vocabulary 475 | if new_token not in pipe.tokenizer.get_vocab(): 476 | # Add the token to the tokenizer 477 | pipe.tokenizer.add_tokens([new_token]) 478 | new_token_id = pipe.tokenizer.convert_tokens_to_ids(new_token) 479 | # Resize the model’s token embeddings to accommodate the new token 480 | pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) 481 | 482 | keyy = list(loaded_embeds.keys())[0] 483 | new_token_embed = loaded_embeds[keyy] 484 | 485 | with torch.no_grad(): 486 | pipe.text_encoder.get_input_embeddings().weight.data[new_token_id] = new_token_embed.clone() 487 | 488 | prompts = [ 489 | # f'Image of a person, realistic, {new_token}, 8k', 490 | f'A photo of a girl, {new_token}', 491 | # 'A realistic photograph of a person, bokeh, blurred background, {new_token}, 8k', 492 | f'Professional headshot of a person, {new_token}', 493 | ] 494 | 495 | 496 | start_noise = 800 497 | num_images_per_prompt = 1 498 | scales = [0, 0.5, 1, 1.5, 2] 499 | for prompt in prompts: 500 | for _ in range(num_images_per_prompt): 501 | seed = random.randint(0,2**15) #22373 502 | print(prompt, seed) 503 | 504 | image_list = [] 505 | for scale in scales: 506 | generator = torch.manual_seed(seed) 507 | images = pipe(prompt, target_size=(512, 512), num_images_per_prompt=1, num_inference_steps=50, generator=generator, start_noise=start_noise, scale=scale, unet=pipe.unet, pipe=pipe).images[0] 508 | image_list.append(images) 509 | 510 | fig, ax = plt.subplots(1, len(image_list), figsize=(20,4)) 511 | for i, a in enumerate(ax): 512 | a.imshow(image_list[i]) 513 | a.set_title(f"{scales[i]}",fontsize=15) 514 | a.axis('off') 515 | 516 | plt.suptitle(f'{concept}', fontsize=20) 517 | # plt.tight_layout() 518 | plt.show() 519 | plt.savefig(f'sdxl-{prompt}-{seed}.png') 520 | 521 | plt.close() 522 | -------------------------------------------------------------------------------- /inference_sd.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import torch 5 | import os 6 | import sys 7 | import numpy as np 8 | from PIL import Image 9 | from tqdm import tqdm 10 | import requests 11 | from typing import Union 12 | import PIL 13 | from diffusers import StableDiffusionPipeline 14 | 15 | model = 'runwayml/stable-diffusion-v1-5' 16 | 17 | device = 'cuda' 18 | 19 | pipe = StableDiffusionPipeline.from_pretrained(model,safety_checker = None,) 20 | 21 | concept = sys.argv[1] 22 | token = sys.argv[2] 23 | textual_inversion_embeds_path = f"{concept}" 24 | pipe.load_textual_inversion(textual_inversion_embeds_path, token=f"{token}") 25 | pipe.to(device) 26 | image = pipe(f"An image of a male, {token}", num_inference_steps=50).images[0] 27 | token_name = token.replace('<','').replace('>','') 28 | image.save(f"outputs/{token_name}-promptslider/output/ti_{token_name}.png") 29 | -------------------------------------------------------------------------------- /prompt_slider_emotions.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the model name and data directory as environment variables 4 | export MODEL_NAME="runwayml/stable-diffusion-v1-5" 5 | 6 | # List of emotions 7 | # emotions=("sad" "disgusted" "confused" "fear" "surprised" "angry") 8 | emotions=("smiling" "surprised") 9 | 10 | # Loop through each emotion and run the script 11 | for EMOTION in "${emotions[@]}"; do 12 | echo "Running script for emotion: $EMOTION" 13 | 14 | # Run the script with the current emotion 15 | accelerate launch textual_inversion.py \ 16 | --pretrained_model_name_or_path=$MODEL_NAME \ 17 | --learnable_property="object" \ 18 | --placeholder_token="<$EMOTION-lora>" \ 19 | --initializer_token="$EMOTION" \ 20 | --mixed_precision="no" \ 21 | --resolution=512 \ 22 | --train_batch_size=1 \ 23 | --gradient_accumulation_steps=1 \ 24 | --max_train_steps=2000 \ 25 | --learning_rate=5.0e-04 \ 26 | --scale_lr \ 27 | --lr_scheduler="constant" \ 28 | --lr_warmup_steps=0 \ 29 | --save_as_full_pipeline \ 30 | --output_dir=outputs/$EMOTION-promptslider/ \ 31 | --prompts_file="textsliders/data/prompts-$EMOTION.yaml" 32 | 33 | echo "Completed: $EMOTION" 34 | echo "-----------------------------------" 35 | done 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=0.16.0 2 | torchvision 3 | transformers>=4.25.1 4 | ftfy 5 | tensorboard 6 | Jinja2 7 | bitsandbytes==0.41.1 8 | dadaptation==3.1 9 | ipython==8.7.0 10 | lion_pytorch==0.1.2 11 | lpips==0.1.4 12 | matplotlib==3.6.2 13 | numpy==1.23.5 14 | opencv_python==4.5.5.64 15 | opencv_python_headless==4.7.0.68 16 | pandas==1.5.2 17 | Pillow==10.1.0 18 | prodigyopt==1.0 19 | pydantic==2.6.3 20 | PyYAML==6.0.1 21 | Requests==2.31.0 22 | safetensors==0.3.1 23 | torch==2.0.1 24 | torchvision==0.15.2 25 | tqdm==4.64.1 26 | transformers==4.27.4 27 | wandb==0.12.21 28 | xformers==0.0.21 -------------------------------------------------------------------------------- /textsliders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepakSridhar/promptsliders/42ded17000fea65ff3573c1caac41dc0bad0b8e3/textsliders/__init__.py -------------------------------------------------------------------------------- /textsliders/data/prompts-anime.yaml: -------------------------------------------------------------------------------- 1 | - target: "male person" # what word for erasing the positive concept from 2 | positive: "male person, anime style" # concept to erase 3 | unconditional: "male person, realistic, natural" # word to take the difference from the positive concept 4 | neutral: "male person" # starting point for conditioning the target 5 | action: "enhance" # erase or enhance 6 | guidance_scale: 4 7 | resolution: 512 8 | dynamic_resolution: false 9 | batch_size: 1 10 | - target: "female person" # what word for erasing the positive concept from 11 | positive: "female person, anime style" # concept to erase 12 | unconditional: "female person, realistic, natural" # word to take the difference from the positive concept 13 | neutral: "female person" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 4 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | ####################################################################################################### -------------------------------------------------------------------------------- /textsliders/data/prompts-bald.yaml: -------------------------------------------------------------------------------- 1 | - target: "male person" # what word for erasing the positive concept from 2 | positive: "male person, bald" # concept to erase 3 | unconditional: "male person, realistic, natural" # word to take the difference from the positive concept 4 | neutral: "male person" # starting point for conditioning the target 5 | action: "enhance" # erase or enhance 6 | guidance_scale: 4 7 | resolution: 512 8 | dynamic_resolution: false 9 | batch_size: 1 10 | - target: "female person" # what word for erasing the positive concept from 11 | positive: "female person, bald" # concept to erase 12 | unconditional: "female person, realistic, natural" # word to take the difference from the positive concept 13 | neutral: "female person" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 4 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | ####################################################################################################### -------------------------------------------------------------------------------- /textsliders/data/prompts-glasses.yaml: -------------------------------------------------------------------------------- 1 | - target: "male person" # what word for erasing the positive concept from 2 | positive: "male person, wearing sunglasses" # concept to erase 3 | unconditional: "male person" # word to take the difference from the positive concept 4 | neutral: "male person" # starting point for conditioning the target 5 | action: "enhance" # erase or enhance 6 | guidance_scale: 4 7 | resolution: 512 8 | dynamic_resolution: false 9 | batch_size: 1 10 | - target: "female person" # what word for erasing the positive concept from 11 | positive: "female person, wearing sunglasses" # concept to erase 12 | unconditional: "female person" # word to take the difference from the positive concept 13 | neutral: "female person" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 4 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | ####################################################################################################### -------------------------------------------------------------------------------- /textsliders/data/prompts-hat.yaml: -------------------------------------------------------------------------------- 1 | - target: "male person" # what word for erasing the positive concept from 2 | positive: "male person, wearing hat" # concept to erase 3 | unconditional: "male person" # word to take the difference from the positive concept 4 | neutral: "male person" # starting point for conditioning the target 5 | action: "enhance" # erase or enhance 6 | guidance_scale: 4 7 | resolution: 512 8 | dynamic_resolution: false 9 | batch_size: 1 10 | - target: "female person" # what word for erasing the positive concept from 11 | positive: "female person, wearing hat" # concept to erase 12 | unconditional: "female person" # word to take the difference from the positive concept 13 | neutral: "female person" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 4 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | ####################################################################################################### -------------------------------------------------------------------------------- /textsliders/data/prompts-smiling.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person, smiling, happy face, big smile 3 | unconditional: male white person, frowning, grumpy, sad 4 | neutral: male white person 5 | guidance: 4 6 | rank: 4 7 | action: enhance 8 | resolution: 512 9 | dynamic_resolution: false 10 | batch_size: 1 11 | - target: male black person 12 | positive: male black person, smiling, happy face, big smile 13 | unconditional: male black person, frowning, grumpy, sad 14 | neutral: male black person 15 | guidance: 4 16 | rank: 4 17 | action: enhance 18 | resolution: 512 19 | dynamic_resolution: false 20 | batch_size: 1 21 | - target: male indian person 22 | positive: male indian person, smiling, happy face, big smile 23 | unconditional: male indian person, frowning, grumpy, sad 24 | neutral: male indian person 25 | guidance: 4 26 | rank: 4 27 | action: enhance 28 | resolution: 512 29 | dynamic_resolution: false 30 | batch_size: 1 31 | - target: male asian person 32 | positive: male asian person, smiling, happy face, big smile 33 | unconditional: male asian person, frowning, grumpy, sad 34 | neutral: male asian person 35 | guidance: 4 36 | rank: 4 37 | action: enhance 38 | resolution: 512 39 | dynamic_resolution: false 40 | batch_size: 1 41 | - target: male hispanic person 42 | positive: male hispanic person, smiling, happy face, big smile 43 | unconditional: male hispanic person, frowning, grumpy, sad 44 | neutral: male hispanic person 45 | guidance: 4 46 | rank: 4 47 | action: enhance 48 | resolution: 512 49 | dynamic_resolution: false 50 | batch_size: 1 51 | - target: female white person 52 | positive: female white person, smiling, happy face, big smile 53 | unconditional: female white person, frowning, grumpy, sad 54 | neutral: female white person 55 | guidance: 4 56 | rank: 4 57 | action: enhance 58 | resolution: 512 59 | dynamic_resolution: false 60 | batch_size: 1 61 | - target: female black person 62 | positive: female black person, smiling, happy face, big smile 63 | unconditional: female black person, frowning, grumpy, sad 64 | neutral: female black person 65 | guidance: 4 66 | rank: 4 67 | action: enhance 68 | resolution: 512 69 | dynamic_resolution: false 70 | batch_size: 1 71 | - target: female indian person 72 | positive: female indian person, smiling, happy face, big smile 73 | unconditional: female indian person, frowning, grumpy, sad 74 | neutral: female indian person 75 | guidance: 4 76 | rank: 4 77 | action: enhance 78 | resolution: 512 79 | dynamic_resolution: false 80 | batch_size: 1 81 | - target: female asian person 82 | positive: female asian person, smiling, happy face, big smile 83 | unconditional: female asian person, frowning, grumpy, sad 84 | neutral: female asian person 85 | guidance: 4 86 | rank: 4 87 | action: enhance 88 | resolution: 512 89 | dynamic_resolution: false 90 | batch_size: 1 91 | - target: female hispanic person 92 | positive: female hispanic person, smiling, happy face, big smile 93 | unconditional: female hispanic person, frowning, grumpy, sad 94 | neutral: female hispanic person 95 | guidance: 4 96 | rank: 4 97 | action: enhance 98 | resolution: 512 99 | dynamic_resolution: false 100 | batch_size: 1 101 | -------------------------------------------------------------------------------- /textsliders/data/prompts-surprised.yaml: -------------------------------------------------------------------------------- 1 | - target: male white person 2 | positive: male white person, surprised 3 | unconditional: male white person, calm, relaxed face, neutral expression, at ease, composed 4 | neutral: male white person 5 | guidance: 4 6 | rank: 4 7 | action: enhance 8 | resolution: 512 9 | dynamic_resolution: false 10 | batch_size: 1 11 | - target: male black person 12 | positive: male black person, surprised 13 | unconditional: male black person, calm, relaxed face, neutral expression, at ease, composed 14 | neutral: male black person 15 | guidance: 4 16 | rank: 4 17 | action: enhance 18 | resolution: 512 19 | dynamic_resolution: false 20 | batch_size: 1 21 | - target: male indian person 22 | positive: male indian person, surprised 23 | unconditional: male indian person, calm, relaxed face, neutral expression, at ease, composed 24 | neutral: male indian person 25 | guidance: 4 26 | rank: 4 27 | action: enhance 28 | resolution: 512 29 | dynamic_resolution: false 30 | batch_size: 1 31 | - target: male asian person 32 | positive: male asian person, surprised 33 | unconditional: male asian person, calm, relaxed face, neutral expression, at ease, composed 34 | neutral: male asian person 35 | guidance: 4 36 | rank: 4 37 | action: enhance 38 | resolution: 512 39 | dynamic_resolution: false 40 | batch_size: 1 41 | - target: male hispanic person 42 | positive: male hispanic person, surprised 43 | unconditional: male hispanic person, calm, relaxed face, neutral expression, at ease, composed 44 | neutral: male hispanic person 45 | guidance: 4 46 | rank: 4 47 | action: enhance 48 | resolution: 512 49 | dynamic_resolution: false 50 | batch_size: 1 51 | - target: female white person 52 | positive: female white person, surprised 53 | unconditional: female white person, calm, relaxed face, neutral expression, at ease, composed 54 | neutral: female white person 55 | guidance: 4 56 | rank: 4 57 | action: enhance 58 | resolution: 512 59 | dynamic_resolution: false 60 | batch_size: 1 61 | - target: female black person 62 | positive: female black person, surprised 63 | unconditional: female black person, calm, relaxed face, neutral expression, at ease, composed 64 | neutral: female black person 65 | guidance: 4 66 | rank: 4 67 | action: enhance 68 | resolution: 512 69 | dynamic_resolution: false 70 | batch_size: 1 71 | - target: female indian person 72 | positive: female indian person, surprised 73 | unconditional: female indian person, calm, relaxed face, neutral expression, at ease, composed 74 | neutral: female indian person 75 | guidance: 4 76 | rank: 4 77 | action: enhance 78 | resolution: 512 79 | dynamic_resolution: false 80 | batch_size: 1 81 | - target: female asian person 82 | positive: female asian person, surprised 83 | unconditional: female asian person, calm, relaxed face, neutral expression, at ease, composed 84 | neutral: female asian person 85 | guidance: 4 86 | rank: 4 87 | action: enhance 88 | resolution: 512 89 | dynamic_resolution: false 90 | batch_size: 1 91 | - target: female hispanic person 92 | positive: female hispanic person, surprised 93 | unconditional: female hispanic person, calm, relaxed face, neutral expression, at ease, composed 94 | neutral: female hispanic person 95 | guidance: 4 96 | rank: 4 97 | action: enhance 98 | resolution: 512 99 | dynamic_resolution: false 100 | batch_size: 1 101 | # wide eyes, raised eyebrows, open mouth, -------------------------------------------------------------------------------- /textsliders/data/prompts-xl.yaml: -------------------------------------------------------------------------------- 1 | # ####################################################################################################### AGE SLIDER 2 | # - target: "male person" # what word for erasing the positive concept from 3 | # positive: "male person, very old" # concept to erase 4 | # unconditional: "male person, very young" # word to take the difference from the positive concept 5 | # neutral: "male person" # starting point for conditioning the target 6 | # action: "enhance" # erase or enhance 7 | # guidance_scale: 4 8 | # resolution: 768 9 | # dynamic_resolution: false 10 | # batch_size: 1 11 | # - target: "female person" # what word for erasing the positive concept from 12 | # positive: "female person, very old" # concept to erase 13 | # unconditional: "female person, very young" # word to take the difference from the positive concept 14 | # neutral: "female person" # starting point for conditioning the target 15 | # action: "enhance" # erase or enhance 16 | # guidance_scale: 4 17 | # resolution: 768 18 | # dynamic_resolution: false 19 | # batch_size: 1 20 | ####################################################################################################### AGE SLIDER 21 | - target: "male person" # what word for erasing the positive concept from 22 | positive: "male person, cyborg" # concept to erase 23 | unconditional: "male person, realistic, natural" # word to take the difference from the positive concept 24 | neutral: "male person" # starting point for conditioning the target 25 | action: "enhance" # erase or enhance 26 | guidance_scale: 4 27 | resolution: 768 28 | dynamic_resolution: false 29 | batch_size: 1 30 | - target: "female person" # what word for erasing the positive concept from 31 | positive: "female person, cyborg" # concept to erase 32 | unconditional: "female person, realistic, natural" # word to take the difference from the positive concept 33 | neutral: "female person" # starting point for conditioning the target 34 | action: "enhance" # erase or enhance 35 | guidance_scale: 4 36 | resolution: 768 37 | dynamic_resolution: false 38 | batch_size: 1 39 | ####################################################################################################### MUSCULAR SLIDER 40 | # - target: "male person" # what word for erasing the positive concept from 41 | # positive: "male person, muscular, strong, biceps, greek god physique, body builder" # concept to erase 42 | # unconditional: "male person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept 43 | # neutral: "male person" # starting point for conditioning the target 44 | # action: "enhance" # erase or enhance 45 | # guidance_scale: 4 46 | # resolution: 512 47 | # dynamic_resolution: false 48 | # batch_size: 1 49 | # - target: "female person" # what word for erasing the positive concept from 50 | # positive: "female person, muscular, strong, biceps, greek god physique, body builder" # concept to erase 51 | # unconditional: "female person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept 52 | # neutral: "female person" # starting point for conditioning the target 53 | # action: "enhance" # erase or enhance 54 | # guidance_scale: 4 55 | # resolution: 512 56 | # dynamic_resolution: false 57 | # batch_size: 1 58 | ####################################################################################################### CURLY HAIR SLIDER 59 | # - target: "male person" # what word for erasing the positive concept from 60 | # positive: "male person, curly hair, wavy hair" # concept to erase 61 | # unconditional: "male person, straight hair" # word to take the difference from the positive concept 62 | # neutral: "male person" # starting point for conditioning the target 63 | # action: "enhance" # erase or enhance 64 | # guidance_scale: 4 65 | # resolution: 512 66 | # dynamic_resolution: false 67 | # batch_size: 1 68 | # - target: "female person" # what word for erasing the positive concept from 69 | # positive: "female person, curly hair, wavy hair" # concept to erase 70 | # unconditional: "female person, straight hair" # word to take the difference from the positive concept 71 | # neutral: "female person" # starting point for conditioning the target 72 | # action: "enhance" # erase or enhance 73 | # guidance_scale: 4 74 | # resolution: 512 75 | # dynamic_resolution: false 76 | # batch_size: 1 77 | ####################################################################################################### BEARD SLIDER 78 | # - target: "male person" # what word for erasing the positive concept from 79 | # positive: "male person, with beard" # concept to erase 80 | # unconditional: "male person, clean shaven" # word to take the difference from the positive concept 81 | # neutral: "male person" # starting point for conditioning the target 82 | # action: "enhance" # erase or enhance 83 | # guidance_scale: 4 84 | # resolution: 512 85 | # dynamic_resolution: false 86 | # batch_size: 1 87 | # - target: "female person" # what word for erasing the positive concept from 88 | # positive: "female person, with beard, lipstick and feminine" # concept to erase 89 | # unconditional: "female person, clean shaven" # word to take the difference from the positive concept 90 | # neutral: "female person" # starting point for conditioning the target 91 | # action: "enhance" # erase or enhance 92 | # guidance_scale: 4 93 | # resolution: 512 94 | # dynamic_resolution: false 95 | # batch_size: 1 96 | ####################################################################################################### MAKEUP SLIDER 97 | # - target: "male person" # what word for erasing the positive concept from 98 | # positive: "male person, with makeup, cosmetic, concealer, mascara" # concept to erase 99 | # unconditional: "male person, barefaced, ugly" # word to take the difference from the positive concept 100 | # neutral: "male person" # starting point for conditioning the target 101 | # action: "enhance" # erase or enhance 102 | # guidance_scale: 4 103 | # resolution: 512 104 | # dynamic_resolution: false 105 | # batch_size: 1 106 | # - target: "female person" # what word for erasing the positive concept from 107 | # positive: "female person, with makeup, cosmetic, concealer, mascara, lipstick" # concept to erase 108 | # unconditional: "female person, barefaced, ugly" # word to take the difference from the positive concept 109 | # neutral: "female person" # starting point for conditioning the target 110 | # action: "enhance" # erase or enhance 111 | # guidance_scale: 4 112 | # resolution: 512 113 | # dynamic_resolution: false 114 | # batch_size: 1 115 | ####################################################################################################### SURPRISED SLIDER 116 | # - target: "male person" # what word for erasing the positive concept from 117 | # positive: "male person, with shocked look, surprised, stunned, amazed" # concept to erase 118 | # unconditional: "male person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept 119 | # neutral: "male person" # starting point for conditioning the target 120 | # action: "enhance" # erase or enhance 121 | # guidance_scale: 4 122 | # resolution: 512 123 | # dynamic_resolution: false 124 | # batch_size: 1 125 | # - target: "female person" # what word for erasing the positive concept from 126 | # positive: "female person, with shocked look, surprised, stunned, amazed" # concept to erase 127 | # unconditional: "female person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept 128 | # neutral: "female person" # starting point for conditioning the target 129 | # action: "enhance" # erase or enhance 130 | # guidance_scale: 4 131 | # resolution: 512 132 | # dynamic_resolution: false 133 | # batch_size: 1 134 | ####################################################################################################### OBESE SLIDER 135 | # - target: "male person" # what word for erasing the positive concept from 136 | # positive: "male person, fat, chubby, overweight, obese" # concept to erase 137 | # unconditional: "male person, lean, fit, slim, slender" # word to take the difference from the positive concept 138 | # neutral: "male person" # starting point for conditioning the target 139 | # action: "enhance" # erase or enhance 140 | # guidance_scale: 4 141 | # resolution: 512 142 | # dynamic_resolution: false 143 | # batch_size: 1 144 | # - target: "female person" # what word for erasing the positive concept from 145 | # positive: "female person, fat, chubby, overweight, obese" # concept to erase 146 | # unconditional: "female person, lean, fit, slim, slender" # word to take the difference from the positive concept 147 | # neutral: "female person" # starting point for conditioning the target 148 | # action: "enhance" # erase or enhance 149 | # guidance_scale: 4 150 | # resolution: 512 151 | # dynamic_resolution: false 152 | # batch_size: 1 153 | ####################################################################################################### PROFESSIONAL SLIDER 154 | # - target: "male person" # what word for erasing the positive concept from 155 | # positive: "male person, professionally dressed, stylised hair, clean face" # concept to erase 156 | # unconditional: "male person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept 157 | # neutral: "male person" # starting point for conditioning the target 158 | # action: "enhance" # erase or enhance 159 | # guidance_scale: 4 160 | # resolution: 512 161 | # dynamic_resolution: false 162 | # batch_size: 1 163 | # - target: "female person" # what word for erasing the positive concept from 164 | # positive: "female person, professionally dressed, stylised hair, clean face" # concept to erase 165 | # unconditional: "female person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept 166 | # neutral: "female person" # starting point for conditioning the target 167 | # action: "enhance" # erase or enhance 168 | # guidance_scale: 4 169 | # resolution: 512 170 | # dynamic_resolution: false 171 | # batch_size: 1 172 | ####################################################################################################### GLASSES SLIDER 173 | # - target: "male person" # what word for erasing the positive concept from 174 | # positive: "male person, wearing glasses" # concept to erase 175 | # unconditional: "male person" # word to take the difference from the positive concept 176 | # neutral: "male person" # starting point for conditioning the target 177 | # action: "enhance" # erase or enhance 178 | # guidance_scale: 4 179 | # resolution: 512 180 | # dynamic_resolution: false 181 | # batch_size: 1 182 | # - target: "female person" # what word for erasing the positive concept from 183 | # positive: "female person, wearing glasses" # concept to erase 184 | # unconditional: "female person" # word to take the difference from the positive concept 185 | # neutral: "female person" # starting point for conditioning the target 186 | # action: "enhance" # erase or enhance 187 | # guidance_scale: 4 188 | # resolution: 512 189 | # dynamic_resolution: false 190 | # batch_size: 1 191 | ####################################################################################################### ASTRONAUGHT SLIDER 192 | # - target: "astronaught" # what word for erasing the positive concept from 193 | # positive: "astronaught, with orange colored spacesuit" # concept to erase 194 | # unconditional: "astronaught" # word to take the difference from the positive concept 195 | # neutral: "astronaught" # starting point for conditioning the target 196 | # action: "enhance" # erase or enhance 197 | # guidance_scale: 4 198 | # resolution: 512 199 | # dynamic_resolution: false 200 | # batch_size: 1 201 | ####################################################################################################### SMILING SLIDER 202 | # - target: "male person" # what word for erasing the positive concept from 203 | # positive: "male person, smiling" # concept to erase 204 | # unconditional: "male person, frowning" # word to take the difference from the positive concept 205 | # neutral: "male person" # starting point for conditioning the target 206 | # action: "enhance" # erase or enhance 207 | # guidance_scale: 4 208 | # resolution: 512 209 | # dynamic_resolution: false 210 | # batch_size: 1 211 | # - target: "female person" # what word for erasing the positive concept from 212 | # positive: "female person, smiling" # concept to erase 213 | # unconditional: "female person, frowning" # word to take the difference from the positive concept 214 | # neutral: "female person" # starting point for conditioning the target 215 | # action: "enhance" # erase or enhance 216 | # guidance_scale: 4 217 | # resolution: 512 218 | # dynamic_resolution: false 219 | # batch_size: 1 220 | ####################################################################################################### CAR COLOR SLIDER 221 | # - target: "car" # what word for erasing the positive concept from 222 | # positive: "car, white color" # concept to erase 223 | # unconditional: "car, black color" # word to take the difference from the positive concept 224 | # neutral: "car" # starting point for conditioning the target 225 | # action: "enhance" # erase or enhance 226 | # guidance_scale: 4 227 | # resolution: 512 228 | # dynamic_resolution: false 229 | # batch_size: 1 230 | ####################################################################################################### DETAILS SLIDER 231 | # - target: "" # what word for erasing the positive concept from 232 | # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase 233 | # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept 234 | # neutral: "" # starting point for conditioning the target 235 | # action: "enhance" # erase or enhance 236 | # guidance_scale: 4 237 | # resolution: 512 238 | # dynamic_resolution: false 239 | # batch_size: 1 240 | ####################################################################################################### CARTOON SLIDER 241 | # - target: "male person" # what word for erasing the positive concept from 242 | # positive: "male person, cartoon style, pixar style, animated style" # concept to erase 243 | # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept 244 | # neutral: "male person" # starting point for conditioning the target 245 | # action: "enhance" # erase or enhance 246 | # guidance_scale: 4 247 | # resolution: 512 248 | # dynamic_resolution: false 249 | # batch_size: 1 250 | # - target: "female person" # what word for erasing the positive concept from 251 | # positive: "female person, cartoon style, pixar style, animated style" # concept to erase 252 | # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept 253 | # neutral: "female person" # starting point for conditioning the target 254 | # action: "enhance" # erase or enhance 255 | # guidance_scale: 4 256 | # resolution: 512 257 | # dynamic_resolution: false 258 | # batch_size: 1 259 | ####################################################################################################### CLAY SLIDER 260 | # - target: "male person" # what word for erasing the positive concept from 261 | # positive: "male person, clay style, made out of clay, clay sculpture" # concept to erase 262 | # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept 263 | # neutral: "male person" # starting point for conditioning the target 264 | # action: "enhance" # erase or enhance 265 | # guidance_scale: 4 266 | # resolution: 512 267 | # dynamic_resolution: false 268 | # batch_size: 1 269 | # - target: "female person" # what word for erasing the positive concept from 270 | # positive: "female person, clay style, made out of clay, clay sculpture" # concept to erase 271 | # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept 272 | # neutral: "female person" # starting point for conditioning the target 273 | # action: "enhance" # erase or enhance 274 | # guidance_scale: 4 275 | # resolution: 512 276 | # dynamic_resolution: false 277 | # batch_size: 1 278 | ####################################################################################################### SCULPTURE SLIDER 279 | # - target: "male person" # what word for erasing the positive concept from 280 | # positive: "male person, cement sculpture, cement greek statue style" # concept to erase 281 | # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept 282 | # neutral: "male person" # starting point for conditioning the target 283 | # action: "enhance" # erase or enhance 284 | # guidance_scale: 4 285 | # resolution: 512 286 | # dynamic_resolution: false 287 | # batch_size: 1 288 | # - target: "female person" # what word for erasing the positive concept from 289 | # positive: "female person, cement sculpture, cement greek statue style" # concept to erase 290 | # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept 291 | # neutral: "female person" # starting point for conditioning the target 292 | # action: "enhance" # erase or enhance 293 | # guidance_scale: 4 294 | # resolution: 512 295 | # dynamic_resolution: false 296 | # batch_size: 1 297 | ####################################################################################################### METAL SLIDER 298 | # - target: "" # what word for erasing the positive concept from 299 | # positive: "made out of metal, metallic style, iron, copper, platinum metal," # concept to erase 300 | # unconditional: "wooden style, made out of wood" # word to take the difference from the positive concept 301 | # neutral: "" # starting point for conditioning the target 302 | # action: "enhance" # erase or enhance 303 | # guidance_scale: 4 304 | # resolution: 512 305 | # dynamic_resolution: false 306 | # batch_size: 1 307 | ####################################################################################################### FESTIVE SLIDER 308 | # - target: "" # what word for erasing the positive concept from 309 | # positive: "festive, colorful banners, confetti, indian festival decorations, chinese festival decorations, fireworks, parade, cherry, gala, happy, celebrations" # concept to erase 310 | # unconditional: "dull, dark, sad, desserted, empty, alone" # word to take the difference from the positive concept 311 | # neutral: "" # starting point for conditioning the target 312 | # action: "enhance" # erase or enhance 313 | # guidance_scale: 4 314 | # resolution: 512 315 | # dynamic_resolution: false 316 | # batch_size: 1 317 | ####################################################################################################### TROPICAL SLIDER 318 | # - target: "" # what word for erasing the positive concept from 319 | # positive: "tropical, beach, sunny, hot" # concept to erase 320 | # unconditional: "arctic, winter, snow, ice, iceburg, snowfall" # word to take the difference from the positive concept 321 | # neutral: "" # starting point for conditioning the target 322 | # action: "enhance" # erase or enhance 323 | # guidance_scale: 4 324 | # resolution: 512 325 | # dynamic_resolution: false 326 | # batch_size: 1 327 | ####################################################################################################### MODERN SLIDER 328 | # - target: "" # what word for erasing the positive concept from 329 | # positive: "modern, futuristic style, trendy, stylish, swank" # concept to erase 330 | # unconditional: "ancient, classic style, regal, vintage" # word to take the difference from the positive concept 331 | # neutral: "" # starting point for conditioning the target 332 | # action: "enhance" # erase or enhance 333 | # guidance_scale: 4 334 | # resolution: 512 335 | # dynamic_resolution: false 336 | # batch_size: 1 337 | ####################################################################################################### BOKEH SLIDER 338 | # - target: "" # what word for erasing the positive concept from 339 | # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase 340 | # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept 341 | # unconditional: "" 342 | # neutral: "" # starting point for conditioning the target 343 | # action: "enhance" # erase or enhance 344 | # guidance_scale: 4 345 | # resolution: 512 346 | # dynamic_resolution: false 347 | # batch_size: 1 348 | ####################################################################################################### LONG HAIR SLIDER 349 | # - target: "male person" # what word for erasing the positive concept from 350 | # positive: "male person, with long hair" # concept to erase 351 | # unconditional: "male person, with short hair" # word to take the difference from the positive concept 352 | # neutral: "male person" # starting point for conditioning the target 353 | # action: "enhance" # erase or enhance 354 | # guidance_scale: 4 355 | # resolution: 512 356 | # dynamic_resolution: false 357 | # batch_size: 1 358 | # - target: "female person" # what word for erasing the positive concept from 359 | # positive: "female person, with long hair" # concept to erase 360 | # unconditional: "female person, with short hair" # word to take the difference from the positive concept 361 | # neutral: "female person" # starting point for conditioning the target 362 | # action: "enhance" # erase or enhance 363 | # guidance_scale: 4 364 | # resolution: 512 365 | # dynamic_resolution: false 366 | # batch_size: 1 367 | ####################################################################################################### NEGPROMPT SLIDER 368 | # - target: "" # what word for erasing the positive concept from 369 | # positive: "cartoon, cgi, render, illustration, painting, drawing, bad quality, grainy, low resolution" # concept to erase 370 | # unconditional: "" 371 | # neutral: "" # starting point for conditioning the target 372 | # action: "erase" # erase or enhance 373 | # guidance_scale: 4 374 | # resolution: 512 375 | # dynamic_resolution: false 376 | # batch_size: 1 377 | ####################################################################################################### EXPENSIVE FOOD SLIDER 378 | # - target: "food" # what word for erasing the positive concept from 379 | # positive: "food, expensive and fine dining" # concept to erase 380 | # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept 381 | # neutral: "food" # starting point for conditioning the target 382 | # action: "enhance" # erase or enhance 383 | # guidance_scale: 4 384 | # resolution: 512 385 | # dynamic_resolution: false 386 | # batch_size: 1 387 | ####################################################################################################### COOKED FOOD SLIDER 388 | # - target: "food" # what word for erasing the positive concept from 389 | # positive: "food, cooked, baked, roasted, fried" # concept to erase 390 | # unconditional: "food, raw, uncooked, fresh, undone" # word to take the difference from the positive concept 391 | # neutral: "food" # starting point for conditioning the target 392 | # action: "enhance" # erase or enhance 393 | # guidance_scale: 4 394 | # resolution: 512 395 | # dynamic_resolution: false 396 | # batch_size: 1 397 | ####################################################################################################### MEAT FOOD SLIDER 398 | # - target: "food" # what word for erasing the positive concept from 399 | # positive: "food, meat, steak, fish, non-vegetrian, beef, lamb, pork, chicken, salmon" # concept to erase 400 | # unconditional: "food, vegetables, fruits, leafy-vegetables, greens, vegetarian, vegan, tomatoes, onions, carrots" # word to take the difference from the positive concept 401 | # neutral: "food" # starting point for conditioning the target 402 | # action: "enhance" # erase or enhance 403 | # guidance_scale: 4 404 | # resolution: 512 405 | # dynamic_resolution: false 406 | # batch_size: 1 407 | ####################################################################################################### WEATHER SLIDER 408 | # - target: "" # what word for erasing the positive concept from 409 | # positive: "snowy, winter, cold, ice, snowfall, white" # concept to erase 410 | # unconditional: "hot, summer, bright, sunny" # word to take the difference from the positive concept 411 | # neutral: "" # starting point for conditioning the target 412 | # action: "enhance" # erase or enhance 413 | # guidance_scale: 4 414 | # resolution: 512 415 | # dynamic_resolution: false 416 | # batch_size: 1 417 | ####################################################################################################### NIGHT/DAY SLIDER 418 | # - target: "" # what word for erasing the positive concept from 419 | # positive: "night time, dark, darkness, pitch black, nighttime" # concept to erase 420 | # unconditional: "day time, bright, sunny, daytime, sunlight" # word to take the difference from the positive concept 421 | # neutral: "" # starting point for conditioning the target 422 | # action: "enhance" # erase or enhance 423 | # guidance_scale: 4 424 | # resolution: 512 425 | # dynamic_resolution: false 426 | # batch_size: 1 427 | ####################################################################################################### INDOOR/OUTDOOR SLIDER 428 | # - target: "" # what word for erasing the positive concept from 429 | # positive: "indoor, inside a room, inside, interior" # concept to erase 430 | # unconditional: "outdoor, outside, open air, exterior" # word to take the difference from the positive concept 431 | # neutral: "" # starting point for conditioning the target 432 | # action: "enhance" # erase or enhance 433 | # guidance_scale: 4 434 | # resolution: 512 435 | # dynamic_resolution: false 436 | # batch_size: 1 437 | ####################################################################################################### GOODHANDS SLIDER 438 | # - target: "" # what word for erasing the positive concept from 439 | # positive: "realistic hands, realistic limbs, perfect limbs, perfect hands, 5 fingers, five fingers, hyper realisitc hands" # concept to erase 440 | # unconditional: "poorly drawn limbs, distorted limbs, poorly rendered hands,bad anatomy, disfigured, mutated body parts, bad composition" # word to take the difference from the positive concept 441 | # neutral: "" # starting point for conditioning the target 442 | # action: "enhance" # erase or enhance 443 | # guidance_scale: 4 444 | # resolution: 512 445 | # dynamic_resolution: false 446 | # batch_size: 1 447 | ####################################################################################################### RUSTY CAR SLIDER 448 | # - target: "car" # what word for erasing the positive concept from 449 | # positive: "car, rusty conditioned" # concept to erase 450 | # unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept 451 | # neutral: "car" # starting point for conditioning the target 452 | # action: "enhance" # erase or enhance 453 | # guidance_scale: 4 454 | # resolution: 512 455 | # dynamic_resolution: false 456 | # batch_size: 1 457 | ####################################################################################################### RUSTY CAR SLIDER 458 | # - target: "car" # what word for erasing the positive concept from 459 | # positive: "car, damaged, broken headlights, dented car, with scrapped paintwork" # concept to erase 460 | # unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept 461 | # neutral: "car" # starting point for conditioning the target 462 | # action: "enhance" # erase or enhance 463 | # guidance_scale: 4 464 | # resolution: 512 465 | # dynamic_resolution: false 466 | # batch_size: 1 467 | ####################################################################################################### CLUTTERED ROOM SLIDER 468 | # - target: "room" # what word for erasing the positive concept from 469 | # positive: "room, cluttered, disorganized, dirty, jumbled, scattered" # concept to erase 470 | # unconditional: "room, super organized, clean, ordered, neat, tidy" # word to take the difference from the positive concept 471 | # neutral: "room" # starting point for conditioning the target 472 | # action: "enhance" # erase or enhance 473 | # guidance_scale: 4 474 | # resolution: 512 475 | # dynamic_resolution: false 476 | # batch_size: 1 477 | ####################################################################################################### HANDS SLIDER 478 | # - target: "hands" # what word for erasing the positive concept from 479 | # positive: "realistic hands, five fingers, 8k hyper realistic hands" # concept to erase 480 | # unconditional: "poorly drawn hands, distorted hands, amputed fingers" # word to take the difference from the positive concept 481 | # neutral: "hands" # starting point for conditioning the target 482 | # action: "enhance" # erase or enhance 483 | # guidance_scale: 4 484 | # resolution: 512 485 | # dynamic_resolution: false 486 | # batch_size: 1 487 | ####################################################################################################### HANDS SLIDER 488 | # - target: "female person" # what word for erasing the positive concept from 489 | # positive: "female person, with a surprised look" # concept to erase 490 | # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept 491 | # neutral: "female person" # starting point for conditioning the target 492 | # action: "enhance" # erase or enhance 493 | # guidance_scale: 4 494 | # resolution: 512 495 | # dynamic_resolution: false 496 | # batch_size: 1 -------------------------------------------------------------------------------- /textsliders/data/prompts-zombie.yaml: -------------------------------------------------------------------------------- 1 | - target: "male person" # what word for erasing the positive concept from 2 | positive: "male person, zombie" # concept to erase 3 | unconditional: "male person, realistic, natural" # word to take the difference from the positive concept 4 | neutral: "male person" # starting point for conditioning the target 5 | action: "enhance" # erase or enhance 6 | guidance_scale: 4 7 | resolution: 512 8 | dynamic_resolution: false 9 | batch_size: 1 10 | - target: "female person" # what word for erasing the positive concept from 11 | positive: "female person, zombie" # concept to erase 12 | unconditional: "female person, realistic, natural" # word to take the difference from the positive concept 13 | neutral: "female person" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 4 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | ####################################################################################################### -------------------------------------------------------------------------------- /textsliders/data/prompts.yaml: -------------------------------------------------------------------------------- 1 | - target: "male person" # what word for erasing the positive concept from 2 | positive: "male person, very old" # concept to erase 3 | unconditional: "male person, very young" # word to take the difference from the positive concept 4 | neutral: "male person" # starting point for conditioning the target 5 | action: "enhance" # erase or enhance 6 | guidance_scale: 4 7 | resolution: 512 8 | dynamic_resolution: false 9 | batch_size: 1 10 | - target: "female person" # what word for erasing the positive concept from 11 | positive: "female person, very old" # concept to erase 12 | unconditional: "female person, very young" # word to take the difference from the positive concept 13 | neutral: "female person" # starting point for conditioning the target 14 | action: "enhance" # erase or enhance 15 | guidance_scale: 4 16 | resolution: 512 17 | dynamic_resolution: false 18 | batch_size: 1 19 | ####################################################################################################### MUSCULAR SLIDER 20 | # - target: "male person" # what word for erasing the positive concept from 21 | # positive: "male person, cyborg" # concept to erase 22 | # unconditional: "male person, realistic, natural" # word to take the difference from the positive concept 23 | # neutral: "male person" # starting point for conditioning the target 24 | # action: "enhance" # erase or enhance 25 | # guidance_scale: 4 26 | # resolution: 512 27 | # dynamic_resolution: false 28 | # batch_size: 1 29 | # - target: "female person" # what word for erasing the positive concept from 30 | # positive: "female person, cyborg" # concept to erase 31 | # unconditional: "female person, realistic, natural" # word to take the difference from the positive concept 32 | # neutral: "female person" # starting point for conditioning the target 33 | # action: "enhance" # erase or enhance 34 | # guidance_scale: 4 35 | # resolution: 512 36 | # dynamic_resolution: false 37 | # batch_size: 1 38 | ####################################################################################################### MUSCULAR SLIDER 39 | # - target: "" # what word for erasing the positive concept from 40 | # positive: "a group of people" # concept to erase 41 | # unconditional: "a person" # word to take the difference from the positive concept 42 | # neutral: "" # starting point for conditioning the target 43 | # action: "enhance" # erase or enhance 44 | # guidance_scale: 4 45 | # resolution: 512 46 | # dynamic_resolution: false 47 | # batch_size: 1 48 | # - target: "" # what word for erasing the positive concept from 49 | # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase 50 | # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept 51 | # neutral: "" # starting point for conditioning the target 52 | # action: "enhance" # erase or enhance 53 | # guidance_scale: 4 54 | # resolution: 512 55 | # dynamic_resolution: false 56 | # batch_size: 1 57 | # - target: "" # what word for erasing the positive concept from 58 | # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase 59 | # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept 60 | # unconditional: "" 61 | # neutral: "" # starting point for conditioning the target 62 | # action: "enhance" # erase or enhance 63 | # guidance_scale: 4 64 | # resolution: 512 65 | # dynamic_resolution: false 66 | # batch_size: 1 67 | # - target: "food" # what word for erasing the positive concept from 68 | # positive: "food, expensive and fine dining" # concept to erase 69 | # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept 70 | # neutral: "food" # starting point for conditioning the target 71 | # action: "enhance" # erase or enhance 72 | # guidance_scale: 4 73 | # resolution: 512 74 | # dynamic_resolution: false 75 | # batch_size: 1 76 | # - target: "room" # what word for erasing the positive concept from 77 | # positive: "room, dirty disorganised and cluttered" # concept to erase 78 | # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept 79 | # neutral: "room" # starting point for conditioning the target 80 | # action: "enhance" # erase or enhance 81 | # guidance_scale: 4 82 | # resolution: 512 83 | # dynamic_resolution: false 84 | # batch_size: 1 85 | # - target: "male person" # what word for erasing the positive concept from 86 | # positive: "male person, with a surprised look" # concept to erase 87 | # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept 88 | # neutral: "male person" # starting point for conditioning the target 89 | # action: "enhance" # erase or enhance 90 | # guidance_scale: 4 91 | # resolution: 512 92 | # dynamic_resolution: false 93 | # batch_size: 1 94 | # - target: "female person" # what word for erasing the positive concept from 95 | # positive: "female person, with a surprised look" # concept to erase 96 | # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept 97 | # neutral: "female person" # starting point for conditioning the target 98 | # action: "enhance" # erase or enhance 99 | # guidance_scale: 4 100 | # resolution: 512 101 | # dynamic_resolution: false 102 | # batch_size: 1 103 | # - target: "sky" # what word for erasing the positive concept from 104 | # positive: "peaceful sky" # concept to erase 105 | # unconditional: "sky" # word to take the difference from the positive concept 106 | # neutral: "sky" # starting point for conditioning the target 107 | # action: "enhance" # erase or enhance 108 | # guidance_scale: 4 109 | # resolution: 512 110 | # dynamic_resolution: false 111 | # batch_size: 1 112 | # - target: "sky" # what word for erasing the positive concept from 113 | # positive: "chaotic dark sky" # concept to erase 114 | # unconditional: "sky" # word to take the difference from the positive concept 115 | # neutral: "sky" # starting point for conditioning the target 116 | # action: "erase" # erase or enhance 117 | # guidance_scale: 4 118 | # resolution: 512 119 | # dynamic_resolution: false 120 | # batch_size: 1 121 | # - target: "person" # what word for erasing the positive concept from 122 | # positive: "person, very young" # concept to erase 123 | # unconditional: "person" # word to take the difference from the positive concept 124 | # neutral: "person" # starting point for conditioning the target 125 | # action: "erase" # erase or enhance 126 | # guidance_scale: 4 127 | # resolution: 512 128 | # dynamic_resolution: false 129 | # batch_size: 1 130 | # overweight 131 | # - target: "art" # what word for erasing the positive concept from 132 | # positive: "realistic art" # concept to erase 133 | # unconditional: "art" # word to take the difference from the positive concept 134 | # neutral: "art" # starting point for conditioning the target 135 | # action: "enhance" # erase or enhance 136 | # guidance_scale: 4 137 | # resolution: 512 138 | # dynamic_resolution: false 139 | # batch_size: 1 140 | # - target: "art" # what word for erasing the positive concept from 141 | # positive: "abstract art" # concept to erase 142 | # unconditional: "art" # word to take the difference from the positive concept 143 | # neutral: "art" # starting point for conditioning the target 144 | # action: "erase" # erase or enhance 145 | # guidance_scale: 4 146 | # resolution: 512 147 | # dynamic_resolution: false 148 | # batch_size: 1 149 | # sky 150 | # - target: "weather" # what word for erasing the positive concept from 151 | # positive: "bright pleasant weather" # concept to erase 152 | # unconditional: "weather" # word to take the difference from the positive concept 153 | # neutral: "weather" # starting point for conditioning the target 154 | # action: "enhance" # erase or enhance 155 | # guidance_scale: 4 156 | # resolution: 512 157 | # dynamic_resolution: false 158 | # batch_size: 1 159 | # - target: "weather" # what word for erasing the positive concept from 160 | # positive: "dark gloomy weather" # concept to erase 161 | # unconditional: "weather" # word to take the difference from the positive concept 162 | # neutral: "weather" # starting point for conditioning the target 163 | # action: "erase" # erase or enhance 164 | # guidance_scale: 4 165 | # resolution: 512 166 | # dynamic_resolution: false 167 | # batch_size: 1 168 | # hair 169 | # - target: "person" # what word for erasing the positive concept from 170 | # positive: "person with long hair" # concept to erase 171 | # unconditional: "person" # word to take the difference from the positive concept 172 | # neutral: "person" # starting point for conditioning the target 173 | # action: "enhance" # erase or enhance 174 | # guidance_scale: 4 175 | # resolution: 512 176 | # dynamic_resolution: false 177 | # batch_size: 1 178 | # - target: "person" # what word for erasing the positive concept from 179 | # positive: "person with short hair" # concept to erase 180 | # unconditional: "person" # word to take the difference from the positive concept 181 | # neutral: "person" # starting point for conditioning the target 182 | # action: "erase" # erase or enhance 183 | # guidance_scale: 4 184 | # resolution: 512 185 | # dynamic_resolution: false 186 | # batch_size: 1 187 | # - target: "girl" # what word for erasing the positive concept from 188 | # positive: "baby girl" # concept to erase 189 | # unconditional: "girl" # word to take the difference from the positive concept 190 | # neutral: "girl" # starting point for conditioning the target 191 | # action: "enhance" # erase or enhance 192 | # guidance_scale: -4 193 | # resolution: 512 194 | # dynamic_resolution: false 195 | # batch_size: 1 196 | # - target: "boy" # what word for erasing the positive concept from 197 | # positive: "old man" # concept to erase 198 | # unconditional: "boy" # word to take the difference from the positive concept 199 | # neutral: "boy" # starting point for conditioning the target 200 | # action: "enhance" # erase or enhance 201 | # guidance_scale: 4 202 | # resolution: 512 203 | # dynamic_resolution: false 204 | # batch_size: 1 205 | # - target: "boy" # what word for erasing the positive concept from 206 | # positive: "baby boy" # concept to erase 207 | # unconditional: "boy" # word to take the difference from the positive concept 208 | # neutral: "boy" # starting point for conditioning the target 209 | # action: "enhance" # erase or enhance 210 | # guidance_scale: -4 211 | # resolution: 512 212 | # dynamic_resolution: false 213 | # batch_size: 1 -------------------------------------------------------------------------------- /textsliders/flush.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | 4 | torch.cuda.empty_cache() 5 | gc.collect() 6 | -------------------------------------------------------------------------------- /textsliders/prompt_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, Union, List, Dict 2 | 3 | import yaml 4 | from pathlib import Path 5 | 6 | 7 | from pydantic import BaseModel, root_validator 8 | import torch 9 | import copy 10 | 11 | ACTION_TYPES = Literal[ 12 | "erase", 13 | "enhance", 14 | ] 15 | 16 | 17 | # XL は二種類必要なので 18 | class PromptEmbedsXL: 19 | text_embeds: torch.FloatTensor 20 | pooled_embeds: torch.FloatTensor 21 | 22 | def __init__(self, *args) -> None: 23 | self.text_embeds = args[0] 24 | self.pooled_embeds = args[1] 25 | 26 | 27 | # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL 28 | PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL] 29 | 30 | 31 | class PromptEmbedsCache: 32 | prompts: Dict[str, PROMPT_EMBEDDING] = {} 33 | 34 | def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None: 35 | self.prompts[__name] = __value 36 | 37 | def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]: 38 | if __name in self.prompts: 39 | return self.prompts[__name] 40 | else: 41 | return None 42 | 43 | 44 | class PromptSettings(BaseModel): # yaml のやつ 45 | target: str 46 | positive: str = None # if None, target will be used 47 | unconditional: str = "" # default is "" 48 | neutral: str = None # if None, unconditional will be used 49 | action: ACTION_TYPES = "erase" # default is "erase" 50 | guidance_scale: float = 1.0 # default is 1.0 51 | resolution: int = 512 # default is 512 52 | dynamic_resolution: bool = False # default is False 53 | batch_size: int = 1 # default is 1 54 | dynamic_crops: bool = False # default is False. only used when model is XL 55 | 56 | @root_validator(pre=True) 57 | def fill_prompts(cls, values): 58 | keys = values.keys() 59 | if "target" not in keys: 60 | raise ValueError("target must be specified") 61 | if "positive" not in keys: 62 | values["positive"] = values["target"] 63 | if "unconditional" not in keys: 64 | values["unconditional"] = "" 65 | if "neutral" not in keys: 66 | values["neutral"] = values["unconditional"] 67 | 68 | return values 69 | 70 | 71 | class PromptEmbedsPair: 72 | target: PROMPT_EMBEDDING # not want to generate the concept 73 | positive: PROMPT_EMBEDDING # generate the concept 74 | unconditional: PROMPT_EMBEDDING # uncondition (default should be empty) 75 | neutral: PROMPT_EMBEDDING # base condition (default should be empty) 76 | 77 | guidance_scale: float 78 | resolution: int 79 | dynamic_resolution: bool 80 | batch_size: int 81 | dynamic_crops: bool 82 | 83 | loss_fn: torch.nn.Module 84 | action: ACTION_TYPES 85 | 86 | def __init__( 87 | self, 88 | loss_fn: torch.nn.Module, 89 | target: PROMPT_EMBEDDING, 90 | positive: PROMPT_EMBEDDING, 91 | unconditional: PROMPT_EMBEDDING, 92 | neutral: PROMPT_EMBEDDING, 93 | settings: PromptSettings, 94 | ) -> None: 95 | self.loss_fn = loss_fn 96 | self.target = target 97 | self.positive = positive 98 | self.unconditional = unconditional 99 | self.neutral = neutral 100 | 101 | self.guidance_scale = settings.guidance_scale 102 | self.resolution = settings.resolution 103 | self.dynamic_resolution = settings.dynamic_resolution 104 | self.batch_size = settings.batch_size 105 | self.dynamic_crops = settings.dynamic_crops 106 | self.action = settings.action 107 | self.settings = settings 108 | 109 | def _erase( 110 | self, 111 | target_latents: torch.FloatTensor, # "van gogh" 112 | positive_latents: torch.FloatTensor, # "van gogh" 113 | unconditional_latents: torch.FloatTensor, # "" 114 | neutral_latents: torch.FloatTensor, # "" 115 | scale: float = 1.0, 116 | ) -> torch.FloatTensor: 117 | """Target latents are going not to have the positive concept.""" 118 | return self.loss_fn( 119 | target_latents, 120 | neutral_latents 121 | - scale * self.guidance_scale * (positive_latents - unconditional_latents) 122 | ) 123 | 124 | 125 | def _enhance( 126 | self, 127 | target_latents: torch.FloatTensor, # "van gogh" 128 | positive_latents: torch.FloatTensor, # "van gogh" 129 | unconditional_latents: torch.FloatTensor, # "" 130 | neutral_latents: torch.FloatTensor, # "" 131 | scale: float = 1.0, 132 | ): 133 | """Target latents are going to have the positive concept.""" 134 | return self.loss_fn( 135 | target_latents, 136 | neutral_latents 137 | + scale * self.guidance_scale * (positive_latents - unconditional_latents) 138 | ) 139 | 140 | def loss( 141 | self, 142 | **kwargs, 143 | ): 144 | if self.action == "erase": 145 | return self._erase(**kwargs) 146 | 147 | elif self.action == "enhance": 148 | return self._enhance(**kwargs) 149 | 150 | else: 151 | raise ValueError("action must be erase or enhance") 152 | 153 | 154 | def load_prompts_from_yaml(path, attributes = []): 155 | with open(path, "r") as f: 156 | prompts = yaml.safe_load(f) 157 | print(prompts) 158 | if len(prompts) == 0: 159 | raise ValueError("prompts file is empty") 160 | if len(attributes)!=0: 161 | newprompts = [] 162 | for i in range(len(prompts)): 163 | for att in attributes: 164 | copy_ = copy.deepcopy(prompts[i]) 165 | copy_['target'] = att + ' ' + copy_['target'] 166 | copy_['positive'] = att + ' ' + copy_['positive'] 167 | copy_['neutral'] = att + ' ' + copy_['neutral'] 168 | copy_['unconditional'] = att + ' ' + copy_['unconditional'] 169 | newprompts.append(copy_) 170 | else: 171 | newprompts = copy.deepcopy(prompts) 172 | 173 | print(newprompts) 174 | print(len(prompts), len(newprompts)) 175 | prompt_settings = [PromptSettings(**prompt) for prompt in newprompts] 176 | 177 | return prompt_settings 178 | -------------------------------------------------------------------------------- /textsliders/train_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Literal, List, Tuple 2 | 3 | import torch 4 | 5 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection 6 | from diffusers import UNet2DConditionModel, SchedulerMixin 7 | 8 | from tqdm import tqdm 9 | 10 | 11 | TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" 12 | TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" 13 | 14 | AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] 15 | 16 | SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] 17 | 18 | DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this 19 | 20 | 21 | UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 22 | VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 23 | 24 | UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL 25 | TEXT_ENCODER_2_PROJECTION_DIM = 1280 26 | UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 27 | 28 | 29 | def get_random_noise( 30 | batch_size: int, height: int, width: int, generator: torch.Generator = None 31 | ) -> torch.Tensor: 32 | return torch.randn( 33 | ( 34 | batch_size, 35 | UNET_IN_CHANNELS, 36 | height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや 37 | width // VAE_SCALE_FACTOR, 38 | ), 39 | generator=generator, 40 | device="cpu", 41 | ) 42 | 43 | 44 | # https://www.crosslabs.org/blog/diffusion-with-offset-noise 45 | def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float): 46 | latents = latents + noise_offset * torch.randn( 47 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 48 | ) 49 | return latents 50 | 51 | 52 | def get_initial_latents( 53 | scheduler: SchedulerMixin, 54 | n_imgs: int, 55 | height: int, 56 | width: int, 57 | n_prompts: int, 58 | generator=None, 59 | ) -> torch.Tensor: 60 | noise = get_random_noise(n_imgs, height, width, generator=generator).repeat( 61 | n_prompts, 1, 1, 1 62 | ) 63 | 64 | latents = noise * scheduler.init_noise_sigma 65 | 66 | return latents 67 | 68 | 69 | def text_tokenize( 70 | tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ! 71 | prompts: List[str], 72 | ): 73 | return tokenizer( 74 | prompts, 75 | padding="max_length", 76 | max_length=tokenizer.model_max_length, 77 | truncation=True, 78 | return_tensors="pt", 79 | ).input_ids 80 | 81 | 82 | def text_encode(text_encoder: CLIPTextModel, tokens): 83 | return text_encoder(tokens.to(text_encoder.device))[0] 84 | 85 | 86 | def encode_prompts( 87 | tokenizer: CLIPTokenizer, 88 | text_encoder: CLIPTokenizer, 89 | prompts: List[str], 90 | ): 91 | 92 | text_tokens = text_tokenize(tokenizer, prompts) 93 | text_embeddings = text_encode(text_encoder, text_tokens) 94 | 95 | 96 | 97 | return text_embeddings 98 | 99 | 100 | def encode_prompts_slider( 101 | tokenizer: CLIPTokenizer, 102 | text_encoder: CLIPTokenizer, 103 | prompts: List[str], 104 | num_images_per_prompt: int = 1, 105 | sc: float = 1.0, 106 | ): 107 | 108 | text_tokens = text_tokenize(tokenizer, prompts) 109 | idx = text_tokens.argmax(-1) 110 | text_embeddings = text_encode(text_encoder, text_tokens) 111 | batch_indices = torch.arange(len(text_tokens)) 112 | text_embeddings[batch_indices, idx, :] = sc * text_embeddings[batch_indices, idx, :] 113 | 114 | 115 | 116 | return text_embeddings 117 | 118 | 119 | # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 120 | def text_encode_xl( 121 | text_encoder: SDXL_TEXT_ENCODER_TYPE, 122 | tokens: torch.FloatTensor, 123 | num_images_per_prompt: int = 1, 124 | ): 125 | prompt_embeds = text_encoder( 126 | tokens.to(text_encoder.device), output_hidden_states=True 127 | ) 128 | pooled_prompt_embeds = prompt_embeds[0] 129 | prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer 130 | 131 | bs_embed, seq_len, _ = prompt_embeds.shape 132 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 133 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 134 | 135 | return prompt_embeds, pooled_prompt_embeds 136 | 137 | 138 | def encode_prompts_xl( 139 | tokenizers: List[CLIPTokenizer], 140 | text_encoders: List[SDXL_TEXT_ENCODER_TYPE], 141 | prompts: List[str], 142 | num_images_per_prompt: int = 1, 143 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 144 | # text_encoder and text_encoder_2's penuultimate layer's output 145 | text_embeds_list = [] 146 | pooled_text_embeds = None # always text_encoder_2's pool 147 | 148 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 149 | text_tokens_input_ids = text_tokenize(tokenizer, prompts) 150 | text_embeds, pooled_text_embeds = text_encode_xl( 151 | text_encoder, text_tokens_input_ids, num_images_per_prompt 152 | ) 153 | 154 | text_embeds_list.append(text_embeds) 155 | 156 | bs_embed = pooled_text_embeds.shape[0] 157 | pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( 158 | bs_embed * num_images_per_prompt, -1 159 | ) 160 | 161 | return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds 162 | 163 | 164 | def encode_prompts_xl_slider( 165 | tokenizers: List[CLIPTokenizer], 166 | text_encoders: List[SDXL_TEXT_ENCODER_TYPE], 167 | prompts: List[str], 168 | num_images_per_prompt: int = 1, 169 | sc: float = 1.0, 170 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 171 | # text_encoder and text_encoder_2's penuultimate layer's output 172 | text_embeds_list = [] 173 | pooled_text_embeds = None # always text_encoder_2's pool 174 | k = 0 175 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 176 | text_tokens_input_ids = text_tokenize(tokenizer, prompts) 177 | 178 | idx = text_tokens_input_ids.argmax(-1) 179 | text_embeds, pooled_text_embeds = text_encode_xl( 180 | text_encoder, text_tokens_input_ids, num_images_per_prompt 181 | ) 182 | batch_indices = torch.arange(len(text_tokens_input_ids)) 183 | if k == 0: 184 | text_embeds[batch_indices, idx, :] = sc * text_embeds[batch_indices, idx, :] 185 | 186 | text_embeds_list.append(text_embeds) 187 | k += 1 188 | 189 | bs_embed = pooled_text_embeds.shape[0] 190 | pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( 191 | bs_embed * num_images_per_prompt, -1 192 | ) 193 | 194 | return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds 195 | 196 | 197 | def concat_embeddings( 198 | unconditional: torch.FloatTensor, 199 | conditional: torch.FloatTensor, 200 | n_imgs: int, 201 | ): 202 | return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) 203 | 204 | 205 | # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721 206 | def predict_noise( 207 | unet: UNet2DConditionModel, 208 | scheduler: SchedulerMixin, 209 | timestep: int, # 現在のタイムステップ 210 | latents: torch.FloatTensor, 211 | text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの 212 | guidance_scale=7.5, 213 | ) -> torch.FloatTensor: 214 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 215 | latent_model_input = torch.cat([latents] * 2) 216 | 217 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) 218 | 219 | # predict the noise residual 220 | noise_pred = unet( 221 | latent_model_input, 222 | timestep, 223 | encoder_hidden_states=text_embeddings, 224 | ).sample 225 | 226 | # perform guidance 227 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 228 | guided_target = noise_pred_uncond + guidance_scale * ( 229 | noise_pred_text - noise_pred_uncond 230 | ) 231 | 232 | return guided_target 233 | 234 | 235 | # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 236 | @torch.no_grad() 237 | def diffusion( 238 | unet: UNet2DConditionModel, 239 | scheduler: SchedulerMixin, 240 | latents: torch.FloatTensor, # ただのノイズだけのlatents 241 | text_embeddings: torch.FloatTensor, 242 | total_timesteps: int = 1000, 243 | start_timesteps=0, 244 | **kwargs, 245 | ): 246 | # latents_steps = [] 247 | 248 | for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]): 249 | noise_pred = predict_noise( 250 | unet, scheduler, timestep, latents, text_embeddings, **kwargs 251 | ) 252 | 253 | # compute the previous noisy sample x_t -> x_t-1 254 | latents = scheduler.step(noise_pred, timestep, latents).prev_sample 255 | 256 | # return latents_steps 257 | return latents 258 | 259 | 260 | def rescale_noise_cfg( 261 | noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0 262 | ): 263 | """ 264 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 265 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 266 | """ 267 | std_text = noise_pred_text.std( 268 | dim=list(range(1, noise_pred_text.ndim)), keepdim=True 269 | ) 270 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 271 | # rescale the results from guidance (fixes overexposure) 272 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 273 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 274 | noise_cfg = ( 275 | guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 276 | ) 277 | 278 | return noise_cfg 279 | 280 | 281 | def predict_noise_xl( 282 | unet: UNet2DConditionModel, 283 | scheduler: SchedulerMixin, 284 | timestep: int, # 現在のタイムステップ 285 | latents: torch.FloatTensor, 286 | text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの 287 | add_text_embeddings: torch.FloatTensor, # pooled なやつ 288 | add_time_ids: torch.FloatTensor, 289 | guidance_scale=7.5, 290 | guidance_rescale=0.7, 291 | ) -> torch.FloatTensor: 292 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 293 | latent_model_input = torch.cat([latents] * 2) 294 | 295 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) 296 | 297 | added_cond_kwargs = { 298 | "text_embeds": add_text_embeddings, 299 | "time_ids": add_time_ids, 300 | } 301 | 302 | # predict the noise residual 303 | noise_pred = unet( 304 | latent_model_input, 305 | timestep, 306 | encoder_hidden_states=text_embeddings, 307 | added_cond_kwargs=added_cond_kwargs, 308 | ).sample 309 | 310 | # perform guidance 311 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 312 | guided_target = noise_pred_uncond + guidance_scale * ( 313 | noise_pred_text - noise_pred_uncond 314 | ) 315 | 316 | # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 317 | noise_pred = rescale_noise_cfg( 318 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale 319 | ) 320 | 321 | return guided_target 322 | 323 | 324 | @torch.no_grad() 325 | def diffusion_xl( 326 | unet: UNet2DConditionModel, 327 | scheduler: SchedulerMixin, 328 | latents: torch.FloatTensor, # ただのノイズだけのlatents 329 | text_embeddings: Tuple[torch.FloatTensor, torch.FloatTensor], 330 | add_text_embeddings: torch.FloatTensor, # pooled なやつ 331 | add_time_ids: torch.FloatTensor, 332 | guidance_scale: float = 1.0, 333 | total_timesteps: int = 1000, 334 | start_timesteps=0, 335 | ): 336 | # latents_steps = [] 337 | 338 | for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]): 339 | noise_pred = predict_noise_xl( 340 | unet, 341 | scheduler, 342 | timestep, 343 | latents, 344 | text_embeddings, 345 | add_text_embeddings, 346 | add_time_ids, 347 | guidance_scale=guidance_scale, 348 | guidance_rescale=0.7, 349 | ) 350 | 351 | # compute the previous noisy sample x_t -> x_t-1 352 | latents = scheduler.step(noise_pred, timestep, latents).prev_sample 353 | 354 | # return latents_steps 355 | return latents 356 | 357 | 358 | # for XL 359 | def get_add_time_ids( 360 | height: int, 361 | width: int, 362 | dynamic_crops: bool = False, 363 | dtype: torch.dtype = torch.float32, 364 | ): 365 | if dynamic_crops: 366 | # random float scale between 1 and 3 367 | random_scale = torch.rand(1).item() * 2 + 1 368 | original_size = (int(height * random_scale), int(width * random_scale)) 369 | # random position 370 | crops_coords_top_left = ( 371 | torch.randint(0, original_size[0] - height, (1,)).item(), 372 | torch.randint(0, original_size[1] - width, (1,)).item(), 373 | ) 374 | target_size = (height, width) 375 | else: 376 | original_size = (height, width) 377 | crops_coords_top_left = (0, 0) 378 | target_size = (height, width) 379 | 380 | # this is expected as 6 381 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 382 | 383 | # this is expected as 2816 384 | passed_add_embed_dim = ( 385 | UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 386 | + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 387 | ) 388 | if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: 389 | raise ValueError( 390 | f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 391 | ) 392 | 393 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 394 | return add_time_ids 395 | 396 | 397 | def get_optimizer(name: str): 398 | name = name.lower() 399 | 400 | if name.startswith("dadapt"): 401 | import dadaptation 402 | 403 | if name == "dadaptadam": 404 | return dadaptation.DAdaptAdam 405 | elif name == "dadaptlion": 406 | return dadaptation.DAdaptLion 407 | else: 408 | raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion") 409 | 410 | elif name.endswith("8bit"): # 検証してない 411 | import bitsandbytes as bnb 412 | 413 | if name == "adam8bit": 414 | return bnb.optim.Adam8bit 415 | elif name == "lion8bit": 416 | return bnb.optim.Lion8bit 417 | else: 418 | raise ValueError("8bit optimizer must be adam8bit or lion8bit") 419 | 420 | else: 421 | if name == "adam": 422 | return torch.optim.Adam 423 | elif name == "adamw": 424 | return torch.optim.AdamW 425 | elif name == "lion": 426 | from lion_pytorch import Lion 427 | 428 | return Lion 429 | elif name == "prodigy": 430 | import prodigyopt 431 | 432 | return prodigyopt.Prodigy 433 | else: 434 | raise ValueError("Optimizer must be adam, adamw, lion or Prodigy") 435 | 436 | 437 | def get_lr_scheduler( 438 | name: Optional[str], 439 | optimizer: torch.optim.Optimizer, 440 | max_iterations: Optional[int], 441 | lr_min: Optional[float], 442 | **kwargs, 443 | ): 444 | if name == "cosine": 445 | return torch.optim.lr_scheduler.CosineAnnealingLR( 446 | optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs 447 | ) 448 | elif name == "cosine_with_restarts": 449 | return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 450 | optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs 451 | ) 452 | elif name == "step": 453 | return torch.optim.lr_scheduler.StepLR( 454 | optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs 455 | ) 456 | elif name == "constant": 457 | return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs) 458 | elif name == "linear": 459 | return torch.optim.lr_scheduler.LinearLR( 460 | optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs 461 | ) 462 | else: 463 | raise ValueError( 464 | "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" 465 | ) 466 | 467 | 468 | def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]: 469 | max_resolution = bucket_resolution 470 | min_resolution = bucket_resolution // 2 471 | 472 | step = 64 473 | 474 | min_step = min_resolution // step 475 | max_step = max_resolution // step 476 | 477 | height = torch.randint(min_step, max_step, (1,)).item() * step 478 | width = torch.randint(min_step, max_step, (1,)).item() * step 479 | 480 | return height, width 481 | -------------------------------------------------------------------------------- /textual_inversion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import logging 18 | import math 19 | import os 20 | import random 21 | import shutil 22 | import warnings 23 | from pathlib import Path 24 | from typing import List, Optional 25 | 26 | import numpy as np 27 | import PIL 28 | import safetensors 29 | import torch 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | import transformers 33 | from accelerate import Accelerator 34 | from accelerate.logging import get_logger 35 | from accelerate.utils import ProjectConfiguration, set_seed 36 | from huggingface_hub import create_repo, upload_folder 37 | 38 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released 39 | from packaging import version 40 | from PIL import Image 41 | from torch.utils.data import Dataset 42 | from torchvision import transforms 43 | from tqdm.auto import tqdm 44 | from transformers import CLIPTextModel, CLIPTokenizer 45 | from textsliders import prompt_util 46 | from textsliders import train_util 47 | from textsliders.prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings 48 | 49 | import diffusers 50 | from diffusers import ( 51 | AutoencoderKL, 52 | DDPMScheduler, 53 | DiffusionPipeline, 54 | DPMSolverMultistepScheduler, 55 | StableDiffusionPipeline, 56 | UNet2DConditionModel, 57 | ) 58 | from diffusers.optimization import get_scheduler 59 | from diffusers.utils import check_min_version, is_wandb_available 60 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 61 | from diffusers.utils.import_utils import is_xformers_available 62 | 63 | 64 | if is_wandb_available(): 65 | import wandb 66 | 67 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 68 | PIL_INTERPOLATION = { 69 | "linear": PIL.Image.Resampling.BILINEAR, 70 | "bilinear": PIL.Image.Resampling.BILINEAR, 71 | "bicubic": PIL.Image.Resampling.BICUBIC, 72 | "lanczos": PIL.Image.Resampling.LANCZOS, 73 | "nearest": PIL.Image.Resampling.NEAREST, 74 | } 75 | else: 76 | PIL_INTERPOLATION = { 77 | "linear": PIL.Image.LINEAR, 78 | "bilinear": PIL.Image.BILINEAR, 79 | "bicubic": PIL.Image.BICUBIC, 80 | "lanczos": PIL.Image.LANCZOS, 81 | "nearest": PIL.Image.NEAREST, 82 | } 83 | # ------------------------------------------------------------------------------ 84 | 85 | 86 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 87 | check_min_version("0.27.0.dev0") 88 | 89 | logger = get_logger(__name__) 90 | 91 | 92 | def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None): 93 | img_str = "" 94 | if images is not None: 95 | for i, image in enumerate(images): 96 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 97 | img_str += f"![img_{i}](./image_{i}.png)\n" 98 | model_description = f""" 99 | # Textual inversion text2image fine-tuning - {repo_id} 100 | These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n 101 | {img_str} 102 | """ 103 | model_card = load_or_create_model_card( 104 | repo_id_or_path=repo_id, 105 | from_training=True, 106 | license="creativeml-openrail-m", 107 | base_model=base_model, 108 | model_description=model_description, 109 | inference=True, 110 | ) 111 | 112 | tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "textual_inversion"] 113 | model_card = populate_model_card(model_card, tags=tags) 114 | 115 | model_card.save(os.path.join(repo_folder, "README.md")) 116 | 117 | 118 | def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): 119 | logger.info( 120 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 121 | f" {args.validation_prompt}." 122 | ) 123 | # create pipeline (note: unet and vae are loaded again in float32) 124 | pipeline = DiffusionPipeline.from_pretrained( 125 | args.pretrained_model_name_or_path, 126 | text_encoder=accelerator.unwrap_model(text_encoder), 127 | tokenizer=tokenizer, 128 | unet=unet, 129 | vae=vae, 130 | safety_checker=None, 131 | revision=args.revision, 132 | variant=args.variant, 133 | torch_dtype=weight_dtype, 134 | ) 135 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 136 | pipeline = pipeline.to(accelerator.device) 137 | pipeline.set_progress_bar_config(disable=True) 138 | 139 | # run inference 140 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) 141 | images = [] 142 | for _ in range(args.num_validation_images): 143 | with torch.autocast("cuda"): 144 | image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 145 | images.append(image) 146 | 147 | for tracker in accelerator.trackers: 148 | if tracker.name == "tensorboard": 149 | np_images = np.stack([np.asarray(img) for img in images]) 150 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 151 | if tracker.name == "wandb": 152 | tracker.log( 153 | { 154 | "validation": [ 155 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) 156 | ] 157 | } 158 | ) 159 | 160 | del pipeline 161 | torch.cuda.empty_cache() 162 | return images 163 | 164 | 165 | def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True): 166 | logger.info("Saving embeddings") 167 | learned_embeds = ( 168 | accelerator.unwrap_model(text_encoder) 169 | .get_input_embeddings() 170 | .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] 171 | ) 172 | learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} 173 | 174 | if safe_serialization: 175 | safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"}) 176 | else: 177 | torch.save(learned_embeds_dict, save_path) 178 | 179 | 180 | def parse_args(): 181 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 182 | parser.add_argument( 183 | "--save_steps", 184 | type=int, 185 | default=500, 186 | help="Save learned_embeds.bin every X updates steps.", 187 | ) 188 | parser.add_argument( 189 | "--save_as_full_pipeline", 190 | action="store_true", 191 | help="Save the complete stable diffusion pipeline.", 192 | ) 193 | parser.add_argument( 194 | "--num_vectors", 195 | type=int, 196 | default=1, 197 | help="How many textual inversion vectors shall be used to learn the concept.", 198 | ) 199 | parser.add_argument( 200 | "--pretrained_model_name_or_path", 201 | type=str, 202 | default=None, 203 | required=True, 204 | help="Path to pretrained model or model identifier from huggingface.co/models.", 205 | ) 206 | parser.add_argument( 207 | "--revision", 208 | type=str, 209 | default=None, 210 | required=False, 211 | help="Revision of pretrained model identifier from huggingface.co/models.", 212 | ) 213 | parser.add_argument( 214 | "--variant", 215 | type=str, 216 | default=None, 217 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 218 | ) 219 | parser.add_argument( 220 | "--tokenizer_name", 221 | type=str, 222 | default=None, 223 | help="Pretrained tokenizer name or path if not the same as model_name", 224 | ) 225 | parser.add_argument( 226 | "--train_data_dir", type=str, default='tmp', required=False, help="A folder containing the training data." 227 | ) 228 | parser.add_argument( 229 | "--placeholder_token", 230 | type=str, 231 | default=None, 232 | required=True, 233 | help="A token to use as a placeholder for the concept.", 234 | ) 235 | parser.add_argument( 236 | "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." 237 | ) 238 | parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") 239 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") 240 | parser.add_argument( 241 | "--output_dir", 242 | type=str, 243 | default="text-inversion-model", 244 | help="The output directory where the model predictions and checkpoints will be written.", 245 | ) 246 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 247 | parser.add_argument( 248 | "--resolution", 249 | type=int, 250 | default=512, 251 | help=( 252 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 253 | " resolution" 254 | ), 255 | ) 256 | parser.add_argument( 257 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution." 258 | ) 259 | parser.add_argument( 260 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 261 | ) 262 | parser.add_argument("--num_train_epochs", type=int, default=100) 263 | parser.add_argument( 264 | "--max_train_steps", 265 | type=int, 266 | default=5000, 267 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 268 | ) 269 | parser.add_argument( 270 | "--gradient_accumulation_steps", 271 | type=int, 272 | default=1, 273 | help="Number of updates steps to accumulate before performing a backward/update pass.", 274 | ) 275 | parser.add_argument( 276 | "--gradient_checkpointing", 277 | action="store_true", 278 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 279 | ) 280 | parser.add_argument( 281 | "--learning_rate", 282 | type=float, 283 | default=1e-4, 284 | help="Initial learning rate (after the potential warmup period) to use.", 285 | ) 286 | parser.add_argument( 287 | "--scale_lr", 288 | action="store_true", 289 | default=False, 290 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 291 | ) 292 | parser.add_argument( 293 | "--lr_scheduler", 294 | type=str, 295 | default="constant", 296 | help=( 297 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 298 | ' "constant", "constant_with_warmup"]' 299 | ), 300 | ) 301 | parser.add_argument( 302 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 303 | ) 304 | parser.add_argument( 305 | "--lr_num_cycles", 306 | type=int, 307 | default=1, 308 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 309 | ) 310 | parser.add_argument( 311 | "--dataloader_num_workers", 312 | type=int, 313 | default=0, 314 | help=( 315 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 316 | ), 317 | ) 318 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 319 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 320 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 321 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 322 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 323 | parser.add_argument("--prompts_file", type=str, default=None, help="prompt file.") 324 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 325 | parser.add_argument( 326 | "--hub_model_id", 327 | type=str, 328 | default=None, 329 | help="The name of the repository to keep in sync with the local `output_dir`.", 330 | ) 331 | parser.add_argument( 332 | "--logging_dir", 333 | type=str, 334 | default="logs", 335 | help=( 336 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 337 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 338 | ), 339 | ) 340 | parser.add_argument( 341 | "--mixed_precision", 342 | type=str, 343 | default="no", 344 | choices=["no", "fp16", "bf16"], 345 | help=( 346 | "Whether to use mixed precision. Choose" 347 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 348 | "and Nvidia Ampere GPU or Intel Gen 4 Xeon (and later) ." 349 | ), 350 | ) 351 | parser.add_argument( 352 | "--allow_tf32", 353 | action="store_true", 354 | help=( 355 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 356 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 357 | ), 358 | ) 359 | parser.add_argument( 360 | "--report_to", 361 | type=str, 362 | default="tensorboard", 363 | help=( 364 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 365 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 366 | ), 367 | ) 368 | parser.add_argument( 369 | "--validation_prompt", 370 | type=str, 371 | default=None, 372 | help="A prompt that is used during validation to verify that the model is learning.", 373 | ) 374 | parser.add_argument( 375 | "--num_validation_images", 376 | type=int, 377 | default=4, 378 | help="Number of images that should be generated during validation with `validation_prompt`.", 379 | ) 380 | parser.add_argument( 381 | "--validation_steps", 382 | type=int, 383 | default=100, 384 | help=( 385 | "Run validation every X steps. Validation consists of running the prompt" 386 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 387 | " and logging the images." 388 | ), 389 | ) 390 | parser.add_argument( 391 | "--validation_epochs", 392 | type=int, 393 | default=None, 394 | help=( 395 | "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" 396 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 397 | " and logging the images." 398 | ), 399 | ) 400 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 401 | parser.add_argument( 402 | "--checkpointing_steps", 403 | type=int, 404 | default=500, 405 | help=( 406 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 407 | " training using `--resume_from_checkpoint`." 408 | ), 409 | ) 410 | parser.add_argument( 411 | "--checkpoints_total_limit", 412 | type=int, 413 | default=None, 414 | help=("Max number of checkpoints to store."), 415 | ) 416 | parser.add_argument( 417 | "--resume_from_checkpoint", 418 | type=str, 419 | default=None, 420 | help=( 421 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 422 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 423 | ), 424 | ) 425 | parser.add_argument( 426 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 427 | ) 428 | parser.add_argument( 429 | "--no_safe_serialization", 430 | action="store_true", 431 | help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", 432 | ) 433 | 434 | args = parser.parse_args() 435 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 436 | if env_local_rank != -1 and env_local_rank != args.local_rank: 437 | args.local_rank = env_local_rank 438 | 439 | if args.train_data_dir is None: 440 | raise ValueError("You must specify a train data directory when using visual prompt sliders.") 441 | 442 | return args 443 | 444 | 445 | imagenet_templates_small = [ 446 | "a photo of a {}", 447 | "a rendering of a {}", 448 | "a cropped photo of the {}", 449 | "the photo of a {}", 450 | "a photo of a clean {}", 451 | "a photo of a dirty {}", 452 | "a dark photo of the {}", 453 | "a photo of my {}", 454 | "a photo of the cool {}", 455 | "a close-up photo of a {}", 456 | "a bright photo of the {}", 457 | "a cropped photo of a {}", 458 | "a photo of the {}", 459 | "a good photo of the {}", 460 | "a photo of one {}", 461 | "a close-up photo of the {}", 462 | "a rendition of the {}", 463 | "a photo of the clean {}", 464 | "a rendition of a {}", 465 | "a photo of a nice {}", 466 | "a good photo of a {}", 467 | "a photo of the nice {}", 468 | "a photo of the small {}", 469 | "a photo of the weird {}", 470 | "a photo of the large {}", 471 | "a photo of a cool {}", 472 | "a photo of a small {}", 473 | ] 474 | 475 | imagenet_style_templates_small = [ 476 | "a painting in the style of {}", 477 | "a rendering in the style of {}", 478 | "a cropped painting in the style of {}", 479 | "the painting in the style of {}", 480 | "a clean painting in the style of {}", 481 | "a dirty painting in the style of {}", 482 | "a dark painting in the style of {}", 483 | "a picture in the style of {}", 484 | "a cool painting in the style of {}", 485 | "a close-up painting in the style of {}", 486 | "a bright painting in the style of {}", 487 | "a cropped painting in the style of {}", 488 | "a good painting in the style of {}", 489 | "a close-up painting in the style of {}", 490 | "a rendition in the style of {}", 491 | "a nice painting in the style of {}", 492 | "a small painting in the style of {}", 493 | "a weird painting in the style of {}", 494 | "a large painting in the style of {}", 495 | ] 496 | 497 | 498 | class TextualInversionDataset(Dataset): 499 | def __init__( 500 | self, 501 | data_root, 502 | tokenizer, 503 | learnable_property="object", # [object, style] 504 | size=512, 505 | repeats=100, 506 | interpolation="bicubic", 507 | flip_p=0.5, 508 | set="train", 509 | placeholder_token="*", 510 | center_crop=False, 511 | ): 512 | self.data_root = data_root 513 | self.tokenizer = tokenizer 514 | self.learnable_property = learnable_property 515 | self.size = size 516 | self.placeholder_token = placeholder_token 517 | self.center_crop = center_crop 518 | self.flip_p = flip_p 519 | 520 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 521 | 522 | self.num_images = len(self.image_paths) 523 | self._length = self.num_images 524 | 525 | if set == "train": 526 | self._length = self.num_images * repeats 527 | 528 | self.interpolation = { 529 | "linear": PIL_INTERPOLATION["linear"], 530 | "bilinear": PIL_INTERPOLATION["bilinear"], 531 | "bicubic": PIL_INTERPOLATION["bicubic"], 532 | "lanczos": PIL_INTERPOLATION["lanczos"], 533 | }[interpolation] 534 | 535 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small 536 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 537 | 538 | def __len__(self): 539 | return self._length 540 | 541 | def __getitem__(self, i): 542 | example = {} 543 | image = Image.open(self.image_paths[i % self.num_images]) 544 | 545 | if not image.mode == "RGB": 546 | image = image.convert("RGB") 547 | 548 | placeholder_string = self.placeholder_token 549 | text = random.choice(self.templates).format(placeholder_string) 550 | 551 | example["input_ids"] = self.tokenizer( 552 | text, 553 | padding="max_length", 554 | truncation=True, 555 | max_length=self.tokenizer.model_max_length, 556 | return_tensors="pt", 557 | ).input_ids[0] 558 | 559 | # default to score-sde preprocessing 560 | img = np.array(image).astype(np.uint8) 561 | 562 | if self.center_crop: 563 | crop = min(img.shape[0], img.shape[1]) 564 | ( 565 | h, 566 | w, 567 | ) = ( 568 | img.shape[0], 569 | img.shape[1], 570 | ) 571 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 572 | 573 | image = Image.fromarray(img) 574 | image = image.resize((self.size, self.size), resample=self.interpolation) 575 | 576 | image = self.flip_transform(image) 577 | image = np.array(image).astype(np.uint8) 578 | image = (image / 127.5 - 1.0).astype(np.float32) 579 | 580 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 581 | return example 582 | 583 | 584 | def main(): 585 | args = parse_args() 586 | if args.report_to == "wandb" and args.hub_token is not None: 587 | raise ValueError( 588 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 589 | " Please use `huggingface-cli login` to authenticate with the Hub." 590 | ) 591 | 592 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 593 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 594 | accelerator = Accelerator( 595 | gradient_accumulation_steps=args.gradient_accumulation_steps, 596 | mixed_precision=args.mixed_precision, 597 | log_with=args.report_to, 598 | project_config=accelerator_project_config, 599 | ) 600 | 601 | if args.report_to == "wandb": 602 | if not is_wandb_available(): 603 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 604 | 605 | # Make one log on every process with the configuration for debugging. 606 | logging.basicConfig( 607 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 608 | datefmt="%m/%d/%Y %H:%M:%S", 609 | level=logging.INFO, 610 | ) 611 | logger.info(accelerator.state, main_process_only=False) 612 | if accelerator.is_local_main_process: 613 | transformers.utils.logging.set_verbosity_warning() 614 | diffusers.utils.logging.set_verbosity_info() 615 | else: 616 | transformers.utils.logging.set_verbosity_error() 617 | diffusers.utils.logging.set_verbosity_error() 618 | 619 | # If passed along, set the training seed now. 620 | if args.seed is not None: 621 | set_seed(args.seed) 622 | 623 | # Handle the repository creation 624 | if accelerator.is_main_process: 625 | if args.output_dir is not None: 626 | os.makedirs(args.output_dir, exist_ok=True) 627 | 628 | if args.push_to_hub: 629 | repo_id = create_repo( 630 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 631 | ).repo_id 632 | 633 | # Load tokenizer 634 | if args.tokenizer_name: 635 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 636 | elif args.pretrained_model_name_or_path: 637 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 638 | 639 | # Load scheduler and models 640 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 641 | text_encoder = CLIPTextModel.from_pretrained( 642 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 643 | ) 644 | vae = AutoencoderKL.from_pretrained( 645 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 646 | ) 647 | unet = UNet2DConditionModel.from_pretrained( 648 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 649 | ) 650 | 651 | # Add the placeholder token in tokenizer 652 | placeholder_tokens = [args.placeholder_token] 653 | 654 | if args.num_vectors < 1: 655 | raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}") 656 | 657 | # add dummy tokens for multi-vector 658 | additional_tokens = [] 659 | for i in range(1, args.num_vectors): 660 | additional_tokens.append(f"{args.placeholder_token}_{i}") 661 | placeholder_tokens += additional_tokens 662 | 663 | num_added_tokens = tokenizer.add_tokens(placeholder_tokens) 664 | if num_added_tokens != args.num_vectors: 665 | raise ValueError( 666 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 667 | " `placeholder_token` that is not already in the tokenizer." 668 | ) 669 | 670 | # Convert the initializer_token, placeholder_token to ids 671 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 672 | # Check if initializer_token is a single token or a sequence of tokens 673 | if len(token_ids) > 1: 674 | raise ValueError("The initializer token must be a single token.") 675 | 676 | initializer_token_id = token_ids[0] 677 | placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) 678 | 679 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 680 | text_encoder.resize_token_embeddings(len(tokenizer)) 681 | 682 | # Initialise the newly added placeholder token with the embeddings of the initializer token 683 | token_embeds = text_encoder.get_input_embeddings().weight.data 684 | with torch.no_grad(): 685 | for token_id in placeholder_token_ids: 686 | token_embeds[token_id] = token_embeds[initializer_token_id].clone() 687 | 688 | # Freeze vae and unet 689 | vae.requires_grad_(False) 690 | unet.requires_grad_(False) 691 | # Freeze all parameters except for the token embeddings in text encoder 692 | text_encoder.text_model.encoder.requires_grad_(False) 693 | text_encoder.text_model.final_layer_norm.requires_grad_(False) 694 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 695 | 696 | if args.gradient_checkpointing: 697 | # Keep unet in train mode if we are using gradient checkpointing to save memory. 698 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. 699 | unet.train() 700 | text_encoder.gradient_checkpointing_enable() 701 | unet.enable_gradient_checkpointing() 702 | 703 | if args.enable_xformers_memory_efficient_attention: 704 | if is_xformers_available(): 705 | import xformers 706 | 707 | xformers_version = version.parse(xformers.__version__) 708 | if xformers_version == version.parse("0.0.16"): 709 | logger.warn( 710 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 711 | ) 712 | unet.enable_xformers_memory_efficient_attention() 713 | else: 714 | raise ValueError("xformers is not available. Make sure it is installed correctly") 715 | 716 | # Enable TF32 for faster training on Ampere GPUs, 717 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 718 | if args.allow_tf32: 719 | torch.backends.cuda.matmul.allow_tf32 = True 720 | 721 | if args.scale_lr: 722 | args.learning_rate = ( 723 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 724 | ) 725 | 726 | # Initialize the optimizer 727 | optimizer = torch.optim.AdamW( 728 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 729 | lr=args.learning_rate, 730 | betas=(args.adam_beta1, args.adam_beta2), 731 | weight_decay=args.adam_weight_decay, 732 | eps=args.adam_epsilon, 733 | ) 734 | attributes = [] 735 | prompts = prompt_util.load_prompts_from_yaml(args.prompts_file, attributes) 736 | placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) 737 | criteria = torch.nn.MSELoss() 738 | 739 | cache = PromptEmbedsCache() 740 | prompt_pairs: list[PromptEmbedsPair] = [] 741 | 742 | with torch.no_grad(): 743 | for settings in prompts: 744 | print(settings) 745 | for prompt in [ 746 | settings.target, 747 | settings.positive, 748 | settings.neutral, 749 | settings.unconditional, 750 | ]: 751 | 752 | print(prompt) 753 | if isinstance(prompt, list): 754 | if prompt == settings.positive: 755 | key_setting = 'positive' 756 | else: 757 | key_setting = 'attributes' 758 | if len(prompt) == 0: 759 | cache[key_setting] = [] 760 | else: 761 | if cache[key_setting] is None: 762 | cache[key_setting] = train_util.encode_prompts( 763 | tokenizer, text_encoder, prompt 764 | ) 765 | else: 766 | if cache[prompt] == None: 767 | cache[prompt] = train_util.encode_prompts( 768 | tokenizer, text_encoder, [prompt] 769 | ) 770 | 771 | prompt_pairs.append( 772 | PromptEmbedsPair( 773 | criteria, 774 | cache[settings.target], 775 | cache[settings.positive], 776 | cache[settings.unconditional], 777 | cache[settings.neutral], 778 | settings, 779 | ) 780 | ) 781 | 782 | # Scheduler and math around the number of training steps. 783 | overrode_max_train_steps = False 784 | num_update_steps_per_epoch = math.ceil(args.max_train_steps / args.gradient_accumulation_steps) 785 | if args.max_train_steps is None: 786 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 787 | overrode_max_train_steps = True 788 | 789 | lr_scheduler = get_scheduler( 790 | args.lr_scheduler, 791 | optimizer=optimizer, 792 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 793 | num_training_steps=args.max_train_steps * accelerator.num_processes, 794 | num_cycles=args.lr_num_cycles, 795 | ) 796 | device = torch.device("cuda:0") 797 | 798 | text_encoder.train() 799 | # Prepare everything with our `accelerator`. 800 | text_encoder, optimizer, lr_scheduler = accelerator.prepare( 801 | text_encoder, optimizer, lr_scheduler 802 | ) 803 | 804 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 805 | # as these weights are only used for inference, keeping weights in full precision is not required. 806 | weight_dtype = torch.float32 807 | if accelerator.mixed_precision == "fp16": 808 | weight_dtype = torch.float16 809 | elif accelerator.mixed_precision == "bf16": 810 | weight_dtype = torch.bfloat16 811 | 812 | # Move vae and unet to device and cast to weight_dtype 813 | unet.to(accelerator.device, dtype=weight_dtype) 814 | vae.to(accelerator.device, dtype=weight_dtype) 815 | 816 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 817 | num_update_steps_per_epoch = math.ceil(1 / args.gradient_accumulation_steps) 818 | if overrode_max_train_steps: 819 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 820 | # Afterwards we recalculate our number of training epochs 821 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 822 | 823 | # We need to initialize the trackers we use, and also store our configuration. 824 | # The trackers initializes automatically on the main process. 825 | if accelerator.is_main_process: 826 | accelerator.init_trackers("textual_inversion", config=vars(args)) 827 | 828 | # Train! 829 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 830 | 831 | logger.info("***** Running training *****") 832 | logger.info(f" Num examples = {args.max_train_steps}") 833 | logger.info(f" Num Epochs = {args.num_train_epochs}") 834 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 835 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 836 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 837 | logger.info(f" Total optimization steps = {args.max_train_steps}") 838 | global_step = 0 839 | first_epoch = 0 840 | # Potentially load in the weights and states from a previous save 841 | if args.resume_from_checkpoint: 842 | if args.resume_from_checkpoint != "latest": 843 | path = os.path.basename(args.resume_from_checkpoint) 844 | else: 845 | # Get the most recent checkpoint 846 | dirs = os.listdir(args.output_dir) 847 | dirs = [d for d in dirs if d.startswith("checkpoint")] 848 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 849 | path = dirs[-1] if len(dirs) > 0 else None 850 | 851 | if path is None: 852 | accelerator.print( 853 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 854 | ) 855 | args.resume_from_checkpoint = None 856 | initial_global_step = 0 857 | else: 858 | accelerator.print(f"Resuming from checkpoint {path}") 859 | accelerator.load_state(os.path.join(args.output_dir, path)) 860 | global_step = int(path.split("-")[1]) 861 | 862 | initial_global_step = global_step 863 | first_epoch = global_step // num_update_steps_per_epoch 864 | 865 | else: 866 | initial_global_step = 0 867 | 868 | progress_bar = tqdm( 869 | range(0, args.max_train_steps), 870 | initial=initial_global_step, 871 | desc="Steps", 872 | # Only show the progress bar once on each machine. 873 | disable=not accelerator.is_local_main_process, 874 | ) 875 | 876 | # keep original embeddings as reference 877 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() 878 | 879 | for epoch in range(first_epoch, args.num_train_epochs): 880 | text_encoder.train() 881 | 882 | with torch.no_grad(): 883 | noise_scheduler.set_timesteps( 884 | 50, device=device 885 | ) 886 | 887 | optimizer.zero_grad() 888 | 889 | prompt_pair: PromptEmbedsPair = prompt_pairs[ 890 | torch.randint(0, len(prompt_pairs), (1,)).item() 891 | ] 892 | tprompt = prompt_pair.settings.target + f', {placeholder_token}' 893 | sc = float(random.choice([idx for idx in range(3)])) 894 | ti_prompt_1 = train_util.encode_prompts_slider( 895 | tokenizer, text_encoder, [tprompt], sc=sc, 896 | ) 897 | 898 | # 1 ~ 49 からランダム 899 | timesteps_to = torch.randint( 900 | 1, 50, (1,) 901 | ).item() 902 | 903 | height, width = ( 904 | prompt_pair.resolution, 905 | prompt_pair.resolution, 906 | ) 907 | if prompt_pair.dynamic_resolution: 908 | height, width = train_util.get_random_resolution_in_bucket( 909 | prompt_pair.resolution 910 | ) 911 | 912 | latents = train_util.get_initial_latents( 913 | noise_scheduler, prompt_pair.batch_size, height, width, 1 914 | ).to(device, dtype=weight_dtype) 915 | 916 | 917 | # ちょっとデノイズされれたものが返る 918 | denoised_latents = train_util.diffusion( 919 | unet, 920 | noise_scheduler, 921 | latents, # 単純なノイズのlatentsを渡す 922 | train_util.concat_embeddings( 923 | prompt_pair.unconditional.to(device), 924 | ti_prompt_1, 925 | prompt_pair.batch_size, 926 | ), 927 | start_timesteps=0, 928 | total_timesteps=timesteps_to, 929 | guidance_scale=3, 930 | ) 931 | 932 | noise_scheduler.set_timesteps(1000) 933 | 934 | current_timestep = noise_scheduler.timesteps[ 935 | int(timesteps_to * 1000 /50) 936 | ] 937 | 938 | # with network: の外では空のLoRAのみが有効になる 939 | positive_latents = train_util.predict_noise( 940 | unet, 941 | noise_scheduler, 942 | current_timestep, 943 | denoised_latents, 944 | train_util.concat_embeddings( 945 | prompt_pair.unconditional, 946 | prompt_pair.positive, 947 | prompt_pair.batch_size, 948 | ).to(device), 949 | guidance_scale=1, 950 | ).to(device, dtype=weight_dtype) 951 | 952 | neutral_latents = train_util.predict_noise( 953 | unet, 954 | noise_scheduler, 955 | current_timestep, 956 | denoised_latents, 957 | train_util.concat_embeddings( 958 | prompt_pair.unconditional, 959 | prompt_pair.neutral, 960 | prompt_pair.batch_size, 961 | ).to(device), 962 | guidance_scale=1, 963 | ).to(device, dtype=weight_dtype) 964 | unconditional_latents = train_util.predict_noise( 965 | unet, 966 | noise_scheduler, 967 | current_timestep, 968 | denoised_latents, 969 | train_util.concat_embeddings( 970 | prompt_pair.unconditional, 971 | prompt_pair.unconditional, 972 | prompt_pair.batch_size, 973 | ).to(device), 974 | guidance_scale=1, 975 | ).to(device, dtype=weight_dtype) 976 | 977 | 978 | with accelerator.accumulate(text_encoder): 979 | 980 | ti_prompt = train_util.encode_prompts_slider( 981 | tokenizer, text_encoder, [tprompt], sc=sc, 982 | ) 983 | target_latents = train_util.predict_noise( 984 | unet, 985 | noise_scheduler, 986 | current_timestep, 987 | denoised_latents, 988 | train_util.concat_embeddings( 989 | prompt_pair.unconditional.to(device), 990 | ti_prompt, 991 | prompt_pair.batch_size, 992 | ), 993 | guidance_scale=1, 994 | ).to(device, dtype=weight_dtype) 995 | 996 | positive_latents.requires_grad = False 997 | neutral_latents.requires_grad = False 998 | unconditional_latents.requires_grad = False 999 | 1000 | loss = prompt_pair.loss( 1001 | target_latents=target_latents, 1002 | positive_latents=positive_latents, 1003 | neutral_latents=neutral_latents, 1004 | unconditional_latents=unconditional_latents, 1005 | scale=sc, 1006 | ) 1007 | 1008 | accelerator.backward(loss) 1009 | 1010 | optimizer.step() 1011 | lr_scheduler.step() 1012 | optimizer.zero_grad() 1013 | 1014 | # Let's make sure we don't update any embedding weights besides the newly added token 1015 | index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) 1016 | index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False 1017 | 1018 | with torch.no_grad(): 1019 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 1020 | index_no_updates 1021 | ] = orig_embeds_params[index_no_updates] 1022 | 1023 | # Checks if the accelerator has performed an optimization step behind the scenes 1024 | if accelerator.sync_gradients: 1025 | images = [] 1026 | progress_bar.update(1) 1027 | global_step += 1 1028 | if global_step % args.save_steps == 0: 1029 | weight_name = ( 1030 | f"learned_embeds-steps-{global_step}.bin" 1031 | if args.no_safe_serialization 1032 | else f"learned_embeds-steps-{global_step}.safetensors" 1033 | ) 1034 | save_path = os.path.join(args.output_dir, weight_name) 1035 | save_progress( 1036 | text_encoder, 1037 | placeholder_token_ids, 1038 | accelerator, 1039 | args, 1040 | save_path, 1041 | safe_serialization=not args.no_safe_serialization, 1042 | ) 1043 | 1044 | if accelerator.is_main_process: 1045 | if global_step % args.checkpointing_steps == 0: 1046 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1047 | if args.checkpoints_total_limit is not None: 1048 | checkpoints = os.listdir(args.output_dir) 1049 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1050 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1051 | 1052 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1053 | if len(checkpoints) >= args.checkpoints_total_limit: 1054 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1055 | removing_checkpoints = checkpoints[0:num_to_remove] 1056 | 1057 | logger.info( 1058 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1059 | ) 1060 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1061 | 1062 | for removing_checkpoint in removing_checkpoints: 1063 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1064 | shutil.rmtree(removing_checkpoint) 1065 | 1066 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1067 | accelerator.save_state(save_path) 1068 | logger.info(f"Saved state to {save_path}") 1069 | 1070 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 1071 | images = log_validation( 1072 | text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch 1073 | ) 1074 | 1075 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1076 | progress_bar.set_postfix(**logs) 1077 | accelerator.log(logs, step=global_step) 1078 | 1079 | if global_step >= args.max_train_steps: 1080 | break 1081 | # Create the pipeline using the trained modules and save it. 1082 | accelerator.wait_for_everyone() 1083 | if accelerator.is_main_process: 1084 | if args.push_to_hub and not args.save_as_full_pipeline: 1085 | logger.warn("Enabling full model saving because --push_to_hub=True was specified.") 1086 | save_full_model = True 1087 | else: 1088 | save_full_model = args.save_as_full_pipeline 1089 | if save_full_model: 1090 | pipeline = StableDiffusionPipeline.from_pretrained( 1091 | args.pretrained_model_name_or_path, 1092 | text_encoder=accelerator.unwrap_model(text_encoder), 1093 | vae=vae, 1094 | unet=unet, 1095 | tokenizer=tokenizer, 1096 | ) 1097 | pipeline.save_pretrained(args.output_dir) 1098 | # Save the newly trained embeddings 1099 | weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors" 1100 | save_path = os.path.join(args.output_dir, weight_name) 1101 | save_progress( 1102 | text_encoder, 1103 | placeholder_token_ids, 1104 | accelerator, 1105 | args, 1106 | save_path, 1107 | safe_serialization=not args.no_safe_serialization, 1108 | ) 1109 | 1110 | if args.push_to_hub: 1111 | save_model_card( 1112 | repo_id, 1113 | images=images, 1114 | base_model=args.pretrained_model_name_or_path, 1115 | repo_folder=args.output_dir, 1116 | ) 1117 | upload_folder( 1118 | repo_id=repo_id, 1119 | folder_path=args.output_dir, 1120 | commit_message="End of training", 1121 | ignore_patterns=["step_*", "epoch_*"], 1122 | ) 1123 | 1124 | accelerator.end_training() 1125 | 1126 | 1127 | if __name__ == "__main__": 1128 | main() 1129 | --------------------------------------------------------------------------------