├── .gitignore ├── LICENSE ├── README.md ├── examples ├── LEdits.ipynb ├── images │ ├── cat.jpeg │ ├── cat.jpg │ ├── cat_3.jpeg │ ├── glasses.jpg │ ├── landscape.jpg │ ├── pexels-tennis.jpg │ ├── placeholder.txt │ ├── portrait.png │ ├── vase_01.jpeg │ ├── vulcano.jpg │ └── yann-lecun.jpg └── teaser.png ├── requirements.txt ├── setup.py └── src └── leditspp ├── __init__.py ├── pipeline_if_ledits.py ├── pipeline_stable_diffusion_ledits.py ├── pipeline_stable_diffusion_xl_ledits.py └── scheduling_dpmsolver_multistep_inject.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ml-research@TUDarmstadt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LEdits++ 2 | 3 | Official Implementation of the [Paper](https://arxiv.org/abs/2311.16711) **LEDITS++: Limitless Image Editing using Text-to-Image Models**. 4 | 5 | 6 | ## Interactive Demo 7 | An interactive demonstration is available in Colab and on Huggingface [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/ledits_pp/blob/main/examples/LEdits.ipynb) [![Huggingface Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/editing-images/ledtisplusplus) 8 | 9 | [Project Page](https://leditsplusplus-project.static.hf.space/index.html) 10 | 11 | ![Examples](./examples/teaser.png) 12 | 13 | ## Installation 14 | LEdits++ is fully integrated in the ```diffusers``` library as ```LEditsPPPipelineStableDiffusion``` and ```LEditsPPPipelineStableDiffusionXL```, respectively. Just install diffusers to use it: 15 | 16 | ```cmd 17 | pip install diffusers 18 | ``` 19 | 20 | Notably, the diffusers implementation does NOT guarantee perfect inversion. If that is a required property for your use case or you are performing research based on LEdits++, we recommend using the implementation in this repository instead. 21 | 22 | You can clone this repository and install it locally by running 23 | 24 | ```cmd 25 | git clone https://github.com/ml-research/ledits_pp.git 26 | cd ./semantic-image-editing 27 | pip install . 28 | ``` 29 | or install it directly from git 30 | ```cmd 31 | pip install git+https://github.com/ml-research/ledits_pp.git 32 | ``` 33 | 34 | ## Usage 35 | This repository provides 3 new diffusion pipelines supporting image editing based on the [diffusers](https://github.com/huggingface/diffusers) library. 36 | The ```StableDiffusionPipeline_LEDITS```, ```StableDiffusionPipelineXL_LEDITS``` and ```IFDiffusion_LEDITS``` extend the respective diffusers pipelines and can therefore be loaded from any corresponding pre-trained checkpoint like shown below. 37 | 38 | 39 | 40 | 41 | ```python 42 | import PIL 43 | import requests 44 | import torch 45 | from io import BytesIO 46 | from leditspp.scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInject 47 | from leditspp import StableDiffusionPipeline_LEDITS 48 | 49 | model = 'runwayml/stable-diffusion-v1-5' 50 | #model = '/workspace/StableDiff/models/stable-diffusion-v1-5' 51 | 52 | device = 'cuda' 53 | 54 | pipe = StableDiffusionPipeline_LEDITS.from_pretrained(model,safety_checker = None,) 55 | pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(model, subfolder="scheduler" 56 | , algorithm_type="sde-dpmsolver++", solver_order=2) 57 | pipe.to(device) 58 | 59 | 60 | An exemplary usage of the pipeline could look like this: 61 | ```python 62 | def download_image(url): 63 | response = requests.get(url) 64 | return PIL.Image.open(BytesIO(response.content)).convert("RGB") 65 | gen = torch.Generator(device=device) 66 | 67 | gen.manual_seed(21) 68 | img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png" 69 | image = download_image(img_url) 70 | _ = pipe.invert( image = image, 71 | num_inversion_steps=50, 72 | skip=0.1 73 | ) 74 | edited_image = pipe( 75 | editing_prompt=["cherry blossom"], 76 | edit_guidance_scale=10.0, 77 | edit_threshold=0.75, 78 | ).images[0] 79 | ``` 80 | 81 | ## Citation 82 | If you like or use our work please cite us: 83 | ```bibtex 84 | @article{brack2023Sega, 85 | title={LEDITS++: Limitless Image Editing using Text-to-Image Models}, 86 | author={Manuel Brack and Felix Friedrich and Katharina Kornmeier and Linoy Tsaban and Patrick Schramowski and Kristian Kersting and Apolinário Passos}, 87 | year={2023}, 88 | eprint={2311.16711}, 89 | archivePrefix={arXiv}, 90 | primaryClass={cs.CV} 91 | } 92 | ``` 93 | 94 | -------------------------------------------------------------------------------- /examples/images/cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/cat.jpeg -------------------------------------------------------------------------------- /examples/images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/cat.jpg -------------------------------------------------------------------------------- /examples/images/cat_3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/cat_3.jpeg -------------------------------------------------------------------------------- /examples/images/glasses.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/glasses.jpg -------------------------------------------------------------------------------- /examples/images/landscape.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/landscape.jpg -------------------------------------------------------------------------------- /examples/images/pexels-tennis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/pexels-tennis.jpg -------------------------------------------------------------------------------- /examples/images/placeholder.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/images/portrait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/portrait.png -------------------------------------------------------------------------------- /examples/images/vase_01.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/vase_01.jpeg -------------------------------------------------------------------------------- /examples/images/vulcano.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/vulcano.jpg -------------------------------------------------------------------------------- /examples/images/yann-lecun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/images/yann-lecun.jpg -------------------------------------------------------------------------------- /examples/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/ledits_pp/c2ccffb9d65d39da6f696676d85d09b62fdc8bd5/examples/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.20.2 2 | pillow 3 | transformers 4 | accelerate 5 | jupyter 6 | sentencepiece 7 | ftfy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="leditspp", 8 | version="1.0.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 9 | description="Semantic Image Editing", 10 | long_description=open("README.md", "r", encoding="utf-8").read(), 11 | long_description_content_type="text/markdown", 12 | keywords="deep learning", 13 | license="MIT", 14 | author="Manuel Brack", 15 | author_email="brack@cs.tu-darmstadt.de", 16 | url="https://github.com/ml-research/ledits_pp", 17 | package_dir={"": "src"}, 18 | packages=find_packages("src"), 19 | install_requires=[ 20 | str(r) 21 | for r in pkg_resources.parse_requirements( 22 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 23 | ) 24 | ], 25 | include_package_data=True, 26 | ) 27 | -------------------------------------------------------------------------------- /src/leditspp/__init__.py: -------------------------------------------------------------------------------- 1 | from diffusers.utils import is_torch_available, is_transformers_available 2 | 3 | 4 | 5 | if is_transformers_available() and is_torch_available(): 6 | from .pipeline_stable_diffusion_ledits import StableDiffusionPipeline_LEDITS 7 | from .pipeline_stable_diffusion_xl_ledits import StableDiffusionPipelineXL_LEDITS 8 | from .pipeline_if_ledits import IFDiffusion_LEDITS 9 | -------------------------------------------------------------------------------- /src/leditspp/pipeline_if_ledits.py: -------------------------------------------------------------------------------- 1 | import html 2 | import inspect 3 | import re 4 | import urllib.parse as ul 5 | from typing import Any, Callable, Dict, List, Optional, Union 6 | from itertools import repeat 7 | 8 | from diffusers.utils import pt_to_pil 9 | 10 | import numpy as np 11 | import PIL 12 | from tqdm import tqdm 13 | import torch 14 | import torch.nn.functional as F 15 | import math 16 | from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer 17 | 18 | from diffusers.models.attention_processor import AttnProcessor, Attention, AttnAddedKVProcessor 19 | from diffusers.models import UNet2DConditionModel 20 | from diffusers.schedulers import DDIMScheduler 21 | from diffusers.utils import ( 22 | BACKENDS_MAPPING, 23 | PIL_INTERPOLATION, 24 | is_accelerate_available, 25 | is_accelerate_version, 26 | is_bs4_available, 27 | is_ftfy_available, 28 | logging, 29 | randn_tensor, 30 | replace_example_docstring, 31 | ) 32 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 33 | from diffusers.pipelines.deepfloyd_if import IFPipelineOutput 34 | from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker 35 | from diffusers.pipelines.deepfloyd_if.watermark import IFWatermarker 36 | 37 | from .scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInject 38 | 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | if is_bs4_available(): 43 | from bs4 import BeautifulSoup 44 | 45 | if is_ftfy_available(): 46 | import ftfy 47 | 48 | def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: 49 | w, h = images.size 50 | 51 | coef = w / h 52 | 53 | w, h = img_size, img_size 54 | 55 | if coef >= 1: 56 | w = int(round(img_size / 8 * coef) * 8) 57 | else: 58 | h = int(round(img_size / 8 / coef) * 8) 59 | 60 | images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) 61 | 62 | return images 63 | 64 | def reset_dpm(scheduler): 65 | if isinstance(scheduler, DPMSolverMultistepSchedulerInject): 66 | scheduler.model_outputs = [ 67 | None, 68 | ] * scheduler.config.solver_order 69 | scheduler.lower_order_nums = 0 70 | 71 | 72 | EXAMPLE_DOC_STRING = """ 73 | Examples: 74 | ```py 75 | >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline 76 | >>> from diffusers.utils import pt_to_pil 77 | >>> import torch 78 | 79 | >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) 80 | >>> pipe.enable_model_cpu_offload() 81 | 82 | >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' 83 | >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) 84 | 85 | >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images 86 | 87 | >>> # save intermediate image 88 | >>> pil_image = pt_to_pil(image) 89 | >>> pil_image[0].save("./if_stage_I.png") 90 | 91 | >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( 92 | ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 93 | ... ) 94 | >>> super_res_1_pipe.enable_model_cpu_offload() 95 | 96 | >>> image = super_res_1_pipe( 97 | ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" 98 | ... ).images 99 | 100 | >>> # save intermediate image 101 | >>> pil_image = pt_to_pil(image) 102 | >>> pil_image[0].save("./if_stage_I.png") 103 | 104 | >>> safety_modules = { 105 | ... "feature_extractor": pipe.feature_extractor, 106 | ... "safety_checker": pipe.safety_checker, 107 | ... "watermarker": pipe.watermarker, 108 | ... } 109 | >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( 110 | ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 111 | ... ) 112 | >>> super_res_2_pipe.enable_model_cpu_offload() 113 | 114 | >>> image = super_res_2_pipe( 115 | ... prompt=prompt, 116 | ... image=image, 117 | ... ).images 118 | >>> image[0].save("./if_stage_II.png") 119 | ``` 120 | """ 121 | 122 | 123 | class AttentionStore(): 124 | @staticmethod 125 | def get_empty_store(): 126 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 127 | "down_self": [], "mid_self": [], "up_self": []} 128 | 129 | def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP): 130 | # attn.shape = batch_size * head_size, seq_len query, seq_len_key 131 | bs = 2 + int(PnP) + editing_prompts 132 | source_batch_size = int(attn.shape[0] // bs) 133 | skip = 2 if PnP else 1 # skip PnP & unconditional 134 | self.forward( 135 | attn[skip*source_batch_size:], 136 | is_cross, 137 | place_in_unet) 138 | 139 | def forward(self, attn, is_cross: bool, place_in_unet: str): 140 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 141 | if attn.shape[1] == 16 ** 2 or attn.shape[1] == 8 ** 2: # avoid memory overhead 142 | self.step_store[key].append(attn) 143 | 144 | def between_steps(self, store_step=True): 145 | if store_step: 146 | if self.average: 147 | if len(self.attention_store) == 0: 148 | self.attention_store = self.step_store 149 | else: 150 | for key in self.attention_store: 151 | for i in range(len(self.attention_store[key])): 152 | self.attention_store[key][i] += self.step_store[key][i] 153 | else: 154 | if len(self.attention_store) == 0: 155 | self.attention_store = [self.step_store] 156 | else: 157 | self.attention_store.append(self.step_store) 158 | 159 | self.cur_step += 1 160 | self.step_store = self.get_empty_store() 161 | 162 | def get_attention(self, step: int): 163 | if self.average: 164 | attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 165 | else: 166 | assert(step is not None) 167 | attention = self.attention_store[step] 168 | return attention 169 | 170 | def aggregate_attention(self, attention_maps, prompts, res: int, 171 | from_where: List[str], is_cross: bool, select: int 172 | ): 173 | out = [] 174 | num_pixels = res ** 2 175 | for location in from_where: 176 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 177 | if item.shape[1] == num_pixels: 178 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 179 | out.append(cross_maps) 180 | out = torch.cat(out, dim=0) 181 | # average over heads 182 | out = out.sum(0) / out.shape[0] 183 | return out 184 | 185 | def __init__(self, average: bool): 186 | self.step_store = self.get_empty_store() 187 | self.attention_store = [] 188 | self.cur_step = 0 189 | self.average = average 190 | 191 | class CrossAttnProcessor: 192 | 193 | def __init__(self, attention_store, place_in_unet, PnP, editing_prompts): 194 | self.attnstore = attention_store 195 | self.place_in_unet = place_in_unet 196 | self.editing_prompts = editing_prompts 197 | self.PnP = PnP 198 | 199 | def __call__( 200 | self, 201 | attn: Attention, 202 | hidden_states, 203 | encoder_hidden_states=None, 204 | attention_mask=None, 205 | temb=None, 206 | ): 207 | residual = hidden_states 208 | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) 209 | batch_size, sequence_length, _ = hidden_states.shape 210 | 211 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 212 | 213 | is_cross = True 214 | if encoder_hidden_states is None: 215 | encoder_hidden_states = hidden_states 216 | is_cross = False 217 | elif attn.norm_cross: 218 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 219 | 220 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 221 | 222 | query = attn.to_q(hidden_states) 223 | query = attn.head_to_batch_dim(query) 224 | 225 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 226 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 227 | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) 228 | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) 229 | 230 | if not attn.only_cross_attention: 231 | key = attn.to_k(hidden_states) 232 | value = attn.to_v(hidden_states) 233 | key = attn.head_to_batch_dim(key) 234 | value = attn.head_to_batch_dim(value) 235 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) 236 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) 237 | else: 238 | key = encoder_hidden_states_key_proj 239 | value = encoder_hidden_states_value_proj 240 | 241 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 242 | 243 | if is_cross: 244 | self.attnstore(attention_probs, 245 | is_cross=True, 246 | place_in_unet=self.place_in_unet, 247 | editing_prompts=self.editing_prompts, 248 | PnP=self.PnP) 249 | 250 | hidden_states = torch.bmm(attention_probs, value) 251 | hidden_states = attn.batch_to_head_dim(hidden_states) 252 | 253 | # linear proj 254 | hidden_states = attn.to_out[0](hidden_states) 255 | # dropout 256 | hidden_states = attn.to_out[1](hidden_states) 257 | 258 | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) 259 | hidden_states = hidden_states + residual 260 | 261 | return hidden_states 262 | 263 | # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing 264 | class GaussianSmoothing(): 265 | 266 | def __init__(self, device): 267 | kernel_size = [3, 3] 268 | sigma = [0.5, 0.5] 269 | 270 | # The gaussian kernel is the product of the gaussian function of each dimension. 271 | kernel = 1 272 | meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) 273 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 274 | mean = (size - 1) / 2 275 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) 276 | 277 | # Make sure sum of values in gaussian kernel equals 1. 278 | kernel = kernel / torch.sum(kernel) 279 | 280 | # Reshape to depthwise convolutional weight 281 | kernel = kernel.view(1, 1, *kernel.size()) 282 | kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) 283 | 284 | self.weight = kernel.to(device) 285 | 286 | def __call__(self, input): 287 | """ 288 | Arguments: 289 | Apply gaussian filter to input. 290 | input (torch.Tensor): Input to apply gaussian filter on. 291 | Returns: 292 | filtered (torch.Tensor): Filtered output. 293 | """ 294 | return F.conv2d(input, weight=self.weight.to(input.dtype)) 295 | 296 | class IFDiffusion_LEDITS(DiffusionPipeline): 297 | tokenizer: T5Tokenizer 298 | text_encoder: T5EncoderModel 299 | 300 | unet: UNet2DConditionModel 301 | scheduler: DDIMScheduler 302 | 303 | feature_extractor: Optional[CLIPImageProcessor] 304 | safety_checker: Optional[IFSafetyChecker] 305 | 306 | watermarker: Optional[IFWatermarker] 307 | 308 | bad_punct_regex = re.compile( 309 | r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" 310 | ) # noqa 311 | 312 | _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] 313 | 314 | def __init__( 315 | self, 316 | tokenizer: T5Tokenizer, 317 | text_encoder: T5EncoderModel, 318 | unet: UNet2DConditionModel, 319 | scheduler: DDIMScheduler, 320 | safety_checker: Optional[IFSafetyChecker], 321 | feature_extractor: Optional[CLIPImageProcessor], 322 | watermarker: Optional[IFWatermarker], 323 | requires_safety_checker: bool = True, 324 | ): 325 | super().__init__() 326 | 327 | if not isinstance(scheduler, DDIMScheduler) or not isinstance(scheduler, DPMSolverMultistepSchedulerInject): 328 | conf = scheduler.config 329 | conf["clip_sample"] = False 330 | conf["thresholding"] = False 331 | conf["variance_type"] = "fixed" 332 | 333 | scheduler = DPMSolverMultistepSchedulerInject.from_config(conf, algorithm_type="sde-dpmsolver++", solver_order=2) 334 | logger.warning("This pipeline only supports DDIMScheduler and DPMSolverMultistepSchedulerInject. " 335 | "The scheduler has been changed to DPMSolverMultistepSchedulerInject.") 336 | 337 | 338 | if safety_checker is None and requires_safety_checker: 339 | logger.warning( 340 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 341 | " that you abide to the conditions of the IF license and do not expose unfiltered" 342 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 343 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 344 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 345 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 346 | ) 347 | 348 | if safety_checker is not None and feature_extractor is None: 349 | raise ValueError( 350 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 351 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 352 | ) 353 | 354 | self.register_modules( 355 | tokenizer=tokenizer, 356 | text_encoder=text_encoder, 357 | unet=unet, 358 | scheduler=scheduler, 359 | safety_checker=safety_checker, 360 | feature_extractor=feature_extractor, 361 | watermarker=watermarker, 362 | ) 363 | self.register_to_config(requires_safety_checker=requires_safety_checker) 364 | 365 | def enable_model_cpu_offload(self, gpu_id=0): 366 | r""" 367 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 368 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 369 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 370 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 371 | """ 372 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 373 | from accelerate import cpu_offload_with_hook 374 | else: 375 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") 376 | 377 | device = torch.device(f"cuda:{gpu_id}") 378 | 379 | if self.device.type != "cpu": 380 | self.to("cpu", silence_dtype_warnings=True) 381 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 382 | 383 | hook = None 384 | 385 | if self.text_encoder is not None: 386 | _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) 387 | 388 | # Accelerate will move the next model to the device _before_ calling the offload hook of the 389 | # previous model. This will cause both models to be present on the device at the same time. 390 | # IF uses T5 for its text encoder which is really large. We can manually call the offload 391 | # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to 392 | # the GPU. 393 | self.text_encoder_offload_hook = hook 394 | 395 | _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) 396 | 397 | # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet 398 | self.unet_offload_hook = hook 399 | 400 | if self.safety_checker is not None: 401 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) 402 | 403 | # We'll offload the last model manually. 404 | self.final_offload_hook = hook 405 | 406 | def remove_all_hooks(self): 407 | if is_accelerate_available(): 408 | from accelerate.hooks import remove_hook_from_module 409 | else: 410 | raise ImportError("Please install accelerate via `pip install accelerate`") 411 | 412 | for model in [self.text_encoder, self.unet, self.safety_checker]: 413 | if model is not None: 414 | remove_hook_from_module(model, recurse=True) 415 | 416 | self.unet_offload_hook = None 417 | self.text_encoder_offload_hook = None 418 | self.final_offload_hook = None 419 | 420 | @torch.no_grad() 421 | def encode_prompt( 422 | self, 423 | prompt, 424 | do_classifier_free_guidance=True, 425 | num_images_per_prompt=1, 426 | device=None, 427 | negative_prompt=None, 428 | editing_prompt=None, 429 | prompt_embeds: Optional[torch.FloatTensor] = None, 430 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 431 | edit_prompt_embeds: Optional[torch.FloatTensor] = None, 432 | clean_caption: bool = False, 433 | ): 434 | r""" 435 | Encodes the prompt into text encoder hidden states. 436 | 437 | Args: 438 | prompt (`str` or `List[str]`, *optional*): 439 | prompt to be encoded 440 | device: (`torch.device`, *optional*): 441 | torch device to place the resulting embeddings on 442 | num_images_per_prompt (`int`, *optional*, defaults to 1): 443 | number of images that should be generated per prompt 444 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): 445 | whether to use classifier free guidance or not 446 | negative_prompt (`str` or `List[str]`, *optional*): 447 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 448 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 449 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 450 | editing_prompt (`str` or `List[str]`, *optional*): 451 | The prompt used for semantic guidance 452 | prompt_embeds (`torch.FloatTensor`, *optional*): 453 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 454 | provided, text embeddings will be generated from `prompt` input argument. 455 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 456 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 457 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 458 | argument. 459 | """ 460 | if prompt is not None and negative_prompt is not None: 461 | if type(prompt) is not type(negative_prompt): 462 | raise TypeError( 463 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 464 | f" {type(prompt)}." 465 | ) 466 | 467 | if device is None: 468 | device = self._execution_device 469 | 470 | if prompt is not None and isinstance(prompt, str): 471 | batch_size = 1 472 | elif prompt is not None and isinstance(prompt, list): 473 | batch_size = len(prompt) 474 | else: 475 | batch_size = prompt_embeds.shape[0] 476 | 477 | # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF 478 | max_length = 77 479 | 480 | if prompt_embeds is None: 481 | prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) 482 | text_inputs = self.tokenizer( 483 | prompt, 484 | padding="max_length", 485 | max_length=max_length, 486 | truncation=True, 487 | add_special_tokens=True, 488 | return_tensors="pt", 489 | ) 490 | text_input_ids = text_inputs.input_ids 491 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 492 | 493 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 494 | text_input_ids, untruncated_ids 495 | ): 496 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) 497 | logger.warning( 498 | "The following part of your input was truncated because CLIP can only handle sequences up to" 499 | f" {max_length} tokens: {removed_text}" 500 | ) 501 | 502 | attention_mask = text_inputs.attention_mask.to(device) 503 | 504 | prompt_embeds = self.text_encoder( 505 | text_input_ids.to(device), 506 | attention_mask=attention_mask, 507 | ) 508 | prompt_embeds = prompt_embeds[0] 509 | 510 | if self.text_encoder is not None: 511 | dtype = self.text_encoder.dtype 512 | elif self.unet is not None: 513 | dtype = self.unet.dtype 514 | else: 515 | dtype = None 516 | 517 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 518 | 519 | bs_embed, seq_len, _ = prompt_embeds.shape 520 | # duplicate text embeddings for each generation per prompt, using mps friendly method 521 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 522 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 523 | 524 | # get unconditional embeddings for classifier free guidance 525 | if do_classifier_free_guidance and negative_prompt_embeds is None: 526 | uncond_tokens: List[str] 527 | if negative_prompt is None: 528 | uncond_tokens = [""] * batch_size 529 | elif isinstance(negative_prompt, str): 530 | uncond_tokens = [negative_prompt] 531 | elif batch_size != len(negative_prompt): 532 | raise ValueError( 533 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 534 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 535 | " the batch size of `prompt`." 536 | ) 537 | else: 538 | uncond_tokens = negative_prompt 539 | 540 | uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) 541 | max_length = prompt_embeds.shape[1] 542 | uncond_input = self.tokenizer( 543 | uncond_tokens, 544 | padding="max_length", 545 | max_length=max_length, 546 | truncation=True, 547 | return_attention_mask=True, 548 | add_special_tokens=True, 549 | return_tensors="pt", 550 | ) 551 | attention_mask = uncond_input.attention_mask.to(device) 552 | 553 | negative_prompt_embeds = self.text_encoder( 554 | uncond_input.input_ids.to(device), 555 | attention_mask=attention_mask, 556 | ) 557 | negative_prompt_embeds = negative_prompt_embeds[0] 558 | 559 | num_edit_tokens = 0 560 | if do_classifier_free_guidance and editing_prompt is not None and edit_prompt_embeds is None: 561 | edit_tokens: List[str] 562 | if isinstance(editing_prompt, str): 563 | edit_tokens = [editing_prompt] 564 | else: 565 | edit_tokens = editing_prompt 566 | edit_tokens = [x for item in edit_tokens for x in repeat(item, batch_size)] 567 | edit_tokens = self._text_preprocessing(edit_tokens, clean_caption=clean_caption) 568 | 569 | max_length = prompt_embeds.shape[1] 570 | edit_input = self.tokenizer( 571 | edit_tokens, 572 | padding="max_length", 573 | max_length=max_length, 574 | truncation=True, 575 | return_attention_mask=True, 576 | add_special_tokens=True, 577 | return_tensors="pt", 578 | return_length=True 579 | ) 580 | num_edit_tokens = edit_input.length -1 # not counting endoftext (there is no startoftext) 581 | #print(f"num edit tokens: {num_edit_tokens}") 582 | 583 | #edit_tokens = [[word.replace("", "") for word in self.tokenizer.tokenize(item)] for item in editing_prompt] 584 | #print(f"edit_tokens: {edit_tokens}") 585 | 586 | attention_mask = edit_input.attention_mask.to(device) 587 | 588 | edit_prompt_embeds = self.text_encoder( 589 | edit_input.input_ids.to(device), 590 | attention_mask=attention_mask, 591 | ) 592 | edit_prompt_embeds = edit_prompt_embeds[0] 593 | 594 | if do_classifier_free_guidance: 595 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 596 | seq_len = negative_prompt_embeds.shape[1] 597 | 598 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) 599 | 600 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 601 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 602 | 603 | # For classifier free guidance, we need to do two forward passes. 604 | # Here we concatenate the unconditional and text embeddings into a single batch 605 | # to avoid doing two forward passes 606 | if editing_prompt is not None: 607 | bs_embed_edit, seq_len_edit, _ = edit_prompt_embeds.shape 608 | edit_prompt_embeds = edit_prompt_embeds.repeat(1, num_images_per_prompt, 1) 609 | edit_prompt_embeds = edit_prompt_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) 610 | 611 | else: 612 | negative_prompt_embeds = None 613 | edit_prompt_embeds = None 614 | 615 | return prompt_embeds, negative_prompt_embeds, edit_prompt_embeds, num_edit_tokens 616 | 617 | def run_safety_checker(self, image, device, dtype): 618 | if self.safety_checker is not None: 619 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) 620 | image, nsfw_detected, watermark_detected = self.safety_checker( 621 | images=image, 622 | clip_input=safety_checker_input.pixel_values.to(dtype=dtype), 623 | ) 624 | else: 625 | nsfw_detected = None 626 | watermark_detected = None 627 | 628 | if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: 629 | self.unet_offload_hook.offload() 630 | 631 | return image, nsfw_detected, watermark_detected 632 | 633 | # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 634 | def prepare_extra_step_kwargs(self, eta): 635 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 636 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 637 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 638 | # and should be between [0, 1] 639 | 640 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 641 | extra_step_kwargs = {} 642 | if accepts_eta: 643 | extra_step_kwargs["eta"] = eta 644 | 645 | return extra_step_kwargs 646 | 647 | def check_inputs( 648 | self, 649 | prompt, 650 | callback_steps, 651 | negative_prompt=None, 652 | prompt_embeds=None, 653 | negative_prompt_embeds=None, 654 | ): 655 | if (callback_steps is None) or ( 656 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 657 | ): 658 | raise ValueError( 659 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 660 | f" {type(callback_steps)}." 661 | ) 662 | 663 | if prompt is not None and prompt_embeds is not None: 664 | raise ValueError( 665 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 666 | " only forward one of the two." 667 | ) 668 | elif prompt is None and prompt_embeds is None: 669 | raise ValueError( 670 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 671 | ) 672 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 673 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 674 | 675 | if negative_prompt is not None and negative_prompt_embeds is not None: 676 | raise ValueError( 677 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 678 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 679 | ) 680 | 681 | if prompt_embeds is not None and negative_prompt_embeds is not None: 682 | if prompt_embeds.shape != negative_prompt_embeds.shape: 683 | raise ValueError( 684 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 685 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 686 | f" {negative_prompt_embeds.shape}." 687 | ) 688 | 689 | # Modified 690 | def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, intermediate_images): 691 | shape = (batch_size, num_channels, height, width) 692 | 693 | if intermediate_images.shape != shape: 694 | raise ValueError(f"Unexpected image shape, got {intermediate_images.shape}, expected {shape}") 695 | 696 | intermediate_images = intermediate_images.to(device) 697 | 698 | # scale the initial noise by the standard deviation required by the scheduler 699 | intermediate_images = intermediate_images * self.scheduler.init_noise_sigma 700 | return intermediate_images 701 | 702 | def prepare_unet(self, attention_store, enabled_editing_prompts): 703 | attn_procs = {} 704 | for name in self.unet.attn_processors.keys(): 705 | if name.startswith("mid_block"): 706 | place_in_unet = "mid" 707 | elif name.startswith("up_blocks"): 708 | place_in_unet = "up" 709 | elif name.startswith("down_blocks"): 710 | place_in_unet = "down" 711 | else: 712 | continue 713 | 714 | attn_procs[name] = CrossAttnProcessor( 715 | attention_store=attention_store, 716 | place_in_unet=place_in_unet, 717 | PnP=False, 718 | editing_prompts=enabled_editing_prompts) 719 | 720 | self.unet.set_attn_processor(attn_procs) 721 | 722 | def _text_preprocessing(self, text, clean_caption=False): 723 | if clean_caption and not is_bs4_available(): 724 | logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) 725 | logger.warn("Setting `clean_caption` to False...") 726 | clean_caption = False 727 | 728 | if clean_caption and not is_ftfy_available(): 729 | logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) 730 | logger.warn("Setting `clean_caption` to False...") 731 | clean_caption = False 732 | 733 | if not isinstance(text, (tuple, list)): 734 | text = [text] 735 | 736 | def process(text: str): 737 | if clean_caption: 738 | text = self._clean_caption(text) 739 | text = self._clean_caption(text) 740 | else: 741 | text = text.lower().strip() 742 | return text 743 | 744 | return [process(t) for t in text] 745 | 746 | def _clean_caption(self, caption): 747 | caption = str(caption) 748 | caption = ul.unquote_plus(caption) 749 | caption = caption.strip().lower() 750 | caption = re.sub("", "person", caption) 751 | # urls: 752 | caption = re.sub( 753 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa 754 | "", 755 | caption, 756 | ) # regex for urls 757 | caption = re.sub( 758 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa 759 | "", 760 | caption, 761 | ) # regex for urls 762 | # html: 763 | caption = BeautifulSoup(caption, features="html.parser").text 764 | 765 | # @ 766 | caption = re.sub(r"@[\w\d]+\b", "", caption) 767 | 768 | # 31C0—31EF CJK Strokes 769 | # 31F0—31FF Katakana Phonetic Extensions 770 | # 3200—32FF Enclosed CJK Letters and Months 771 | # 3300—33FF CJK Compatibility 772 | # 3400—4DBF CJK Unified Ideographs Extension A 773 | # 4DC0—4DFF Yijing Hexagram Symbols 774 | # 4E00—9FFF CJK Unified Ideographs 775 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) 776 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) 777 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption) 778 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption) 779 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) 780 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) 781 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) 782 | ####################################################### 783 | 784 | # все виды тире / all types of dash --> "-" 785 | caption = re.sub( 786 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa 787 | "-", 788 | caption, 789 | ) 790 | 791 | # кавычки к одному стандарту 792 | caption = re.sub(r"[`´«»“”¨]", '"', caption) 793 | caption = re.sub(r"[‘’]", "'", caption) 794 | 795 | # " 796 | caption = re.sub(r""?", "", caption) 797 | # & 798 | caption = re.sub(r"&", "", caption) 799 | 800 | # ip adresses: 801 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) 802 | 803 | # article ids: 804 | caption = re.sub(r"\d:\d\d\s+$", "", caption) 805 | 806 | # \n 807 | caption = re.sub(r"\\n", " ", caption) 808 | 809 | # "#123" 810 | caption = re.sub(r"#\d{1,3}\b", "", caption) 811 | # "#12345.." 812 | caption = re.sub(r"#\d{5,}\b", "", caption) 813 | # "123456.." 814 | caption = re.sub(r"\b\d{6,}\b", "", caption) 815 | # filenames: 816 | caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) 817 | 818 | # 819 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" 820 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" 821 | 822 | caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT 823 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " 824 | 825 | # this-is-my-cute-cat / this_is_my_cute_cat 826 | regex2 = re.compile(r"(?:\-|\_)") 827 | if len(re.findall(regex2, caption)) > 3: 828 | caption = re.sub(regex2, " ", caption) 829 | 830 | caption = ftfy.fix_text(caption) 831 | caption = html.unescape(html.unescape(caption)) 832 | 833 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 834 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc 835 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 836 | 837 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) 838 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) 839 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) 840 | caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) 841 | caption = re.sub(r"\bpage\s+\d+\b", "", caption) 842 | 843 | caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... 844 | 845 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) 846 | 847 | caption = re.sub(r"\b\s+\:\s+", r": ", caption) 848 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) 849 | caption = re.sub(r"\s+", " ", caption) 850 | 851 | caption.strip() 852 | 853 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) 854 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) 855 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) 856 | caption = re.sub(r"^\.\S+$", "", caption) 857 | 858 | return caption.strip() 859 | 860 | def crop(self,image_path, left=0, right=0, top=0, bottom=0, size=64): 861 | if type(image_path) is str: 862 | image = np.array(PIL.Image.open(image_path).convert('RGB'))[:, :, :3] 863 | else: 864 | image = image_path 865 | h, w, c = image.shape 866 | left = min(left, w-1) 867 | right = min(right, w - left - 1) 868 | top = min(top, h - left - 1) 869 | bottom = min(bottom, h - top - 1) 870 | image = image[top:h-bottom, left:w-right] 871 | h, w, c = image.shape 872 | if h < w: 873 | offset = (w - h) // 2 874 | image = image[:, offset:offset + h] 875 | elif w < h: 876 | offset = (h - w) // 2 877 | image = image[offset:offset + w] 878 | image = PIL.Image.fromarray(image).resize((size, size)) 879 | return image 880 | 881 | # Copied from diffusers.pipelines.deepfloyed_if.IFImg2ImgPipeline.preprocess_image 882 | def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: 883 | if not isinstance(image, list): 884 | image = [image] 885 | 886 | def numpy_to_pt(images): 887 | if images.ndim == 3: 888 | images = images[..., None] 889 | 890 | images = torch.from_numpy(images.transpose(0, 3, 1, 2)) 891 | return images 892 | 893 | if isinstance(image[0], PIL.Image.Image): 894 | new_image = [] 895 | 896 | for image_ in image: 897 | image_ = image_.convert("RGB") 898 | image_ = resize(image_, self.unet.sample_size) 899 | image_ = np.array(image_) 900 | image_ = image_.astype(np.float32) 901 | image_ = image_ / 127.5 - 1 902 | new_image.append(image_) 903 | 904 | image = new_image 905 | 906 | image = np.stack(image, axis=0) # to np 907 | image = numpy_to_pt(image) # to pt 908 | 909 | elif isinstance(image[0], np.ndarray): 910 | image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) 911 | image = numpy_to_pt(image) 912 | 913 | elif isinstance(image[0], torch.Tensor): 914 | image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) 915 | 916 | return image 917 | 918 | @torch.no_grad() 919 | @replace_example_docstring(EXAMPLE_DOC_STRING) 920 | def __call__( 921 | self, 922 | prompt: Union[str, List[str]] = None, 923 | #num_inference_steps: int = 100, 924 | #timesteps: List[int] = None, 925 | guidance_scale: float = 7.0, 926 | negative_prompt: Optional[Union[str, List[str]]] = None, 927 | editing_prompt: Optional[Union[str, List[str]]] = None, 928 | #num_images_per_prompt: Optional[int] = 1, 929 | #height: Optional[int] = None, 930 | #width: Optional[int] = None, 931 | #eta: float = 0.0, 932 | #generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 933 | prompt_embeds: Optional[torch.FloatTensor] = None, 934 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 935 | edit_prompt_embeds: Optional[torch.FloatTensor] = None, 936 | output_type: Optional[str] = "pil", 937 | return_dict: bool = True, 938 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 939 | callback_steps: int = 1, 940 | clean_caption: bool = True, 941 | reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, 942 | edit_guidance_scale: Optional[Union[float, List[float]]] = 5, 943 | edit_warmup_steps: Optional[Union[int, List[int]]] = 10, 944 | edit_cooldown_steps: Optional[Union[int, List[int]]] = None, 945 | edit_threshold: Optional[Union[float, List[float]]] = 0.9, 946 | edit_momentum_scale: Optional[float] = 0.1, 947 | edit_mom_beta: Optional[float] = 0.4, 948 | edit_weights: Optional[List[float]] = None, 949 | #cross_attention_kwargs: Optional[Dict[str, Any]] = None, 950 | use_cross_attn_mask: bool = False, 951 | use_intersect_mask: bool = False, 952 | # Attention store (just for visualization purposes) 953 | attn_store_steps: Optional[List[int]] = [], 954 | store_averaged_over_steps: bool = True, 955 | ): 956 | """ 957 | Function invoked when calling the pipeline for generation. 958 | 959 | Args: 960 | prompt (`str` or `List[str]`, *optional*): 961 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 962 | instead. 963 | num_inference_steps (`int`, *optional*, defaults to 50): 964 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 965 | expense of slower inference. 966 | timesteps (`List[int]`, *optional*): 967 | Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` 968 | timesteps are used. Must be in descending order. 969 | guidance_scale (`float`, *optional*, defaults to 7.5): 970 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 971 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 972 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 973 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 974 | usually at the expense of lower image quality. 975 | negative_prompt (`str` or `List[str]`, *optional*): 976 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 977 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 978 | less than `1`). 979 | num_images_per_prompt (`int`, *optional*, defaults to 1): 980 | The number of images to generate per prompt. 981 | height (`int`, *optional*, defaults to self.unet.config.sample_size): 982 | The height in pixels of the generated image. 983 | width (`int`, *optional*, defaults to self.unet.config.sample_size): 984 | The width in pixels of the generated image. 985 | eta (`float`, *optional*, defaults to 0.0): 986 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 987 | [`schedulers.DDIMScheduler`], will be ignored for others. 988 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 989 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 990 | to make generation deterministic. 991 | prompt_embeds (`torch.FloatTensor`, *optional*): 992 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 993 | provided, text embeddings will be generated from `prompt` input argument. 994 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 995 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 996 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 997 | argument. 998 | output_type (`str`, *optional*, defaults to `"pil"`): 999 | The output format of the generate image. Choose between 1000 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 1001 | return_dict (`bool`, *optional*, defaults to `True`): 1002 | Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. 1003 | callback (`Callable`, *optional*): 1004 | A function that will be called every `callback_steps` steps during inference. The function will be 1005 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 1006 | callback_steps (`int`, *optional*, defaults to 1): 1007 | The frequency at which the `callback` function will be called. If not specified, the callback will be 1008 | called at every step. 1009 | clean_caption (`bool`, *optional*, defaults to `True`): 1010 | Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to 1011 | be installed. If the dependencies are not installed, the embeddings will be created from the raw 1012 | prompt. 1013 | cross_attention_kwargs (`dict`, *optional*): 1014 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 1015 | `self.processor` in 1016 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 1017 | 1018 | Examples: 1019 | 1020 | Returns: 1021 | [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: 1022 | [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When 1023 | returning a tuple, the first element is a list with the generated images, and the second element is a list 1024 | of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) 1025 | or watermarked content, according to the `safety_checker`. 1026 | """ 1027 | eta = self.eta 1028 | num_inference_steps = self.num_inversion_steps 1029 | num_images_per_prompt = 1 1030 | cross_attention_kwargs = None 1031 | intermediate_images = self.init_images 1032 | 1033 | use_ddpm = True 1034 | zs = self.zs 1035 | 1036 | reset_dpm(self.scheduler) 1037 | 1038 | if use_intersect_mask: 1039 | use_cross_attn_mask = True 1040 | 1041 | if use_cross_attn_mask: 1042 | self.smoothing = GaussianSmoothing(self._execution_device) 1043 | 1044 | # 0. Default height and width 1045 | height = self.unet.config.sample_size 1046 | width = self.unet.config.sample_size 1047 | 1048 | # 1. Check inputs. Raise error if not correct 1049 | self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) 1050 | 1051 | # 2. Define call parameters 1052 | height = height or self.unet.config.sample_size 1053 | width = width or self.unet.config.sample_size 1054 | 1055 | if prompt is not None and isinstance(prompt, str): 1056 | batch_size = 1 1057 | elif prompt is not None and isinstance(prompt, list): 1058 | batch_size = len(prompt) 1059 | else: 1060 | batch_size = prompt_embeds.shape[0] 1061 | 1062 | device = self._execution_device 1063 | 1064 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 1065 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 1066 | # corresponds to doing no classifier free guidance. 1067 | do_classifier_free_guidance = guidance_scale > 1.0 1068 | 1069 | if editing_prompt: 1070 | enable_edit_guidance = True 1071 | if isinstance(editing_prompt, str): 1072 | editing_prompt = [editing_prompt] 1073 | enabled_editing_prompts = len(editing_prompt) 1074 | elif edit_prompt_embeds is not None: 1075 | enable_edit_guidance = True 1076 | enabled_editing_prompts = int(edit_prompt_embeds.shape[0] / batch_size) 1077 | else: 1078 | enabled_editing_prompts = 0 1079 | enable_edit_guidance = False 1080 | 1081 | # 3. Encode input prompt 1082 | prompt_embeds, negative_prompt_embeds, edit_prompt_embeds, num_edit_tokens = self.encode_prompt( 1083 | prompt, 1084 | do_classifier_free_guidance, 1085 | num_images_per_prompt=num_images_per_prompt, 1086 | device=device, 1087 | negative_prompt=negative_prompt, 1088 | editing_prompt=editing_prompt, 1089 | prompt_embeds=prompt_embeds, 1090 | negative_prompt_embeds=negative_prompt_embeds, 1091 | edit_prompt_embeds=edit_prompt_embeds, 1092 | clean_caption=clean_caption, 1093 | ) 1094 | 1095 | self.text_cross_attention_maps = [prompt] if isinstance(prompt, str) else prompt 1096 | if do_classifier_free_guidance: 1097 | if enable_edit_guidance: 1098 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, edit_prompt_embeds]) 1099 | self.text_cross_attention_maps += \ 1100 | ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt) 1101 | else: 1102 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 1103 | 1104 | # 4. Prepare timesteps 1105 | timesteps = self.scheduler.timesteps 1106 | if use_ddpm: 1107 | t_to_idx = {int(v): k for k, v in enumerate(timesteps)} 1108 | 1109 | if use_cross_attn_mask: 1110 | self.attention_store = AttentionStore(average=store_averaged_over_steps) 1111 | self.prepare_unet(self.attention_store, enabled_editing_prompts) 1112 | 1113 | # 5. Prepare intermediate images 1114 | intermediate_images = self.prepare_intermediate_images( 1115 | batch_size * num_images_per_prompt, 1116 | self.unet.config.in_channels, 1117 | height, 1118 | width, 1119 | prompt_embeds.dtype, 1120 | device, 1121 | intermediate_images 1122 | ) 1123 | 1124 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 1125 | extra_step_kwargs = self.prepare_extra_step_kwargs(eta) 1126 | 1127 | # HACK: see comment in `enable_model_cpu_offload` 1128 | if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: 1129 | self.text_encoder_offload_hook.offload() 1130 | 1131 | # Initialize edit_momentum to None 1132 | edit_momentum = None 1133 | self.sem_guidance = None 1134 | 1135 | # 7. Denoising loop 1136 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 1137 | with self.progress_bar(total=len(timesteps)) as progress_bar: 1138 | for i, t in enumerate(timesteps): 1139 | model_input = ( 1140 | torch.cat([intermediate_images] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else intermediate_images 1141 | ) 1142 | model_input = self.scheduler.scale_model_input(model_input, t) 1143 | 1144 | # predict the noise residual 1145 | noise_pred = self.unet( 1146 | model_input, 1147 | t, 1148 | encoder_hidden_states=prompt_embeds, 1149 | cross_attention_kwargs=cross_attention_kwargs, 1150 | return_dict=False, 1151 | )[0] 1152 | 1153 | # perform guidance 1154 | if do_classifier_free_guidance: 1155 | noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] 1156 | noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] 1157 | 1158 | 1159 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) 1160 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) 1161 | 1162 | # default text guidance 1163 | noise_guidance = (noise_pred_text - noise_pred_uncond) * guidance_scale 1164 | if edit_momentum is None: 1165 | edit_momentum = torch.zeros_like(noise_guidance) 1166 | 1167 | if self.sem_guidance is None: 1168 | self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_text.shape)) 1169 | 1170 | if enable_edit_guidance: 1171 | noise_pred_edit_concepts = noise_pred_out[2:] 1172 | tmp = noise_pred_edit_concepts[0] 1173 | tmp, _ = tmp.split(model_input.shape[1], dim=1) 1174 | 1175 | concept_weights = torch.zeros( 1176 | (len(tmp), noise_guidance.shape[0]), 1177 | device=edit_momentum.device, 1178 | dtype=noise_guidance.dtype, 1179 | ) 1180 | noise_guidance_edit = torch.zeros( 1181 | (len(tmp), *noise_guidance.shape), 1182 | device=edit_momentum.device, 1183 | dtype=noise_guidance.dtype, 1184 | ) 1185 | # noise_guidance_edit = torch.zeros_like(noise_guidance) 1186 | warmup_inds = [] 1187 | for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): 1188 | noise_pred_edit_concept, _ = noise_pred_edit_concept.split(model_input.shape[1], dim=1) 1189 | if isinstance(edit_guidance_scale, list): 1190 | edit_guidance_scale_c = edit_guidance_scale[c] 1191 | else: 1192 | edit_guidance_scale_c = edit_guidance_scale 1193 | 1194 | if isinstance(edit_threshold, list): 1195 | edit_threshold_c = edit_threshold[c] 1196 | else: 1197 | edit_threshold_c = edit_threshold 1198 | if isinstance(reverse_editing_direction, list): 1199 | reverse_editing_direction_c = reverse_editing_direction[c] 1200 | else: 1201 | reverse_editing_direction_c = reverse_editing_direction 1202 | if edit_weights: 1203 | edit_weight_c = edit_weights[c] 1204 | else: 1205 | edit_weight_c = 1.0 1206 | if isinstance(edit_warmup_steps, list): 1207 | edit_warmup_steps_c = edit_warmup_steps[c] 1208 | else: 1209 | edit_warmup_steps_c = edit_warmup_steps 1210 | 1211 | if isinstance(edit_cooldown_steps, list): 1212 | edit_cooldown_steps_c = edit_cooldown_steps[c] 1213 | elif edit_cooldown_steps is None: 1214 | edit_cooldown_steps_c = i + 1 1215 | else: 1216 | edit_cooldown_steps_c = edit_cooldown_steps 1217 | if i >= edit_warmup_steps_c: 1218 | warmup_inds.append(c) 1219 | if i >= edit_cooldown_steps_c: 1220 | noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) 1221 | continue 1222 | 1223 | noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond 1224 | # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) 1225 | tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) 1226 | 1227 | tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) 1228 | if reverse_editing_direction_c: 1229 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 1230 | concept_weights[c, :] = tmp_weights 1231 | 1232 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c 1233 | 1234 | if use_cross_attn_mask: 1235 | out = self.attention_store.aggregate_attention( 1236 | attention_maps=self.attention_store.step_store, 1237 | prompts=self.text_cross_attention_maps, 1238 | res=8, 1239 | from_where=["up","down"], 1240 | is_cross=True, 1241 | select=self.text_cross_attention_maps.index(editing_prompt[c]), 1242 | ) 1243 | 1244 | attn_map = out[:, :, :num_edit_tokens[c]] # there is no startoftext 1245 | 1246 | # average over all tokens 1247 | assert(attn_map.shape[2]==num_edit_tokens[c]) 1248 | attn_map = torch.sum(attn_map, dim=2) 1249 | 1250 | # gaussian_smoothing TODO 1251 | attn_map = F.pad(attn_map.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") 1252 | attn_map = self.smoothing(attn_map).squeeze(0).squeeze(0) 1253 | 1254 | # create binary mask 1255 | tmp = torch.quantile(attn_map.flatten(),edit_threshold_c) 1256 | attn_mask = torch.where(attn_map >= tmp, 1.0, 0.0) 1257 | 1258 | # resolution must match latent space dimension 1259 | attn_mask = F.interpolate( 1260 | attn_mask.unsqueeze(0).unsqueeze(0), 1261 | noise_guidance_edit_tmp.shape[-2:] # 64,64 1262 | )[0,0,:,:] 1263 | 1264 | if not use_intersect_mask: 1265 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask 1266 | 1267 | if use_intersect_mask: 1268 | noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) 1269 | noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, keepdim=True) 1270 | noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1,noise_guidance_edit_tmp.shape[1],1,1) 1271 | 1272 | if noise_guidance_edit_tmp_quantile.dtype == torch.float32: 1273 | tmp = torch.quantile( 1274 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2), 1275 | edit_threshold_c, 1276 | dim=2, 1277 | keepdim=False, 1278 | ) 1279 | else: 1280 | tmp = torch.quantile( 1281 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), 1282 | edit_threshold_c, 1283 | dim=2, 1284 | keepdim=False, 1285 | ).to(noise_guidance_edit_tmp_quantile.dtype) 1286 | 1287 | sega_mask = torch.where( 1288 | noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], 1289 | torch.ones_like(noise_guidance_edit_tmp), 1290 | torch.zeros_like(noise_guidance_edit_tmp), 1291 | ) 1292 | 1293 | intersect_mask = sega_mask * attn_mask 1294 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask 1295 | 1296 | elif not use_cross_attn_mask: 1297 | # calculate quantile 1298 | noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) 1299 | noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, keepdim=True) 1300 | noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1,noise_guidance_edit_tmp.shape[1],1,1) 1301 | 1302 | # torch.quantile function expects float32 1303 | if noise_guidance_edit_tmp_quantile.dtype == torch.float32: 1304 | tmp = torch.quantile( 1305 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2), 1306 | edit_threshold_c, 1307 | dim=2, 1308 | keepdim=False, 1309 | ) 1310 | else: 1311 | tmp = torch.quantile( 1312 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), 1313 | edit_threshold_c, 1314 | dim=2, 1315 | keepdim=False, 1316 | ).to(noise_guidance_edit_tmp_quantile.dtype) 1317 | 1318 | noise_guidance_edit_tmp = torch.where( 1319 | noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], 1320 | noise_guidance_edit_tmp, 1321 | torch.zeros_like(noise_guidance_edit_tmp), 1322 | ) 1323 | 1324 | noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp 1325 | 1326 | warmup_inds = torch.tensor(warmup_inds).to(self.device) 1327 | if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: 1328 | concept_weights = concept_weights.to("cpu") # Offload to cpu 1329 | noise_guidance_edit = noise_guidance_edit.to("cpu") 1330 | 1331 | concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) 1332 | concept_weights_tmp = torch.where( 1333 | concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp 1334 | ) 1335 | concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) 1336 | # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) 1337 | 1338 | noise_guidance_edit_tmp = torch.index_select( 1339 | noise_guidance_edit.to(self.device), 0, warmup_inds 1340 | ) 1341 | noise_guidance_edit_tmp = torch.einsum( 1342 | "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp 1343 | ) 1344 | noise_guidance_edit_tmp = noise_guidance_edit_tmp 1345 | noise_guidance = noise_guidance + noise_guidance_edit_tmp 1346 | 1347 | self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() 1348 | 1349 | del noise_guidance_edit_tmp 1350 | del concept_weights_tmp 1351 | concept_weights = concept_weights.to(self.device) 1352 | noise_guidance_edit = noise_guidance_edit.to(self.device) 1353 | 1354 | concept_weights = torch.where( 1355 | concept_weights < 0, torch.zeros_like(concept_weights), concept_weights 1356 | ) 1357 | 1358 | concept_weights = torch.nan_to_num(concept_weights) 1359 | 1360 | noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) 1361 | 1362 | noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum 1363 | 1364 | edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit 1365 | 1366 | if warmup_inds.shape[0] == len(noise_pred_edit_concepts): 1367 | #print(noise_guidance.device, noise_guidance_edit.device) 1368 | noise_guidance = noise_guidance + noise_guidance_edit 1369 | self.sem_guidance[i] = noise_guidance_edit.detach().cpu() 1370 | 1371 | noise_pred = noise_pred_uncond + noise_guidance 1372 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) 1373 | 1374 | if self.scheduler.config.variance_type not in ["learned", "learned_range"]: 1375 | noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) 1376 | 1377 | # compute the previous noisy sample x_t -> x_t-1 1378 | if use_ddpm: 1379 | idx = t_to_idx[int(t)] 1380 | intermediate_images = self.scheduler.step( 1381 | noise_pred, t, intermediate_images, variance_noise=zs[idx], **extra_step_kwargs, return_dict=False 1382 | )[0] 1383 | else: 1384 | intermediate_images = self.scheduler.step( 1385 | noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False 1386 | )[0] 1387 | 1388 | if use_cross_attn_mask: 1389 | # step callback 1390 | store_step = i in attn_store_steps 1391 | if store_step: 1392 | print(f"storing attention for step {i}") 1393 | self.attention_store.between_steps(store_step) 1394 | 1395 | # call the callback, if provided 1396 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 1397 | progress_bar.update() 1398 | if callback is not None and i % callback_steps == 0: 1399 | callback(i, t, intermediate_images) 1400 | 1401 | image = intermediate_images 1402 | 1403 | if output_type == "pil": 1404 | # 8. Post-processing 1405 | image = (image / 2 + 0.5).clamp(0, 1) 1406 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 1407 | 1408 | # 9. Run safety checker 1409 | image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) 1410 | 1411 | # 10. Convert to PIL 1412 | image = self.numpy_to_pil(image) 1413 | 1414 | # 11. Apply watermark 1415 | if self.watermarker is not None: 1416 | image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) 1417 | elif output_type == "pt": 1418 | nsfw_detected = None 1419 | watermark_detected = None 1420 | 1421 | if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: 1422 | self.unet_offload_hook.offload() 1423 | else: 1424 | # 8. Post-processing 1425 | image = (image / 2 + 0.5).clamp(0, 1) 1426 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 1427 | 1428 | # 9. Run safety checker 1429 | image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) 1430 | 1431 | # Offload last model to CPU 1432 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 1433 | self.final_offload_hook.offload() 1434 | 1435 | if not return_dict: 1436 | return (image, nsfw_detected, watermark_detected) 1437 | 1438 | return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) 1439 | 1440 | 1441 | @torch.no_grad() 1442 | def invert(self, 1443 | image_path: str, 1444 | source_prompt: str = "", 1445 | source_guidance_scale = 3.5, 1446 | num_inversion_steps: int = 100, 1447 | skip: float = .15, 1448 | eta: float = 1.0, 1449 | generator: Optional[torch.Generator] = None 1450 | ): 1451 | """ 1452 | Inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf, 1453 | based on the code in https://github.com/inbarhub/DDPM_inversion 1454 | 1455 | returns: 1456 | zs - noise maps 1457 | xts - intermediate inverted latents 1458 | """ 1459 | self.eta = eta 1460 | assert(self.eta > 0) 1461 | 1462 | device = self._execution_device 1463 | dtype = self.text_encoder.dtype 1464 | 1465 | train_steps = self.scheduler.config.num_train_timesteps 1466 | timesteps = torch.from_numpy( 1467 | np.linspace(train_steps - skip * train_steps - 1, 0, num_inversion_steps).astype(np.int64)).to(self.device) 1468 | #timesteps += self.scheduler.config.steps_offset 1469 | 1470 | self.num_inversion_steps = timesteps.shape[0] 1471 | self.scheduler.num_inference_steps = timesteps.shape[0] 1472 | self.scheduler.timesteps = timesteps 1473 | #print(timesteps) 1474 | 1475 | reset_dpm(self.scheduler) 1476 | 1477 | # Reset attn processor, we do not want to store attn maps during inversion 1478 | self.unet.set_attn_processor(AttnAddedKVProcessor()) 1479 | 1480 | # 1. get embeddings 1481 | text_embeddings, uncond_embedding, _, _ = self.encode_prompt(source_prompt) 1482 | prompt_embeds = torch.cat([uncond_embedding, text_embeddings]) 1483 | 1484 | # 2. open image 1485 | image = self.crop(image_path) 1486 | x0 = self.preprocess_image(image) 1487 | x0 = x0.to(device=device, dtype=dtype) 1488 | self.batch_size = x0.shape[0] 1489 | 1490 | # 3. find zs and xts 1491 | variance_noise_shape = ( 1492 | self.num_inversion_steps, 1493 | self.batch_size, 1494 | self.unet.config.in_channels, 1495 | self.unet.sample_size, 1496 | self.unet.sample_size) 1497 | 1498 | # intermediate latents 1499 | t_to_idx = {int(v):k for k,v in enumerate(timesteps)} 1500 | xts = torch.zeros(size=variance_noise_shape, device=device, dtype=uncond_embedding.dtype) 1501 | 1502 | for t in reversed(timesteps): 1503 | idx = self.num_inversion_steps-t_to_idx[int(t)] - 1 1504 | noise = randn_tensor(shape=x0.shape, generator=generator, device=device, dtype=x0.dtype) 1505 | xts[idx] = self.scheduler.add_noise(x0, noise, t) 1506 | xts = torch.cat([x0.unsqueeze(0), xts], dim=0) 1507 | 1508 | # noise maps 1509 | zs = torch.zeros(size=variance_noise_shape, device=device, dtype=uncond_embedding.dtype) 1510 | 1511 | if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: 1512 | self.text_encoder_offload_hook.offload() 1513 | 1514 | for t in tqdm(timesteps): 1515 | idx = self.num_inversion_steps-t_to_idx[int(t)]-1 1516 | 1517 | # 1. predict noise residual 1518 | xt = xts[idx+1] 1519 | model_input = torch.cat([xt] * 2) 1520 | noise_pred = self.unet( 1521 | model_input, 1522 | timestep=t, 1523 | encoder_hidden_states=prompt_embeds, 1524 | cross_attention_kwargs=None, 1525 | return_dict=False 1526 | )[0] 1527 | 1528 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 1529 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) 1530 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) 1531 | 1532 | noise_pred = noise_pred_uncond + (noise_pred_text - noise_pred_uncond) * source_guidance_scale 1533 | 1534 | xtm1 = xts[idx] 1535 | z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, eta) 1536 | zs[idx] = z 1537 | 1538 | # correction to avoid error accumulation 1539 | xts[idx] = xtm1_corrected 1540 | 1541 | # TODO: I don't think that the noise map for the last step should be discarded ?! 1542 | # if not zs is None: 1543 | # zs[-1] = torch.zeros_like(zs[-1]) 1544 | 1545 | self.init_images = xts[-1].expand(1, -1, -1, -1) 1546 | zs = zs.flip(0) 1547 | self.zs = zs 1548 | 1549 | if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: 1550 | self.unet_offload_hook.offload() 1551 | 1552 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 1553 | self.final_offload_hook.offload() 1554 | 1555 | return zs, xts 1556 | 1557 | 1558 | # Copied from pipelines.StableDiffusion.CycleDiffusionPipeline.compute_noise 1559 | def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): 1560 | # 1. get previous step value (=t-1) 1561 | prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps 1562 | 1563 | # 2. compute alphas, betas 1564 | alpha_prod_t = scheduler.alphas_cumprod[timestep] 1565 | alpha_prod_t_prev = ( 1566 | scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod 1567 | ) 1568 | 1569 | beta_prod_t = 1 - alpha_prod_t 1570 | 1571 | # 3. compute predicted original sample from predicted noise also called 1572 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 1573 | pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) 1574 | 1575 | # 4. Clip "predicted x_0" 1576 | if scheduler.config.clip_sample: 1577 | pred_original_sample = torch.clamp(pred_original_sample, -1, 1) 1578 | 1579 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 1580 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 1581 | variance = scheduler._get_variance(timestep, prev_timestep) 1582 | std_dev_t = eta * variance ** (0.5) 1583 | 1584 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 1585 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred 1586 | 1587 | # modifed so that updated xtm1 is returned as well (to avoid error accumulation) 1588 | mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 1589 | 1590 | if variance > 0.0: 1591 | noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) 1592 | else: 1593 | noise = torch.Tensor([0.0]).to(latents.device) 1594 | 1595 | return noise, mu_xt + ( eta * variance ** 0.5 )*noise 1596 | 1597 | def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): 1598 | def first_order_update(model_output, timestep, prev_timestep, sample): 1599 | lambda_t, lambda_s = scheduler.lambda_t[prev_timestep], scheduler.lambda_t[timestep] 1600 | alpha_t, alpha_s = scheduler.alpha_t[prev_timestep], scheduler.alpha_t[timestep] 1601 | sigma_t, sigma_s = scheduler.sigma_t[prev_timestep], scheduler.sigma_t[timestep] 1602 | h = lambda_t - lambda_s 1603 | 1604 | mu_xt = ( 1605 | (sigma_t / sigma_s * torch.exp(-h)) * sample 1606 | + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output 1607 | ) 1608 | sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) 1609 | 1610 | if sigma > 0.0: 1611 | noise = (prev_latents - mu_xt) / sigma 1612 | else: 1613 | noise = torch.Tensor([0.0]).to(sample.device) 1614 | 1615 | prev_sample = mu_xt + sigma * noise 1616 | 1617 | return noise, prev_sample 1618 | 1619 | def second_order_update(model_output_list, timestep_list, prev_timestep, sample): 1620 | t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] 1621 | m0, m1 = model_output_list[-1], model_output_list[-2] 1622 | lambda_t, lambda_s0, lambda_s1 = scheduler.lambda_t[t], scheduler.lambda_t[s0], scheduler.lambda_t[s1] 1623 | alpha_t, alpha_s0 = scheduler.alpha_t[t], scheduler.alpha_t[s0] 1624 | sigma_t, sigma_s0 = scheduler.sigma_t[t], scheduler.sigma_t[s0] 1625 | h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 1626 | r0 = h_0 / h 1627 | D0, D1 = m0, (1.0 / r0) * (m0 - m1) 1628 | 1629 | mu_xt = ( 1630 | (sigma_t / sigma_s0 * torch.exp(-h)) * sample 1631 | + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 1632 | + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 1633 | ) 1634 | sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) 1635 | 1636 | if sigma > 0.0: 1637 | noise = (prev_latents - mu_xt) / sigma 1638 | else: 1639 | noise = torch.Tensor([0.0]).to(sample.device) 1640 | 1641 | prev_sample = mu_xt + sigma * noise 1642 | 1643 | return noise, prev_sample 1644 | 1645 | step_index = (scheduler.timesteps == timestep).nonzero() 1646 | if len(step_index) == 0: 1647 | step_index = len(scheduler.timesteps) - 1 1648 | else: 1649 | step_index = step_index.item() 1650 | 1651 | prev_timestep = 0 if step_index == len(scheduler.timesteps) - 1 else scheduler.timesteps[step_index + 1] 1652 | 1653 | model_output = scheduler.convert_model_output(noise_pred, timestep, latents) 1654 | 1655 | for i in range(scheduler.config.solver_order - 1): 1656 | scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] 1657 | scheduler.model_outputs[-1] = model_output 1658 | 1659 | if scheduler.lower_order_nums < 1: 1660 | noise, prev_sample = first_order_update(model_output, timestep, prev_timestep, latents) 1661 | else: 1662 | timestep_list = [scheduler.timesteps[step_index - 1], timestep] 1663 | noise, prev_sample = second_order_update(scheduler.model_outputs, timestep_list, prev_timestep, latents) 1664 | 1665 | if scheduler.lower_order_nums < scheduler.config.solver_order: 1666 | scheduler.lower_order_nums += 1 1667 | 1668 | return noise, prev_sample 1669 | 1670 | 1671 | def compute_noise(scheduler, *args): 1672 | if isinstance(scheduler, DDIMScheduler): 1673 | return compute_noise_ddim(scheduler, *args) 1674 | elif isinstance(scheduler, 1675 | DPMSolverMultistepSchedulerInject) and scheduler.config.algorithm_type == 'sde-dpmsolver++' \ 1676 | and scheduler.config.solver_order == 2: 1677 | return compute_noise_sde_dpm_pp_2nd(scheduler, *args) 1678 | else: 1679 | raise NotImplementedError 1680 | -------------------------------------------------------------------------------- /src/leditspp/pipeline_stable_diffusion_ledits.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from itertools import repeat 4 | from typing import Callable, List, Optional, Union 5 | 6 | import torch 7 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 8 | 9 | from diffusers.image_processor import VaeImageProcessor 10 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 11 | from diffusers.models.attention_processor import AttnProcessor, Attention 12 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 13 | from diffusers.schedulers import DDIMScheduler 14 | from .scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInject 15 | from diffusers.utils import logging, randn_tensor 16 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 17 | from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipelineOutput 18 | 19 | import numpy as np 20 | from PIL import Image 21 | from tqdm import tqdm 22 | import torch.nn.functional as F 23 | import math 24 | from collections.abc import Iterable 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | 29 | class AttentionStore(): 30 | @staticmethod 31 | def get_empty_store(): 32 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 33 | "down_self": [], "mid_self": [], "up_self": []} 34 | 35 | def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False): 36 | # attn.shape = batch_size * head_size, seq_len query, seq_len_key 37 | if attn.shape[1] <= self.max_size: 38 | bs = 1 + int(PnP) + editing_prompts 39 | skip = 2 if PnP else 1 # skip PnP & unconditional 40 | attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3) 41 | source_batch_size = int(attn.shape[1] // bs) 42 | self.forward( 43 | attn[:, skip * source_batch_size:], 44 | is_cross, 45 | place_in_unet) 46 | 47 | def forward(self, attn, is_cross: bool, place_in_unet: str): 48 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 49 | 50 | self.step_store[key].append(attn) 51 | 52 | def between_steps(self, store_step=True): 53 | if store_step: 54 | if self.average: 55 | if len(self.attention_store) == 0: 56 | self.attention_store = self.step_store 57 | else: 58 | for key in self.attention_store: 59 | for i in range(len(self.attention_store[key])): 60 | self.attention_store[key][i] += self.step_store[key][i] 61 | else: 62 | if len(self.attention_store) == 0: 63 | self.attention_store = [self.step_store] 64 | else: 65 | self.attention_store.append(self.step_store) 66 | 67 | self.cur_step += 1 68 | self.step_store = self.get_empty_store() 69 | 70 | def get_attention(self, step: int): 71 | if self.average: 72 | attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in 73 | self.attention_store} 74 | else: 75 | assert (step is not None) 76 | attention = self.attention_store[step] 77 | return attention 78 | 79 | def aggregate_attention(self, attention_maps, prompts, res: int, 80 | from_where: List[str], is_cross: bool, select: int 81 | ): 82 | out = [[] for x in range(self.batch_size)] 83 | num_pixels = res ** 2 84 | for location in from_where: 85 | for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 86 | for batch, item in enumerate(bs_item): 87 | if item.shape[1] == num_pixels: 88 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 89 | out[batch].append(cross_maps) 90 | 91 | out = torch.stack([torch.cat(x, dim=0) for x in out]) 92 | # average over heads 93 | out = out.sum(1) / out.shape[1] 94 | return out 95 | 96 | def __init__(self, average: bool, batch_size=1, max_resolution=16): 97 | self.step_store = self.get_empty_store() 98 | self.attention_store = [] 99 | self.cur_step = 0 100 | self.average = average 101 | self.batch_size = batch_size 102 | self.max_size = max_resolution ** 2 103 | 104 | 105 | class CrossAttnProcessor: 106 | 107 | def __init__(self, attention_store, place_in_unet, PnP, editing_prompts): 108 | self.attnstore = attention_store 109 | self.place_in_unet = place_in_unet 110 | self.editing_prompts = editing_prompts 111 | self.PnP = PnP 112 | 113 | def __call__( 114 | self, 115 | attn: Attention, 116 | hidden_states, 117 | encoder_hidden_states=None, 118 | attention_mask=None, 119 | temb=None, 120 | ): 121 | assert (not attn.residual_connection) 122 | assert (attn.spatial_norm is None) 123 | assert (attn.group_norm is None) 124 | assert (hidden_states.ndim != 4) 125 | assert (encoder_hidden_states is not None) # is cross 126 | 127 | batch_size, sequence_length, _ = ( 128 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 129 | ) 130 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 131 | 132 | query = attn.to_q(hidden_states) 133 | 134 | if encoder_hidden_states is None: 135 | encoder_hidden_states = hidden_states 136 | elif attn.norm_cross: 137 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 138 | 139 | key = attn.to_k(encoder_hidden_states) 140 | value = attn.to_v(encoder_hidden_states) 141 | 142 | query = attn.head_to_batch_dim(query) 143 | key = attn.head_to_batch_dim(key) 144 | value = attn.head_to_batch_dim(value) 145 | 146 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 147 | self.attnstore(attention_probs, 148 | is_cross=True, 149 | place_in_unet=self.place_in_unet, 150 | editing_prompts=self.editing_prompts, 151 | PnP=self.PnP) 152 | 153 | hidden_states = torch.bmm(attention_probs, value) 154 | hidden_states = attn.batch_to_head_dim(hidden_states) 155 | 156 | # linear proj 157 | hidden_states = attn.to_out[0](hidden_states) 158 | # dropout 159 | hidden_states = attn.to_out[1](hidden_states) 160 | 161 | hidden_states = hidden_states / attn.rescale_output_factor 162 | return hidden_states 163 | 164 | 165 | # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing 166 | class GaussianSmoothing(): 167 | 168 | def __init__(self, device): 169 | kernel_size = [3, 3] 170 | sigma = [0.5, 0.5] 171 | 172 | # The gaussian kernel is the product of the gaussian function of each dimension. 173 | kernel = 1 174 | meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) 175 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 176 | mean = (size - 1) / 2 177 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) 178 | 179 | # Make sure sum of values in gaussian kernel equals 1. 180 | kernel = kernel / torch.sum(kernel) 181 | 182 | # Reshape to depthwise convolutional weight 183 | kernel = kernel.view(1, 1, *kernel.size()) 184 | kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) 185 | 186 | self.weight = kernel.to(device) 187 | 188 | def __call__(self, input): 189 | """ 190 | Arguments: 191 | Apply gaussian filter to input. 192 | input (torch.Tensor): Input to apply gaussian filter on. 193 | Returns: 194 | filtered (torch.Tensor): Filtered output. 195 | """ 196 | return F.conv2d(input, weight=self.weight.to(input.dtype)) 197 | 198 | 199 | def load_512(image_path, sizes=(512,768), left=0, right=0, top=0, bottom=0, device=None, dtype=None): 200 | def pre_process(im, sizes, left=0, right=0, top=0, bottom=0): 201 | if type(im) is str: 202 | image = np.array(Image.open(im).convert('RGB'))[:, :, :3] 203 | elif isinstance(im, Image.Image): 204 | image = np.array((im).convert('RGB'))[:, :, :3] 205 | else: 206 | image = im 207 | 208 | h, w, c = image.shape 209 | left = min(left, w - 1) 210 | right = min(right, w - left - 1) 211 | top = min(top, h - left - 1) 212 | bottom = min(bottom, h - top - 1) 213 | image = image[top:h - bottom, left:w - right] 214 | 215 | ar = max(*image.shape[:2]) / min(*image.shape[:2]) 216 | 217 | if ar > 1.25: 218 | h_max = image.shape[0] > image.shape[1] 219 | if h_max: 220 | resized = Image.fromarray(image).resize((sizes[0], sizes[1])) 221 | else: 222 | resized = Image.fromarray(image).resize((sizes[1], sizes[0])) 223 | image = np.array(resized) 224 | 225 | else: 226 | image = np.array(Image.fromarray(image).resize((sizes[0], sizes[0]))) 227 | image = torch.from_numpy(image).float().permute(2, 0, 1) 228 | return image 229 | 230 | tmps = [] 231 | if isinstance(image_path, list): 232 | for item in image_path: 233 | tmps.append(pre_process(item, sizes, left, right, top, bottom)) 234 | else: 235 | tmps.append(pre_process(image_path, sizes, left, right, top, bottom)) 236 | image = torch.stack(tmps) / 127.5 - 1 237 | 238 | image = image.to(device=device, dtype=dtype) 239 | return image 240 | 241 | 242 | # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing 243 | 244 | def reset_dpm(scheduler): 245 | if isinstance(scheduler, DPMSolverMultistepSchedulerInject): 246 | scheduler.model_outputs = [ 247 | None, 248 | ] * scheduler.config.solver_order 249 | scheduler.lower_order_nums = 0 250 | 251 | 252 | class StableDiffusionPipeline_LEDITS(DiffusionPipeline): 253 | r""" 254 | Pipeline for text-to-image generation with latent editing. 255 | 256 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 257 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 258 | 259 | This model builds on the implementation of ['StableDiffusionPipeline'] 260 | 261 | Args: 262 | vae ([`AutoencoderKL`]): 263 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 264 | text_encoder ([`CLIPTextModel`]): 265 | Frozen text-encoder. Stable Diffusion uses the text portion of 266 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 267 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 268 | tokenizer (`CLIPTokenizer`): 269 | Tokenizer of class 270 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 271 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 272 | scheduler ([`SchedulerMixin`]): 273 | A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of 274 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 275 | safety_checker ([`Q16SafetyChecker`]): 276 | Classification module that estimates whether generated images could be considered offensive or harmful. 277 | Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. 278 | feature_extractor ([`CLIPImageProcessor`]): 279 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 280 | """ 281 | 282 | _optional_components = ["safety_checker", "feature_extractor"] 283 | 284 | def __init__( 285 | self, 286 | vae: AutoencoderKL, 287 | text_encoder: CLIPTextModel, 288 | tokenizer: CLIPTokenizer, 289 | unet: UNet2DConditionModel, 290 | scheduler: Union[DDIMScheduler,DPMSolverMultistepSchedulerInject], 291 | safety_checker: StableDiffusionSafetyChecker, 292 | feature_extractor: CLIPImageProcessor, 293 | requires_safety_checker: bool = True, 294 | ): 295 | super().__init__() 296 | 297 | if not isinstance(scheduler, DDIMScheduler) or not isinstance(scheduler, DPMSolverMultistepSchedulerInject): 298 | scheduler = DPMSolverMultistepSchedulerInject.from_config(scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2) 299 | logger.warning("This pipeline only supports DDIMScheduler and DPMSolverMultistepSchedulerInject. " 300 | "The scheduler has been changed to DPMSolverMultistepSchedulerInject.") 301 | 302 | if safety_checker is None and requires_safety_checker: 303 | logger.warning( 304 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 305 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 306 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 307 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 308 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 309 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 310 | ) 311 | 312 | if safety_checker is not None and feature_extractor is None: 313 | raise ValueError( 314 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 315 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 316 | ) 317 | 318 | self.register_modules( 319 | vae=vae, 320 | text_encoder=text_encoder, 321 | tokenizer=tokenizer, 322 | unet=unet, 323 | scheduler=scheduler, 324 | safety_checker=safety_checker, 325 | feature_extractor=feature_extractor, 326 | ) 327 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 328 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 329 | self.register_to_config(requires_safety_checker=requires_safety_checker) 330 | 331 | def progress_bar(self, iterable=None, total=None, verbose=True): 332 | if not hasattr(self, "_progress_bar_config"): 333 | self._progress_bar_config = {} 334 | elif not isinstance(self._progress_bar_config, dict): 335 | raise ValueError( 336 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 337 | ) 338 | if not verbose: 339 | return iterable 340 | elif iterable is not None: 341 | return tqdm(iterable, **self._progress_bar_config) 342 | elif total is not None: 343 | return tqdm(total=total, **self._progress_bar_config) 344 | else: 345 | raise ValueError("Either `total` or `iterable` has to be defined.") 346 | 347 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 348 | def run_safety_checker(self, image, device, dtype): 349 | if self.safety_checker is None: 350 | has_nsfw_concept = None 351 | else: 352 | if torch.is_tensor(image): 353 | feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") 354 | else: 355 | feature_extractor_input = self.image_processor.numpy_to_pil(image) 356 | safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) 357 | image, has_nsfw_concept = self.safety_checker( 358 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 359 | ) 360 | return image, has_nsfw_concept 361 | 362 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 363 | def decode_latents(self, latents): 364 | warnings.warn( 365 | "The decode_latents method is deprecated and will be removed in a future version. Please" 366 | " use VaeImageProcessor instead", 367 | FutureWarning, 368 | ) 369 | latents = 1 / self.vae.config.scaling_factor * latents 370 | image = self.vae.decode(latents, return_dict=False)[0] 371 | image = (image / 2 + 0.5).clamp(0, 1) 372 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 373 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 374 | return image 375 | 376 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 377 | def prepare_extra_step_kwargs(self, eta): 378 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 379 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 380 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 381 | # and should be between [0, 1] 382 | 383 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 384 | extra_step_kwargs = {} 385 | if accepts_eta: 386 | extra_step_kwargs["eta"] = eta 387 | 388 | return extra_step_kwargs 389 | 390 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs 391 | def check_inputs( 392 | self, 393 | prompt, 394 | height, 395 | width, 396 | callback_steps, 397 | negative_prompt=None, 398 | prompt_embeds=None, 399 | negative_prompt_embeds=None, 400 | ): 401 | if height % 8 != 0 or width % 8 != 0: 402 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 403 | 404 | if (callback_steps is None) or ( 405 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 406 | ): 407 | raise ValueError( 408 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 409 | f" {type(callback_steps)}." 410 | ) 411 | 412 | if prompt is not None and prompt_embeds is not None: 413 | raise ValueError( 414 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 415 | " only forward one of the two." 416 | ) 417 | elif prompt is None and prompt_embeds is None: 418 | raise ValueError( 419 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 420 | ) 421 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 422 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 423 | 424 | if negative_prompt is not None and negative_prompt_embeds is not None: 425 | raise ValueError( 426 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 427 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 428 | ) 429 | 430 | if prompt_embeds is not None and negative_prompt_embeds is not None: 431 | if prompt_embeds.shape != negative_prompt_embeds.shape: 432 | raise ValueError( 433 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 434 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 435 | f" {negative_prompt_embeds.shape}." 436 | ) 437 | 438 | # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 439 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents): 440 | #shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 441 | 442 | #if latents.shape != shape: 443 | # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 444 | 445 | latents = latents.to(device) 446 | 447 | # scale the initial noise by the standard deviation required by the scheduler 448 | latents = latents * self.scheduler.init_noise_sigma 449 | return latents 450 | 451 | def prepare_unet(self, attention_store, PnP: bool = False): 452 | attn_procs = {} 453 | for name in self.unet.attn_processors.keys(): 454 | if name.startswith("mid_block"): 455 | place_in_unet = "mid" 456 | elif name.startswith("up_blocks"): 457 | place_in_unet = "up" 458 | elif name.startswith("down_blocks"): 459 | place_in_unet = "down" 460 | else: 461 | continue 462 | 463 | if "attn2" in name and place_in_unet != 'mid': 464 | attn_procs[name] = CrossAttnProcessor( 465 | attention_store=attention_store, 466 | place_in_unet=place_in_unet, 467 | PnP=PnP, 468 | editing_prompts=self.enabled_editing_prompts) 469 | else: 470 | attn_procs[name] = AttnProcessor() 471 | 472 | self.unet.set_attn_processor(attn_procs) 473 | 474 | @torch.no_grad() 475 | def __call__( 476 | self, 477 | negative_prompt: Optional[Union[str, List[str]]] = None, 478 | output_type: Optional[str] = "pil", 479 | return_dict: bool = True, 480 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 481 | callback_steps: int = 1, 482 | editing_prompt: Optional[Union[str, List[str]]] = None, 483 | editing_prompt_embeddings: Optional[torch.Tensor] = None, 484 | reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, 485 | edit_guidance_scale: Optional[Union[float, List[float]]] = 5, 486 | edit_warmup_steps: Optional[Union[int, List[int]]] = 0, 487 | edit_cooldown_steps: Optional[Union[int, List[int]]] = None, 488 | edit_threshold: Optional[Union[float, List[float]]] = 0.9, 489 | user_mask: Optional[torch.FloatTensor] = None, 490 | edit_weights: Optional[List[float]] = None, 491 | sem_guidance: Optional[List[torch.Tensor]] = None, 492 | verbose=True, 493 | use_cross_attn_mask: bool = False, 494 | # Attention store (just for visualization purposes) 495 | attn_store_steps: Optional[List[int]] = [], 496 | store_averaged_over_steps: bool = True, 497 | use_intersect_mask: bool = False 498 | ): 499 | r""" 500 | Function invoked when calling the pipeline for generation. 501 | 502 | Args: 503 | prompt (`str` or `List[str]`): 504 | The prompt or prompts to guide the image generation. 505 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 506 | The height in pixels of the generated image. 507 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 508 | The width in pixels of the generated image. 509 | num_inference_steps (`int`, *optional*, defaults to 50): 510 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 511 | expense of slower inference. 512 | guidance_scale (`float`, *optional*, defaults to 7.5): 513 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 514 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 515 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 516 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 517 | usually at the expense of lower image quality. 518 | negative_prompt (`str` or `List[str]`, *optional*): 519 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 520 | if `guidance_scale` is less than `1`). 521 | num_images_per_prompt (`int`, *optional*, defaults to 1): 522 | The number of images to generate per prompt. 523 | eta (`float`, *optional*, defaults to 0.0): 524 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 525 | [`schedulers.DDIMScheduler`], will be ignored for others. 526 | generator (`torch.Generator`, *optional*): 527 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 528 | to make generation deterministic. 529 | latents (`torch.FloatTensor`, *optional*): 530 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 531 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 532 | tensor will ge generated by sampling using the supplied random `generator`. 533 | output_type (`str`, *optional*, defaults to `"pil"`): 534 | The output format of the generate image. Choose between 535 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 536 | return_dict (`bool`, *optional*, defaults to `True`): 537 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 538 | plain tuple. 539 | callback (`Callable`, *optional*): 540 | A function that will be called every `callback_steps` steps during inference. The function will be 541 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 542 | callback_steps (`int`, *optional*, defaults to 1): 543 | The frequency at which the `callback` function will be called. If not specified, the callback will be 544 | called at every step. 545 | editing_prompt (`str` or `List[str]`, *optional*): 546 | The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting 547 | `editing_prompt = None`. Guidance direction of prompt should be specified via 548 | `reverse_editing_direction`. 549 | editing_prompt_embeddings (`torch.Tensor>`, *optional*): 550 | Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be 551 | specified via `reverse_editing_direction`. 552 | reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): 553 | Whether the corresponding prompt in `editing_prompt` should be increased or decreased. 554 | edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): 555 | Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. 556 | `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA 557 | Paper](https://arxiv.org/pdf/2301.12247.pdf). 558 | edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): 559 | Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum 560 | will still be calculated for those steps and applied once all warmup periods are over. 561 | `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). 562 | edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): 563 | Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. 564 | edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): 565 | Threshold of semantic guidance. 566 | edit_momentum_scale (`float`, *optional*, defaults to 0.1): 567 | Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 568 | momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller 569 | than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are 570 | finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA 571 | Paper](https://arxiv.org/pdf/2301.12247.pdf). 572 | edit_mom_beta (`float`, *optional*, defaults to 0.4): 573 | Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous 574 | momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller 575 | than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA 576 | Paper](https://arxiv.org/pdf/2301.12247.pdf). 577 | edit_weights (`List[float]`, *optional*, defaults to `None`): 578 | Indicates how much each individual concept should influence the overall guidance. If no weights are 579 | provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA 580 | Paper](https://arxiv.org/pdf/2301.12247.pdf). 581 | sem_guidance (`List[torch.Tensor]`, *optional*): 582 | List of pre-generated guidance vectors to be applied at generation. Length of the list has to 583 | correspond to `num_inference_steps`. 584 | 585 | Returns: 586 | [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: 587 | [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True, 588 | otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the 589 | second element is a list of `bool`s denoting whether the corresponding generated image likely represents 590 | "not-safe-for-work" (nsfw) content, according to the `safety_checker`. 591 | """ 592 | eta = self.eta 593 | num_images_per_prompt = 1 594 | latents = self.init_latents 595 | 596 | use_ddpm = True 597 | zs = self.zs 598 | reset_dpm(self.scheduler) 599 | 600 | if use_intersect_mask: 601 | use_cross_attn_mask = True 602 | 603 | if use_cross_attn_mask: 604 | self.smoothing = GaussianSmoothing(self.device) 605 | 606 | org_prompt = "" 607 | 608 | # 2. Define call parameters 609 | batch_size = self.batch_size 610 | 611 | if editing_prompt: 612 | enable_edit_guidance = True 613 | if isinstance(editing_prompt, str): 614 | editing_prompt = [editing_prompt] 615 | self.enabled_editing_prompts = len(editing_prompt) 616 | elif editing_prompt_embeddings is not None: 617 | enable_edit_guidance = True 618 | self.enabled_editing_prompts = editing_prompt_embeddings.shape[0] 619 | else: 620 | self.enabled_editing_prompts = 0 621 | enable_edit_guidance = False 622 | 623 | if enable_edit_guidance: 624 | # get safety text embeddings 625 | if editing_prompt_embeddings is None: 626 | edit_concepts_input = self.tokenizer( 627 | [x for item in editing_prompt for x in repeat(item, batch_size)], 628 | padding="max_length", 629 | max_length=self.tokenizer.model_max_length, 630 | truncation=True, 631 | return_tensors="pt", 632 | return_length=True 633 | ) 634 | 635 | num_edit_tokens = edit_concepts_input.length - 2 # not counting startoftext and endoftext 636 | edit_concepts_input_ids = edit_concepts_input.input_ids 637 | untruncated_ids = self.tokenizer( 638 | [x for item in editing_prompt for x in repeat(item, batch_size)], 639 | padding="longest", 640 | return_tensors="pt").input_ids 641 | 642 | if untruncated_ids.shape[-1] >= edit_concepts_input_ids.shape[-1] and not torch.equal( 643 | edit_concepts_input_ids, untruncated_ids 644 | ): 645 | removed_text = self.tokenizer.batch_decode( 646 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] 647 | ) 648 | logger.warning( 649 | "The following part of your input was truncated because CLIP can only handle sequences up to" 650 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 651 | ) 652 | 653 | edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] 654 | else: 655 | edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) 656 | 657 | # duplicate text embeddings for each generation per prompt, using mps friendly method 658 | bs_embed_edit, seq_len_edit, _ = edit_concepts.shape 659 | edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) 660 | edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) 661 | 662 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 663 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 664 | # corresponds to doing no classifier free guidance. 665 | # get unconditional embeddings for classifier free guidance 666 | 667 | 668 | uncond_tokens: List[str] 669 | if negative_prompt is None: 670 | uncond_tokens = [""] 671 | elif isinstance(negative_prompt, str): 672 | uncond_tokens = [negative_prompt] 673 | elif batch_size != len(negative_prompt): 674 | raise ValueError( 675 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 676 | f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 677 | " the batch size of `prompt`." 678 | ) 679 | else: 680 | uncond_tokens = negative_prompt 681 | 682 | max_length = self.tokenizer.model_max_length 683 | uncond_input = self.tokenizer( 684 | uncond_tokens, 685 | padding="max_length", 686 | max_length=max_length, 687 | truncation=True, 688 | return_tensors="pt", 689 | ) 690 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 691 | 692 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 693 | seq_len = uncond_embeddings.shape[1] 694 | uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) 695 | uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) 696 | 697 | # For classifier free guidance, we need to do two forward passes. 698 | # Here we concatenate the unconditional and text embeddings into a single batch 699 | # to avoid doing two forward passes 700 | if enable_edit_guidance: 701 | text_embeddings = torch.cat([uncond_embeddings, edit_concepts]) 702 | self.text_cross_attention_maps = \ 703 | ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt) 704 | else: 705 | text_embeddings = torch.cat([uncond_embeddings]) 706 | 707 | # 4. Prepare timesteps 708 | #self.scheduler.set_timesteps(num_inference_steps, device=self.device) 709 | timesteps = self.scheduler.timesteps 710 | 711 | if use_ddpm: 712 | t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} 713 | timesteps = timesteps[-zs.shape[0]:] 714 | 715 | if use_cross_attn_mask: 716 | self.attention_store = AttentionStore(average=store_averaged_over_steps, batch_size=batch_size) 717 | self.prepare_unet(self.attention_store, PnP=False) 718 | # 5. Prepare latent variables 719 | num_channels_latents = self.unet.config.in_channels 720 | latents = self.prepare_latents( 721 | batch_size * num_images_per_prompt, 722 | num_channels_latents, 723 | None, 724 | None, 725 | text_embeddings.dtype, 726 | self.device, 727 | latents, 728 | ) 729 | 730 | # 6. Prepare extra step kwargs. 731 | extra_step_kwargs = self.prepare_extra_step_kwargs(eta) 732 | 733 | self.uncond_estimates = None 734 | self.edit_estimates = None 735 | self.sem_guidance = None 736 | self.activation_mask = None 737 | 738 | for i, t in enumerate(self.progress_bar(timesteps, verbose=verbose)): 739 | # expand the latents if we are doing classifier free guidance 740 | 741 | if enable_edit_guidance: 742 | latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts)) 743 | else: 744 | latent_model_input = latents 745 | 746 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 747 | 748 | text_embed_input = text_embeddings 749 | 750 | # predict the noise residual 751 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample 752 | 753 | 754 | noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64] 755 | noise_pred_uncond = noise_pred_out[0] 756 | noise_pred_edit_concepts = noise_pred_out[1:] 757 | 758 | # default text guidance 759 | noise_guidance = torch.zeros_like(noise_pred_uncond) 760 | 761 | if self.uncond_estimates is None: 762 | self.uncond_estimates = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) 763 | self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() 764 | 765 | if sem_guidance is not None and len(sem_guidance) > i: 766 | edit_guidance = sem_guidance[i].to(self.device) 767 | noise_guidance = noise_guidance + edit_guidance 768 | 769 | elif enable_edit_guidance: 770 | if self.activation_mask is None: 771 | self.activation_mask = torch.zeros( 772 | (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) 773 | ) 774 | if self.edit_estimates is None and enable_edit_guidance: 775 | self.edit_estimates = torch.zeros( 776 | (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) 777 | ) 778 | 779 | if self.sem_guidance is None: 780 | self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) 781 | 782 | concept_weights = torch.zeros( 783 | (len(noise_pred_edit_concepts), noise_guidance.shape[0]), 784 | device=self.device, 785 | dtype=noise_guidance.dtype, 786 | ) 787 | noise_guidance_edit = torch.zeros( 788 | (len(noise_pred_edit_concepts), *noise_guidance.shape), 789 | device=self.device, 790 | dtype=noise_guidance.dtype, 791 | ) 792 | warmup_inds = [] 793 | # noise_guidance_edit = torch.zeros_like(noise_guidance) 794 | for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): 795 | self.edit_estimates[i, c] = noise_pred_edit_concept 796 | if isinstance(edit_warmup_steps, list): 797 | edit_warmup_steps_c = edit_warmup_steps[c] 798 | else: 799 | edit_warmup_steps_c = edit_warmup_steps 800 | if i >= edit_warmup_steps_c: 801 | warmup_inds.append(c) 802 | else: 803 | continue 804 | 805 | if isinstance(edit_guidance_scale, list): 806 | edit_guidance_scale_c = edit_guidance_scale[c] 807 | else: 808 | edit_guidance_scale_c = edit_guidance_scale 809 | 810 | if isinstance(edit_threshold, list): 811 | edit_threshold_c = edit_threshold[c] 812 | else: 813 | edit_threshold_c = edit_threshold 814 | if isinstance(reverse_editing_direction, list): 815 | reverse_editing_direction_c = reverse_editing_direction[c] 816 | else: 817 | reverse_editing_direction_c = reverse_editing_direction 818 | if edit_weights: 819 | edit_weight_c = edit_weights[c] 820 | else: 821 | edit_weight_c = 1.0 822 | 823 | if isinstance(edit_cooldown_steps, list): 824 | edit_cooldown_steps_c = edit_cooldown_steps[c] 825 | elif edit_cooldown_steps is None: 826 | edit_cooldown_steps_c = i + 1 827 | else: 828 | edit_cooldown_steps_c = edit_cooldown_steps 829 | 830 | if i >= edit_cooldown_steps_c: 831 | noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) 832 | continue 833 | 834 | noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond 835 | # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) 836 | tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) 837 | 838 | tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) 839 | if reverse_editing_direction_c: 840 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 841 | concept_weights[c, :] = tmp_weights 842 | 843 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c 844 | 845 | if user_mask is not None: 846 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask 847 | 848 | if use_cross_attn_mask: 849 | out = self.attention_store.aggregate_attention( 850 | attention_maps=self.attention_store.step_store, 851 | prompts=self.text_cross_attention_maps, 852 | res=16, 853 | from_where=["up", "down"], 854 | is_cross=True, 855 | select=self.text_cross_attention_maps.index(editing_prompt[c]), 856 | ) 857 | attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext 858 | 859 | # average over all tokens 860 | assert (attn_map.shape[3] == num_edit_tokens[c]) 861 | attn_map = torch.sum(attn_map, dim=3) 862 | 863 | # gaussian_smoothing 864 | attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect") 865 | attn_map = self.smoothing(attn_map).squeeze(1) 866 | 867 | # create binary mask 868 | tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1) 869 | attn_mask = torch.where(attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1,16,16), 1.0, 0.0) 870 | 871 | # resolution must match latent space dimension 872 | attn_mask = F.interpolate( 873 | attn_mask.unsqueeze(1), 874 | noise_guidance_edit_tmp.shape[-2:] # 64,64 875 | ).repeat(1, 4, 1, 1) 876 | self.activation_mask[i, c] = attn_mask.detach().cpu() 877 | if not use_intersect_mask: 878 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask 879 | 880 | if use_intersect_mask: 881 | noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) 882 | noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, 883 | keepdim=True) 884 | noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) 885 | 886 | # torch.quantile function expects float32 887 | if noise_guidance_edit_tmp_quantile.dtype == torch.float32: 888 | tmp = torch.quantile( 889 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2), 890 | edit_threshold_c, 891 | dim=2, 892 | keepdim=False, 893 | ) 894 | else: 895 | tmp = torch.quantile( 896 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), 897 | edit_threshold_c, 898 | dim=2, 899 | keepdim=False, 900 | ).to(noise_guidance_edit_tmp_quantile.dtype) 901 | 902 | intersect_mask = torch.where( 903 | noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], 904 | torch.ones_like(noise_guidance_edit_tmp), 905 | torch.zeros_like(noise_guidance_edit_tmp), 906 | ) * attn_mask 907 | 908 | self.activation_mask[i, c] = intersect_mask.detach().cpu() 909 | 910 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask 911 | 912 | elif not use_cross_attn_mask: 913 | # calculate quantile 914 | noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) 915 | noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, 916 | keepdim=True) 917 | noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) 918 | 919 | # torch.quantile function expects float32 920 | if noise_guidance_edit_tmp_quantile.dtype == torch.float32: 921 | tmp = torch.quantile( 922 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2), 923 | edit_threshold_c, 924 | dim=2, 925 | keepdim=False, 926 | ) 927 | else: 928 | tmp = torch.quantile( 929 | noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), 930 | edit_threshold_c, 931 | dim=2, 932 | keepdim=False, 933 | ).to(noise_guidance_edit_tmp_quantile.dtype) 934 | 935 | self.activation_mask[i, c] = torch.where( 936 | noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], 937 | torch.ones_like(noise_guidance_edit_tmp), 938 | torch.zeros_like(noise_guidance_edit_tmp), 939 | ).detach().cpu() 940 | 941 | noise_guidance_edit_tmp = torch.where( 942 | noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], 943 | noise_guidance_edit_tmp, 944 | torch.zeros_like(noise_guidance_edit_tmp), 945 | ) 946 | 947 | noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp 948 | 949 | warmup_inds = torch.tensor(warmup_inds).to(self.device) 950 | concept_weights = torch.index_select(concept_weights, 0, warmup_inds) 951 | concept_weights = torch.where( 952 | concept_weights < 0, torch.zeros_like(concept_weights), concept_weights 953 | ) 954 | 955 | concept_weights = torch.nan_to_num(concept_weights) 956 | 957 | noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) 958 | 959 | noise_guidance = noise_guidance + noise_guidance_edit 960 | self.sem_guidance[i] = noise_guidance_edit.detach().cpu() 961 | 962 | noise_pred = noise_pred_uncond + noise_guidance 963 | 964 | # compute the previous noisy sample x_t -> x_t-1 965 | if use_ddpm: 966 | idx = t_to_idx[int(t)] 967 | latents = self.scheduler.step(noise_pred, t, latents, variance_noise=zs[idx], 968 | **extra_step_kwargs).prev_sample 969 | 970 | else: # if not use_ddpm: 971 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 972 | 973 | # step callback 974 | if use_cross_attn_mask: 975 | store_step = i in attn_store_steps 976 | if store_step: 977 | print(f"storing attention for step {i}") 978 | self.attention_store.between_steps(store_step) 979 | 980 | # call the callback, if provided 981 | if callback is not None and i % callback_steps == 0: 982 | callback(i, t, latents) 983 | 984 | # 8. Post-processing 985 | if not output_type == "latent": 986 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 987 | image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) 988 | else: 989 | image = latents 990 | has_nsfw_concept = None 991 | 992 | if has_nsfw_concept is None: 993 | do_denormalize = [True] * image.shape[0] 994 | else: 995 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 996 | 997 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 998 | 999 | if not return_dict: 1000 | return (image, has_nsfw_concept) 1001 | 1002 | return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 1003 | 1004 | def encode_text(self, prompts): 1005 | text_inputs = self.tokenizer( 1006 | prompts, 1007 | padding="max_length", 1008 | max_length=self.tokenizer.model_max_length, 1009 | return_tensors="pt", 1010 | ) 1011 | text_input_ids = text_inputs.input_ids 1012 | 1013 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: 1014 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) 1015 | logger.warning( 1016 | "The following part of your input was truncated because CLIP can only handle sequences up to" 1017 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 1018 | ) 1019 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 1020 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 1021 | 1022 | return text_embeddings 1023 | 1024 | @torch.no_grad() 1025 | def invert(self, 1026 | image_path: str, 1027 | source_prompt: str = "", 1028 | source_guidance_scale=3.5, 1029 | num_inversion_steps: int = 30, 1030 | skip: float = 0.15, 1031 | eta: float = 1.0, 1032 | generator: Optional[torch.Generator] = None, 1033 | verbose=True, 1034 | ): 1035 | """ 1036 | Inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf, 1037 | based on the code in https://github.com/inbarhub/DDPM_inversion 1038 | 1039 | returns: 1040 | zs - noise maps 1041 | xts - intermediate inverted latents 1042 | """ 1043 | 1044 | self.eta = eta 1045 | assert (self.eta > 0) 1046 | 1047 | train_steps = self.scheduler.config.num_train_timesteps 1048 | timesteps = torch.from_numpy( 1049 | np.linspace(train_steps - skip * train_steps - 1, 1, num_inversion_steps).astype(np.int64)).to(self.device) 1050 | 1051 | 1052 | self.num_inversion_steps = timesteps.shape[0] 1053 | self.scheduler.num_inference_steps = timesteps.shape[0] 1054 | self.scheduler.timesteps = timesteps 1055 | 1056 | 1057 | # 1. get embeddings 1058 | 1059 | uncond_embedding = self.encode_text("") 1060 | 1061 | # 2. encode image 1062 | x0 = self.encode_image(image_path, dtype=uncond_embedding.dtype) 1063 | self.batch_size = x0.shape[0] 1064 | 1065 | if not source_prompt == "": 1066 | text_embeddings = self.encode_text(source_prompt).repeat((self.batch_size, 1, 1)) 1067 | uncond_embedding = uncond_embedding.repeat((self.batch_size, 1, 1)) 1068 | # autoencoder reconstruction 1069 | image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False)[0] 1070 | image_rec = self.image_processor.postprocess(image_rec, output_type="pil") 1071 | 1072 | # 3. find zs and xts 1073 | variance_noise_shape = ( 1074 | self.num_inversion_steps, 1075 | *x0.shape) 1076 | 1077 | # intermediate latents 1078 | t_to_idx = {int(v): k for k, v in enumerate(timesteps)} 1079 | xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) 1080 | 1081 | for t in reversed(timesteps): 1082 | idx = self.num_inversion_steps-t_to_idx[int(t)] - 1 1083 | noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype) 1084 | xts[idx] = self.scheduler.add_noise(x0, noise, t) 1085 | xts = torch.cat([x0.unsqueeze(0), xts], dim=0) 1086 | 1087 | reset_dpm(self.scheduler) 1088 | # noise maps 1089 | zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) 1090 | 1091 | for t in self.progress_bar(timesteps, verbose=verbose): 1092 | 1093 | idx = self.num_inversion_steps-t_to_idx[int(t)]-1 1094 | # 1. predict noise residual 1095 | xt = xts[idx+1] 1096 | 1097 | noise_pred = self.unet(xt, timestep=t, encoder_hidden_states=uncond_embedding).sample 1098 | 1099 | if not source_prompt == "": 1100 | noise_pred_cond = self.unet(xt, timestep=t, encoder_hidden_states=text_embeddings).sample 1101 | noise_pred = noise_pred + source_guidance_scale * (noise_pred_cond - noise_pred) 1102 | 1103 | xtm1 = xts[idx] 1104 | z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, eta) 1105 | zs[idx] = z 1106 | 1107 | # correction to avoid error accumulation 1108 | xts[idx] = xtm1_corrected 1109 | 1110 | # TODO: I don't think that the noise map for the last step should be discarded ?! 1111 | # if not zs is None: 1112 | # zs[-1] = torch.zeros_like(zs[-1]) 1113 | self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1) 1114 | zs = zs.flip(0) 1115 | self.zs = zs 1116 | 1117 | 1118 | 1119 | return zs, xts, image_rec 1120 | 1121 | @torch.no_grad() 1122 | def encode_image(self, image_path, dtype=None): 1123 | image = load_512(image_path, 1124 | sizes=(int(self.unet.sample_size * self.vae_scale_factor), int(self.unet.sample_size * self.vae_scale_factor*1.5)), 1125 | device=self.device, 1126 | dtype=dtype) 1127 | x0 = self.vae.encode(image).latent_dist.mode() 1128 | x0 = self.vae.config.scaling_factor * x0 1129 | return x0 1130 | 1131 | def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): 1132 | # 1. get previous step value (=t-1) 1133 | prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps 1134 | 1135 | # 2. compute alphas, betas 1136 | alpha_prod_t = scheduler.alphas_cumprod[timestep] 1137 | alpha_prod_t_prev = ( 1138 | scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod 1139 | ) 1140 | 1141 | beta_prod_t = 1 - alpha_prod_t 1142 | 1143 | # 3. compute predicted original sample from predicted noise also called 1144 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 1145 | pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) 1146 | 1147 | # 4. Clip "predicted x_0" 1148 | if scheduler.config.clip_sample: 1149 | pred_original_sample = torch.clamp(pred_original_sample, -1, 1) 1150 | 1151 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 1152 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 1153 | variance = scheduler._get_variance(timestep, prev_timestep) 1154 | std_dev_t = eta * variance ** (0.5) 1155 | 1156 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 1157 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * noise_pred 1158 | 1159 | # modifed so that updated xtm1 is returned as well (to avoid error accumulation) 1160 | mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 1161 | noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) 1162 | 1163 | return noise, mu_xt + (eta * variance ** 0.5) * noise 1164 | 1165 | # Copied from pipelines.StableDiffusion.CycleDiffusionPipeline.compute_noise 1166 | def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): 1167 | 1168 | def first_order_update(model_output, timestep, prev_timestep, sample): 1169 | lambda_t, lambda_s = scheduler.lambda_t[prev_timestep], scheduler.lambda_t[timestep] 1170 | alpha_t, alpha_s = scheduler.alpha_t[prev_timestep], scheduler.alpha_t[timestep] 1171 | sigma_t, sigma_s = scheduler.sigma_t[prev_timestep], scheduler.sigma_t[timestep] 1172 | h = lambda_t - lambda_s 1173 | 1174 | mu_xt = ( 1175 | (sigma_t / sigma_s * torch.exp(-h)) * sample 1176 | + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output 1177 | ) 1178 | sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) 1179 | 1180 | noise = (prev_latents - mu_xt) / sigma 1181 | 1182 | prev_sample = mu_xt + sigma * noise 1183 | 1184 | return noise, prev_sample 1185 | def second_order_update(model_output_list, timestep_list, prev_timestep, sample): 1186 | t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] 1187 | m0, m1 = model_output_list[-1], model_output_list[-2] 1188 | lambda_t, lambda_s0, lambda_s1 = scheduler.lambda_t[t], scheduler.lambda_t[s0], scheduler.lambda_t[s1] 1189 | alpha_t, alpha_s0 = scheduler.alpha_t[t], scheduler.alpha_t[s0] 1190 | sigma_t, sigma_s0 = scheduler.sigma_t[t], scheduler.sigma_t[s0] 1191 | h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 1192 | r0 = h_0 / h 1193 | D0, D1 = m0, (1.0 / r0) * (m0 - m1) 1194 | 1195 | mu_xt = ( 1196 | (sigma_t / sigma_s0 * torch.exp(-h)) * sample 1197 | + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 1198 | + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 1199 | ) 1200 | sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) 1201 | 1202 | noise = (prev_latents - mu_xt) / sigma 1203 | 1204 | prev_sample = mu_xt + sigma * noise 1205 | 1206 | return noise, prev_sample 1207 | 1208 | step_index = (scheduler.timesteps == timestep).nonzero() 1209 | if len(step_index) == 0: 1210 | step_index = len(scheduler.timesteps) - 1 1211 | else: 1212 | step_index = step_index.item() 1213 | 1214 | prev_timestep = 0 if step_index == len(scheduler.timesteps) - 1 else scheduler.timesteps[step_index + 1] 1215 | 1216 | model_output = scheduler.convert_model_output(noise_pred, timestep, latents) 1217 | 1218 | for i in range(scheduler.config.solver_order - 1): 1219 | scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] 1220 | scheduler.model_outputs[-1] = model_output 1221 | 1222 | if scheduler.lower_order_nums < 1: 1223 | noise, prev_sample = first_order_update(model_output, timestep, prev_timestep, latents) 1224 | else: 1225 | timestep_list = [scheduler.timesteps[step_index - 1], timestep] 1226 | noise, prev_sample = second_order_update(scheduler.model_outputs, timestep_list, prev_timestep, latents) 1227 | 1228 | if scheduler.lower_order_nums < scheduler.config.solver_order: 1229 | scheduler.lower_order_nums += 1 1230 | 1231 | return noise, prev_sample 1232 | 1233 | def compute_noise(scheduler, *args): 1234 | if isinstance(scheduler, DDIMScheduler): 1235 | return compute_noise_ddim(scheduler, *args) 1236 | elif isinstance(scheduler, DPMSolverMultistepSchedulerInject) and scheduler.config.algorithm_type == 'sde-dpmsolver++'\ 1237 | and scheduler.config.solver_order == 2: 1238 | return compute_noise_sde_dpm_pp_2nd(scheduler, *args) 1239 | else: 1240 | raise NotImplementedError 1241 | -------------------------------------------------------------------------------- /src/leditspp/scheduling_dpmsolver_multistep_inject.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver 16 | 17 | import math 18 | from typing import List, Optional, Tuple, Union 19 | 20 | import numpy as np 21 | import torch 22 | 23 | from diffusers.configuration_utils import ConfigMixin, register_to_config 24 | from diffusers.utils import randn_tensor 25 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput 26 | 27 | 28 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 29 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): 30 | """ 31 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 32 | (1-beta) over time from t = [0,1]. 33 | 34 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 35 | to that part of the diffusion process. 36 | 37 | 38 | Args: 39 | num_diffusion_timesteps (`int`): the number of betas to produce. 40 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 41 | prevent singularities. 42 | 43 | Returns: 44 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 45 | """ 46 | 47 | def alpha_bar(time_step): 48 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 49 | 50 | betas = [] 51 | for i in range(num_diffusion_timesteps): 52 | t1 = i / num_diffusion_timesteps 53 | t2 = (i + 1) / num_diffusion_timesteps 54 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 55 | return torch.tensor(betas, dtype=torch.float32) 56 | 57 | 58 | class DPMSolverMultistepSchedulerInject(SchedulerMixin, ConfigMixin): 59 | """ 60 | DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with 61 | the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality 62 | samples, and it can generate quite good samples even in only 10 steps. 63 | 64 | For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 65 | 66 | Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We 67 | recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. 68 | 69 | We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space 70 | diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic 71 | thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as 72 | stable-diffusion). 73 | 74 | We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse 75 | diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the 76 | second-order `sde-dpmsolver++`. 77 | 78 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 79 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 80 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 81 | [`~SchedulerMixin.from_pretrained`] functions. 82 | 83 | Args: 84 | num_train_timesteps (`int`): number of diffusion steps used to train the model. 85 | beta_start (`float`): the starting `beta` value of inference. 86 | beta_end (`float`): the final `beta` value. 87 | beta_schedule (`str`): 88 | the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 89 | `linear`, `scaled_linear`, or `squaredcos_cap_v2`. 90 | trained_betas (`np.ndarray`, optional): 91 | option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. 92 | solver_order (`int`, default `2`): 93 | the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided 94 | sampling, and `solver_order=3` for unconditional sampling. 95 | prediction_type (`str`, default `epsilon`, optional): 96 | prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion 97 | process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 98 | https://imagen.research.google/video/paper.pdf) 99 | thresholding (`bool`, default `False`): 100 | whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). 101 | For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to 102 | use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion 103 | models (such as stable-diffusion). 104 | dynamic_thresholding_ratio (`float`, default `0.995`): 105 | the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen 106 | (https://arxiv.org/abs/2205.11487). 107 | sample_max_value (`float`, default `1.0`): 108 | the threshold value for dynamic thresholding. Valid only when `thresholding=True` and 109 | `algorithm_type="dpmsolver++`. 110 | algorithm_type (`str`, default `dpmsolver++`): 111 | the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or 112 | `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and 113 | the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use 114 | `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). 115 | solver_type (`str`, default `midpoint`): 116 | the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects 117 | the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are 118 | slightly better, so we recommend to use the `midpoint` type. 119 | lower_order_final (`bool`, default `True`): 120 | whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically 121 | find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. 122 | use_karras_sigmas (`bool`, *optional*, defaults to `False`): 123 | This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the 124 | noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence 125 | of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. 126 | lambda_min_clipped (`float`, default `-inf`): 127 | the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for 128 | cosine (squaredcos_cap_v2) noise schedule. 129 | variance_type (`str`, *optional*): 130 | Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's 131 | guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the 132 | Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on 133 | diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's 134 | guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the 135 | Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on 136 | diffusion ODEs. 137 | """ 138 | 139 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 140 | order = 1 141 | 142 | @register_to_config 143 | def __init__( 144 | self, 145 | num_train_timesteps: int = 1000, 146 | beta_start: float = 0.0001, 147 | beta_end: float = 0.02, 148 | beta_schedule: str = "linear", 149 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 150 | solver_order: int = 2, 151 | prediction_type: str = "epsilon", 152 | thresholding: bool = False, 153 | dynamic_thresholding_ratio: float = 0.995, 154 | sample_max_value: float = 1.0, 155 | algorithm_type: str = "dpmsolver++", 156 | solver_type: str = "midpoint", 157 | lower_order_final: bool = True, 158 | use_karras_sigmas: Optional[bool] = False, 159 | lambda_min_clipped: float = -float("inf"), 160 | variance_type: Optional[str] = None, 161 | ): 162 | if trained_betas is not None: 163 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 164 | elif beta_schedule == "linear": 165 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 166 | elif beta_schedule == "scaled_linear": 167 | # this schedule is very specific to the latent diffusion model. 168 | self.betas = ( 169 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 170 | ) 171 | elif beta_schedule == "squaredcos_cap_v2": 172 | # Glide cosine schedule 173 | self.betas = betas_for_alpha_bar(num_train_timesteps) 174 | else: 175 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 176 | 177 | self.alphas = 1.0 - self.betas 178 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 179 | # Currently we only support VP-type noise schedule 180 | self.alpha_t = torch.sqrt(self.alphas_cumprod) 181 | self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) 182 | self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) 183 | 184 | # standard deviation of the initial noise distribution 185 | self.init_noise_sigma = 1.0 186 | 187 | # settings for DPM-Solver 188 | if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: 189 | if algorithm_type == "deis": 190 | self.register_to_config(algorithm_type="dpmsolver++") 191 | else: 192 | raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") 193 | 194 | if solver_type not in ["midpoint", "heun"]: 195 | if solver_type in ["logrho", "bh1", "bh2"]: 196 | self.register_to_config(solver_type="midpoint") 197 | else: 198 | raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") 199 | 200 | # setable values 201 | self.num_inference_steps = None 202 | timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() 203 | self.timesteps = torch.from_numpy(timesteps) 204 | self.model_outputs = [None] * solver_order 205 | self.lower_order_nums = 0 206 | self.use_karras_sigmas = use_karras_sigmas 207 | 208 | def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): 209 | """ 210 | Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. 211 | 212 | Args: 213 | num_inference_steps (`int`): 214 | the number of diffusion steps used when generating samples with a pre-trained model. 215 | device (`str` or `torch.device`, optional): 216 | the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 217 | """ 218 | # Clipping the minimum of all lambda(t) for numerical stability. 219 | # This is critical for cosine (squaredcos_cap_v2) noise schedule. 220 | clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) 221 | timesteps = ( 222 | np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) 223 | .round()[::-1][:-1] 224 | .copy() 225 | .astype(np.int64) 226 | ) 227 | 228 | if self.use_karras_sigmas: 229 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 230 | log_sigmas = np.log(sigmas) 231 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) 232 | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() 233 | timesteps = np.flip(timesteps).copy().astype(np.int64) 234 | 235 | # when num_inference_steps == num_train_timesteps, we can end up with 236 | # duplicates in timesteps. 237 | _, unique_indices = np.unique(timesteps, return_index=True) 238 | timesteps = timesteps[np.sort(unique_indices)] 239 | 240 | self.timesteps = torch.from_numpy(timesteps).to(device) 241 | 242 | self.num_inference_steps = len(timesteps) 243 | 244 | self.model_outputs = [ 245 | None, 246 | ] * self.config.solver_order 247 | self.lower_order_nums = 0 248 | 249 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 250 | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: 251 | """ 252 | "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the 253 | prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by 254 | s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing 255 | pixels from saturation at each step. We find that dynamic thresholding results in significantly better 256 | photorealism as well as better image-text alignment, especially when using very large guidance weights." 257 | 258 | https://arxiv.org/abs/2205.11487 259 | """ 260 | dtype = sample.dtype 261 | batch_size, channels, height, width = sample.shape 262 | 263 | if dtype not in (torch.float32, torch.float64): 264 | sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half 265 | 266 | # Flatten sample for doing quantile calculation along each image 267 | sample = sample.reshape(batch_size, channels * height * width) 268 | 269 | abs_sample = sample.abs() # "a certain percentile absolute pixel value" 270 | 271 | s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) 272 | s = torch.clamp( 273 | s, min=1, max=self.config.sample_max_value 274 | ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] 275 | 276 | s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 277 | sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" 278 | 279 | sample = sample.reshape(batch_size, channels, height, width) 280 | sample = sample.to(dtype) 281 | 282 | return sample 283 | 284 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 285 | def _sigma_to_t(self, sigma, log_sigmas): 286 | # get log sigma 287 | log_sigma = np.log(sigma) 288 | 289 | # get distribution 290 | dists = log_sigma - log_sigmas[:, np.newaxis] 291 | 292 | # get sigmas range 293 | low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) 294 | high_idx = low_idx + 1 295 | 296 | low = log_sigmas[low_idx] 297 | high = log_sigmas[high_idx] 298 | 299 | # interpolate sigmas 300 | w = (low - log_sigma) / (low - high) 301 | w = np.clip(w, 0, 1) 302 | 303 | # transform interpolation to time range 304 | t = (1 - w) * low_idx + w * high_idx 305 | t = t.reshape(sigma.shape) 306 | return t 307 | 308 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras 309 | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: 310 | """Constructs the noise schedule of Karras et al. (2022).""" 311 | 312 | sigma_min: float = in_sigmas[-1].item() 313 | sigma_max: float = in_sigmas[0].item() 314 | 315 | rho = 7.0 # 7.0 is the value used in the paper 316 | ramp = np.linspace(0, 1, num_inference_steps) 317 | min_inv_rho = sigma_min ** (1 / rho) 318 | max_inv_rho = sigma_max ** (1 / rho) 319 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 320 | return sigmas 321 | 322 | def convert_model_output( 323 | self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor 324 | ) -> torch.FloatTensor: 325 | """ 326 | Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. 327 | 328 | DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to 329 | discretize an integral of the data prediction model. So we need to first convert the model output to the 330 | corresponding type to match the algorithm. 331 | 332 | Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or 333 | DPM-Solver++ for both noise prediction model and data prediction model. 334 | 335 | Args: 336 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 337 | timestep (`int`): current discrete timestep in the diffusion chain. 338 | sample (`torch.FloatTensor`): 339 | current instance of sample being created by diffusion process. 340 | 341 | Returns: 342 | `torch.FloatTensor`: the converted model output. 343 | """ 344 | 345 | # DPM-Solver++ needs to solve an integral of the data prediction model. 346 | if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: 347 | if self.config.prediction_type == "epsilon": 348 | # DPM-Solver and DPM-Solver++ only need the "mean" output. 349 | if self.config.variance_type in ["learned", "learned_range"]: 350 | model_output = model_output[:, :3] 351 | alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] 352 | x0_pred = (sample - sigma_t * model_output) / alpha_t 353 | elif self.config.prediction_type == "sample": 354 | x0_pred = model_output 355 | elif self.config.prediction_type == "v_prediction": 356 | alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] 357 | x0_pred = alpha_t * sample - sigma_t * model_output 358 | else: 359 | raise ValueError( 360 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 361 | " `v_prediction` for the DPMSolverMultistepScheduler." 362 | ) 363 | 364 | if self.config.thresholding: 365 | x0_pred = self._threshold_sample(x0_pred) 366 | 367 | return x0_pred 368 | 369 | # DPM-Solver needs to solve an integral of the noise prediction model. 370 | elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: 371 | if self.config.prediction_type == "epsilon": 372 | # DPM-Solver and DPM-Solver++ only need the "mean" output. 373 | if self.config.variance_type in ["learned", "learned_range"]: 374 | epsilon = model_output[:, :3] 375 | else: 376 | epsilon = model_output 377 | elif self.config.prediction_type == "sample": 378 | alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] 379 | epsilon = (sample - alpha_t * model_output) / sigma_t 380 | elif self.config.prediction_type == "v_prediction": 381 | alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] 382 | epsilon = alpha_t * model_output + sigma_t * sample 383 | else: 384 | raise ValueError( 385 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 386 | " `v_prediction` for the DPMSolverMultistepScheduler." 387 | ) 388 | 389 | if self.config.thresholding: 390 | alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] 391 | x0_pred = (sample - sigma_t * epsilon) / alpha_t 392 | x0_pred = self._threshold_sample(x0_pred) 393 | epsilon = (sample - alpha_t * x0_pred) / sigma_t 394 | 395 | return epsilon 396 | 397 | def dpm_solver_first_order_update( 398 | self, 399 | model_output: torch.FloatTensor, 400 | timestep: int, 401 | prev_timestep: int, 402 | sample: torch.FloatTensor, 403 | noise: Optional[torch.FloatTensor] = None, 404 | ) -> torch.FloatTensor: 405 | """ 406 | One step for the first-order DPM-Solver (equivalent to DDIM). 407 | 408 | See https://arxiv.org/abs/2206.00927 for the detailed derivation. 409 | 410 | Args: 411 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 412 | timestep (`int`): current discrete timestep in the diffusion chain. 413 | prev_timestep (`int`): previous discrete timestep in the diffusion chain. 414 | sample (`torch.FloatTensor`): 415 | current instance of sample being created by diffusion process. 416 | 417 | Returns: 418 | `torch.FloatTensor`: the sample tensor at the previous timestep. 419 | """ 420 | lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] 421 | alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] 422 | sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] 423 | h = lambda_t - lambda_s 424 | if self.config.algorithm_type == "dpmsolver++": 425 | 426 | x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output 427 | elif self.config.algorithm_type == "dpmsolver": 428 | 429 | x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output 430 | elif self.config.algorithm_type == "sde-dpmsolver++": 431 | assert noise is not None 432 | x_t = ( 433 | (sigma_t / sigma_s * torch.exp(-h)) * sample 434 | + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output 435 | + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise 436 | ) 437 | elif self.config.algorithm_type == "sde-dpmsolver": 438 | assert noise is not None 439 | x_t = ( 440 | (alpha_t / alpha_s) * sample 441 | - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output 442 | + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise 443 | ) 444 | return x_t 445 | 446 | def multistep_dpm_solver_second_order_update( 447 | self, 448 | model_output_list: List[torch.FloatTensor], 449 | timestep_list: List[int], 450 | prev_timestep: int, 451 | sample: torch.FloatTensor, 452 | noise: Optional[torch.FloatTensor] = None, 453 | ) -> torch.FloatTensor: 454 | """ 455 | One step for the second-order multistep DPM-Solver. 456 | 457 | Args: 458 | model_output_list (`List[torch.FloatTensor]`): 459 | direct outputs from learned diffusion model at current and latter timesteps. 460 | timestep (`int`): current and latter discrete timestep in the diffusion chain. 461 | prev_timestep (`int`): previous discrete timestep in the diffusion chain. 462 | sample (`torch.FloatTensor`): 463 | current instance of sample being created by diffusion process. 464 | 465 | Returns: 466 | `torch.FloatTensor`: the sample tensor at the previous timestep. 467 | """ 468 | t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] 469 | m0, m1 = model_output_list[-1], model_output_list[-2] 470 | lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] 471 | alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] 472 | sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] 473 | h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 474 | r0 = h_0 / h 475 | D0, D1 = m0, (1.0 / r0) * (m0 - m1) 476 | if self.config.algorithm_type == "dpmsolver++": 477 | # See https://arxiv.org/abs/2211.01095 for detailed derivations 478 | if self.config.solver_type == "midpoint": 479 | x_t = ( 480 | (sigma_t / sigma_s0) * sample 481 | - (alpha_t * (torch.exp(-h) - 1.0)) * D0 482 | - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 483 | ) 484 | elif self.config.solver_type == "heun": 485 | x_t = ( 486 | (sigma_t / sigma_s0) * sample 487 | - (alpha_t * (torch.exp(-h) - 1.0)) * D0 488 | + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 489 | ) 490 | elif self.config.algorithm_type == "dpmsolver": 491 | 492 | # See https://arxiv.org/abs/2206.00927 for detailed derivations 493 | if self.config.solver_type == "midpoint": 494 | x_t = ( 495 | (alpha_t / alpha_s0) * sample 496 | - (sigma_t * (torch.exp(h) - 1.0)) * D0 497 | - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 498 | ) 499 | elif self.config.solver_type == "heun": 500 | x_t = ( 501 | (alpha_t / alpha_s0) * sample 502 | - (sigma_t * (torch.exp(h) - 1.0)) * D0 503 | - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 504 | ) 505 | elif self.config.algorithm_type == "sde-dpmsolver++": 506 | assert noise is not None 507 | if self.config.solver_type == "midpoint": 508 | x_t = ( 509 | (sigma_t / sigma_s0 * torch.exp(-h)) * sample 510 | + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 511 | + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 512 | + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise 513 | ) 514 | elif self.config.solver_type == "heun": 515 | x_t = ( 516 | (sigma_t / sigma_s0 * torch.exp(-h)) * sample 517 | + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 518 | + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 519 | + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise 520 | ) 521 | elif self.config.algorithm_type == "sde-dpmsolver": 522 | assert noise is not None 523 | if self.config.solver_type == "midpoint": 524 | x_t = ( 525 | (alpha_t / alpha_s0) * sample 526 | - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 527 | - (sigma_t * (torch.exp(h) - 1.0)) * D1 528 | + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise 529 | ) 530 | elif self.config.solver_type == "heun": 531 | x_t = ( 532 | (alpha_t / alpha_s0) * sample 533 | - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 534 | - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 535 | + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise 536 | ) 537 | return x_t 538 | 539 | def multistep_dpm_solver_third_order_update( 540 | self, 541 | model_output_list: List[torch.FloatTensor], 542 | timestep_list: List[int], 543 | prev_timestep: int, 544 | sample: torch.FloatTensor, 545 | ) -> torch.FloatTensor: 546 | """ 547 | One step for the third-order multistep DPM-Solver. 548 | 549 | Args: 550 | model_output_list (`List[torch.FloatTensor]`): 551 | direct outputs from learned diffusion model at current and latter timesteps. 552 | timestep (`int`): current and latter discrete timestep in the diffusion chain. 553 | prev_timestep (`int`): previous discrete timestep in the diffusion chain. 554 | sample (`torch.FloatTensor`): 555 | current instance of sample being created by diffusion process. 556 | 557 | Returns: 558 | `torch.FloatTensor`: the sample tensor at the previous timestep. 559 | """ 560 | t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] 561 | m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] 562 | lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( 563 | self.lambda_t[t], 564 | self.lambda_t[s0], 565 | self.lambda_t[s1], 566 | self.lambda_t[s2], 567 | ) 568 | alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] 569 | sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] 570 | h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 571 | r0, r1 = h_0 / h, h_1 / h 572 | D0 = m0 573 | D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) 574 | D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) 575 | D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) 576 | if self.config.algorithm_type == "dpmsolver++": 577 | # See https://arxiv.org/abs/2206.00927 for detailed derivations 578 | x_t = ( 579 | (sigma_t / sigma_s0) * sample 580 | - (alpha_t * (torch.exp(-h) - 1.0)) * D0 581 | + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 582 | - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 583 | ) 584 | elif self.config.algorithm_type == "dpmsolver": 585 | # See https://arxiv.org/abs/2206.00927 for detailed derivations 586 | x_t = ( 587 | (alpha_t / alpha_s0) * sample 588 | - (sigma_t * (torch.exp(h) - 1.0)) * D0 589 | - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 590 | - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 591 | ) 592 | return x_t 593 | 594 | def step( 595 | self, 596 | model_output: torch.FloatTensor, 597 | timestep: int, 598 | sample: torch.FloatTensor, 599 | generator=None, 600 | return_dict: bool = True, 601 | variance_noise: Optional[torch.FloatTensor] = None, 602 | ) -> Union[SchedulerOutput, Tuple]: 603 | """ 604 | Step function propagating the sample with the multistep DPM-Solver. 605 | 606 | Args: 607 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 608 | timestep (`int`): current discrete timestep in the diffusion chain. 609 | sample (`torch.FloatTensor`): 610 | current instance of sample being created by diffusion process. 611 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class 612 | 613 | Returns: 614 | [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is 615 | True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. 616 | 617 | """ 618 | if self.num_inference_steps is None: 619 | raise ValueError( 620 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 621 | ) 622 | 623 | if isinstance(timestep, torch.Tensor): 624 | timestep = timestep.to(self.timesteps.device) 625 | step_index = (self.timesteps == timestep).nonzero() 626 | if len(step_index) == 0: 627 | step_index = len(self.timesteps) - 1 628 | else: 629 | step_index = step_index.item() 630 | prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] 631 | lower_order_final = ( 632 | (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 633 | ) 634 | lower_order_second = ( 635 | (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 636 | ) 637 | 638 | model_output = self.convert_model_output(model_output, timestep, sample) 639 | for i in range(self.config.solver_order - 1): 640 | self.model_outputs[i] = self.model_outputs[i + 1] 641 | self.model_outputs[-1] = model_output 642 | 643 | if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: 644 | noise = randn_tensor( 645 | model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype 646 | ) 647 | elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: 648 | noise = variance_noise 649 | else: 650 | noise = None 651 | 652 | if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: 653 | prev_sample = self.dpm_solver_first_order_update( 654 | model_output, timestep, prev_timestep, sample, noise=noise 655 | ) 656 | elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: 657 | timestep_list = [self.timesteps[step_index - 1], timestep] 658 | prev_sample = self.multistep_dpm_solver_second_order_update( 659 | self.model_outputs, timestep_list, prev_timestep, sample, noise=noise 660 | ) 661 | else: 662 | raise NotImplementedError() 663 | 664 | if self.lower_order_nums < self.config.solver_order: 665 | self.lower_order_nums += 1 666 | 667 | if not return_dict: 668 | return (prev_sample,) 669 | 670 | return SchedulerOutput(prev_sample=prev_sample) 671 | 672 | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: 673 | """ 674 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 675 | current timestep. 676 | 677 | Args: 678 | sample (`torch.FloatTensor`): input sample 679 | 680 | Returns: 681 | `torch.FloatTensor`: scaled input sample 682 | """ 683 | return sample 684 | 685 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise 686 | def add_noise( 687 | self, 688 | original_samples: torch.FloatTensor, 689 | noise: torch.FloatTensor, 690 | timesteps: torch.IntTensor, 691 | ) -> torch.FloatTensor: 692 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 693 | alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) 694 | timesteps = timesteps.to(original_samples.device) 695 | 696 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 697 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 698 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 699 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 700 | 701 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 702 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 703 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 704 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 705 | 706 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 707 | return noisy_samples 708 | 709 | def __len__(self): 710 | return self.config.num_train_timesteps 711 | --------------------------------------------------------------------------------