├── .gitignore ├── LICENSE ├── README.md ├── examples ├── SemanticGuidance.ipynb ├── TheStableArtist.ipynb └── teaser.png ├── requirements.txt ├── setup.py └── src └── semdiffusers ├── __init__.py └── pipeline_latent_edit_diffusion.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ml-research@TUDarmstadt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Guidance for Diffusion 2 | 3 | Official Implementation of the [Paper](https://arxiv.org/abs/2301.12247) **SEGA: Instructing Diffusion using Semantic Dimensions**. 4 | 5 | You may find the implementation of the previous [pre-print](http://arxiv.org/abs/2212.06013) **The Stable Artist: Interacting with Concepts in Diffusion Latent Space** under the tag [StableArtist](https://github.com/ml-research/semantic-image-editing/tree/StableArtist). 6 | 7 | ## Interactive Demo 8 | An interactive demonstration is available in Colab and on Huggingface [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb) [![Huggingface Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/AIML-TUDA/semantic-diffusion) 9 | 10 | ![Examples](./examples/teaser.png) 11 | 12 | ## Installation 13 | SEGA is fully integrated in the ```diffusers``` library as ```SemanticStableDiffusionPipeline```. Just install diffusers to use it: 14 | 15 | ```cmd 16 | pip install diffusers 17 | ``` 18 | 19 | Alternatively you can clone this repository and install it locally by running 20 | 21 | ```cmd 22 | git clone https://github.com/ml-research/semantic-image-editing.git 23 | cd ./semantic-image-editing 24 | pip install . 25 | ``` 26 | or install it directly from git 27 | ```cmd 28 | pip install git+https://github.com/ml-research/semantic-image-editing.git 29 | ``` 30 | 31 | ## Usage 32 | This repository provides a new diffusion pipeline supporting semantic image editing based on the [diffusers](https://github.com/huggingface/diffusers) library. 33 | The ```SemanticEditPipeline``` extends the ```StableDiffusionPipeline``` and can therefore be loaded from a stable diffusion checkpoint like shown below. 34 | 35 | 36 | ```python 37 | from semdiffusers import SemanticEditPipeline 38 | device='cuda' 39 | 40 | pipe = SemanticEditPipeline.from_pretrained( 41 | "runwayml/stable-diffusion-v1-5", 42 | ).to(device) 43 | ``` 44 | or load the corresponding pipeline in diffusers: 45 | 46 | ```python 47 | from diffusers import SemanticStableDiffusionPipeline 48 | device = 'cuda' 49 | pipe = SemanticStableDiffusionPipeline.from_pretrained( 50 | "runwayml/stable-diffusion-v1-5", 51 | ).to(device) 52 | ``` 53 | 54 | An exemplary usage of the pipeline could look like this: 55 | ```python 56 | import torch 57 | gen = torch.Generator(device=device) 58 | 59 | gen.manual_seed(21) 60 | out = pipe(prompt='a photo of the face of a woman', generator=gen, num_images_per_prompt=1, guidance_scale=7, 61 | editing_prompt=['smiling, smile', # Concepts to apply 62 | 'glasses, wearing glasses', 63 | 'curls, wavy hair, curly hair', 64 | 'beard, full beard, mustache'], 65 | reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts 66 | edit_warmup_steps=[10, 10, 10,10], # Warmup period for each concept 67 | edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept 68 | edit_threshold=[0.99, 0.975, 0.925, 0.96], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions 69 | edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance 70 | edit_mom_beta=0.6, # Momentum beta 71 | edit_weights=[1,1,1,1,1] # Weights of the individual concepts against each other 72 | ) 73 | images = out.images 74 | 75 | ``` 76 | 77 | ## Citation 78 | If you like or use our work please cite us: 79 | ```bibtex 80 | @article{brack2023Sega, 81 | title={SEGA: Instructing Diffusion using Semantic Dimensions}, 82 | author={Manuel Brack and Felix Friedrich and Dominik Hintersdorf and Lukas Struppek and Patrick Schramowski and Kristian Kersting}, 83 | year={2023}, 84 | journal={NeurIPS} 85 | } 86 | ``` 87 | 88 | -------------------------------------------------------------------------------- /examples/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/semantic-image-editing/d284e44077cc6d94389693e1db6b1e5495f9e0b8/examples/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | Pillow<10.0 3 | accelerate>=0.11.0 4 | torch>=1.4 5 | torchvision 6 | transformers>=4.21.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="semdiffusers", 8 | version="1.0.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 9 | description="Semantic Image Editing", 10 | long_description=open("README.md", "r", encoding="utf-8").read(), 11 | long_description_content_type="text/markdown", 12 | keywords="deep learning", 13 | license="MIT", 14 | author="Manuel Brack", 15 | author_email="brac@cs.tu-darmstadt.de", 16 | url="https://github.com/ml-research/semantic-image-editing", 17 | package_dir={"": "src"}, 18 | packages=find_packages("src"), 19 | install_requires=[ 20 | str(r) 21 | for r in pkg_resources.parse_requirements( 22 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 23 | ) 24 | ], 25 | include_package_data=True, 26 | ) 27 | -------------------------------------------------------------------------------- /src/semdiffusers/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | 6 | import PIL 7 | from PIL import Image 8 | 9 | from diffusers.utils import BaseOutput, is_torch_available, is_transformers_available 10 | 11 | 12 | @dataclass 13 | class SemanticEditPipelineOutput(BaseOutput): 14 | """ 15 | Output class for Latent editing pipeline. 16 | 17 | Args: 18 | images (`List[PIL.Image.Image]` or `np.ndarray`) 19 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 20 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 21 | inappropriate_content_detected (`List[bool]`) 22 | List of flags denoting whether the corresponding generated image likely represents inappropriate content, 23 | or `None` if safety checking could not be performed. 24 | """ 25 | 26 | images: Union[List[PIL.Image.Image], np.ndarray] 27 | inappropriate_content_detected: Optional[List[bool]] 28 | 29 | 30 | if is_transformers_available() and is_torch_available(): 31 | from .pipeline_latent_edit_diffusion import SemanticEditPipeline 32 | -------------------------------------------------------------------------------- /src/semdiffusers/pipeline_latent_edit_diffusion.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from itertools import repeat 4 | from typing import Callable, List, Optional, Union 5 | 6 | import torch 7 | 8 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 9 | 10 | from diffusers.configuration_utils import FrozenDict 11 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 12 | from diffusers.pipeline_utils import DiffusionPipeline 13 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 14 | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler 15 | from diffusers.utils import deprecate, logging 16 | from . import SemanticEditPipelineOutput 17 | 18 | 19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 20 | 21 | 22 | class SemanticEditPipeline(DiffusionPipeline): 23 | r""" 24 | Pipeline for text-to-image generation with latent editing. 25 | 26 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 27 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 28 | 29 | This model builds on the implementation of ['StableDiffusionPipeline'] 30 | 31 | Args: 32 | vae ([`AutoencoderKL`]): 33 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 34 | text_encoder ([`CLIPTextModel`]): 35 | Frozen text-encoder. Stable Diffusion uses the text portion of 36 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 37 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 38 | tokenizer (`CLIPTokenizer`): 39 | Tokenizer of class 40 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 41 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 42 | scheduler ([`SchedulerMixin`]): 43 | A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of 44 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 45 | safety_checker ([`Q16SafetyChecker`]): 46 | Classification module that estimates whether generated images could be considered offensive or harmful. 47 | Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. 48 | feature_extractor ([`CLIPFeatureExtractor`]): 49 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | vae: AutoencoderKL, 55 | text_encoder: CLIPTextModel, 56 | tokenizer: CLIPTokenizer, 57 | unet: UNet2DConditionModel, 58 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 59 | safety_checker: StableDiffusionSafetyChecker, 60 | feature_extractor: CLIPFeatureExtractor, 61 | ): 62 | super().__init__() 63 | 64 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 65 | deprecation_message = ( 66 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 67 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 68 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 69 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 70 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 71 | " file" 72 | ) 73 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 74 | new_config = dict(scheduler.config) 75 | new_config["steps_offset"] = 1 76 | scheduler._internal_dict = FrozenDict(new_config) 77 | 78 | if safety_checker is None: 79 | warnings.warn( 80 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 81 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 82 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 83 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 84 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 85 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 86 | ) 87 | 88 | self.register_modules( 89 | vae=vae, 90 | text_encoder=text_encoder, 91 | tokenizer=tokenizer, 92 | unet=unet, 93 | scheduler=scheduler, 94 | safety_checker=safety_checker, 95 | feature_extractor=feature_extractor, 96 | ) 97 | 98 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 99 | r""" 100 | Enable sliced attention computation. 101 | 102 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 103 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 104 | 105 | Args: 106 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): 107 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 108 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, 109 | `attention_head_dim` must be a multiple of `slice_size`. 110 | """ 111 | if slice_size == "auto": 112 | # half the attention head size is usually a good trade-off between 113 | # speed and memory 114 | slice_size = self.unet.config.attention_head_dim // 2 115 | self.unet.set_attention_slice(slice_size) 116 | 117 | def disable_attention_slicing(self): 118 | r""" 119 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go 120 | back to computing attention in one step. 121 | """ 122 | # set slice_size = `None` to disable `attention slicing` 123 | self.enable_attention_slicing(None) 124 | 125 | @torch.no_grad() 126 | def __call__( 127 | self, 128 | prompt: Union[str, List[str]], 129 | height: int = 512, 130 | width: int = 512, 131 | num_inference_steps: int = 50, 132 | guidance_scale: float = 7.5, 133 | negative_prompt: Optional[Union[str, List[str]]] = None, 134 | num_images_per_prompt: Optional[int] = 1, 135 | eta: float = 0.0, 136 | generator: Optional[torch.Generator] = None, 137 | latents: Optional[torch.FloatTensor] = None, 138 | output_type: Optional[str] = "pil", 139 | return_dict: bool = True, 140 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 141 | callback_steps: Optional[int] = 1, 142 | editing_prompt: Optional[Union[str, List[str]]] = None, 143 | editing_prompt_prompt_embeddings=None, 144 | reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, 145 | edit_guidance_scale: Optional[Union[float, List[float]]] = 500, 146 | edit_warmup_steps: Optional[Union[int, List[int]]] = 10, 147 | edit_cooldown_steps: Optional[Union[int, List[int]]] = None, 148 | edit_threshold: Optional[Union[float, List[float]]] = None, 149 | edit_momentum_scale: Optional[float] = 0.1, 150 | edit_mom_beta: Optional[float] = 0.4, 151 | edit_weights: Optional[List[float]] = None, 152 | sem_guidance = None, 153 | **kwargs, 154 | ): 155 | r""" 156 | Function invoked when calling the pipeline for generation. 157 | 158 | Args: 159 | prompt (`str` or `List[str]`): 160 | The prompt or prompts to guide the image generation. 161 | height (`int`, *optional*, defaults to 512): 162 | The height in pixels of the generated image. 163 | width (`int`, *optional*, defaults to 512): 164 | The width in pixels of the generated image. 165 | num_inference_steps (`int`, *optional*, defaults to 50): 166 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 167 | expense of slower inference. 168 | guidance_scale (`float`, *optional*, defaults to 7.5): 169 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 170 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 171 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 172 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 173 | usually at the expense of lower image quality. 174 | negative_prompt (`str` or `List[str]`, *optional*): 175 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 176 | if `guidance_scale` is less than `1`). 177 | num_images_per_prompt (`int`, *optional*, defaults to 1): 178 | The number of images to generate per prompt. 179 | eta (`float`, *optional*, defaults to 0.0): 180 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 181 | [`schedulers.DDIMScheduler`], will be ignored for others. 182 | generator (`torch.Generator`, *optional*): 183 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 184 | deterministic. 185 | latents (`torch.FloatTensor`, *optional*): 186 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 187 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 188 | tensor will ge generated by sampling using the supplied random `generator`. 189 | output_type (`str`, *optional*, defaults to `"pil"`): 190 | The output format of the generate image. Choose between 191 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 192 | return_dict (`bool`, *optional*, defaults to `True`): 193 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 194 | plain tuple. 195 | callback (`Callable`, *optional*): 196 | A function that will be called every `callback_steps` steps during inference. The function will be 197 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 198 | callback_steps (`int`, *optional*, defaults to 1): 199 | The frequency at which the `callback` function will be called. If not specified, the callback will be 200 | called at every step. 201 | editing_prompt (`str` or `List[str]`, *optional*): 202 | The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting 203 | `editing_prompt = None`. Guidance direction of prompt should be specified via 204 | `reverse_editing_direction`. 205 | reverse_editing_direction (`bool` or `List[bool]`, *optional*): 206 | Whether the corresponding prompt in `editing_prompt` should be increased or decreased. 207 | edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): 208 | Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. 209 | edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): 210 | Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum 211 | will still be calculated for those steps and applied once all warmup periods are over. 212 | edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to 10): 213 | Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. 214 | edit_threshold (`float` or `List[float]`, *optional*, defaults to `None`): 215 | Threshold of semantic guidance. 216 | edit_momentum_scale (`float`, *optional*, defaults to 0.1): 217 | Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 momentum 218 | will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller than 219 | `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are 220 | finished. 221 | edit_mom_beta (`float`, *optional*, defaults to 0.4): 222 | Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous 223 | momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller 224 | than `edit_warmup_steps`. 225 | edit_weights (`List[float]`, *optional*, defaults to `None`): 226 | Indicates how much each individual concept should influence the overall guidance. If no weights are 227 | provided all concepts are applied equally. 228 | 229 | Returns: 230 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 231 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 232 | When returning a tuple, the first element is a list with the generated images, and the second element is a 233 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 234 | (nsfw) content, according to the `safety_checker`. 235 | """ 236 | 237 | if isinstance(prompt, str): 238 | batch_size = 1 239 | elif isinstance(prompt, list): 240 | batch_size = len(prompt) 241 | else: 242 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 243 | 244 | if height % 8 != 0 or width % 8 != 0: 245 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 246 | 247 | if (callback_steps is None) or ( 248 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 249 | ): 250 | raise ValueError( 251 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 252 | f" {type(callback_steps)}." 253 | ) 254 | 255 | if editing_prompt: 256 | enable_edit_guidance = True 257 | if isinstance(editing_prompt, str): 258 | editing_prompt = [editing_prompt] 259 | enabled_editing_prompts = len(editing_prompt) 260 | elif editing_prompt_prompt_embeddings is not None: 261 | enable_edit_guidance = True 262 | enabled_editing_prompts = editing_prompt_prompt_embeddings.shape[0] 263 | else: 264 | enabled_editing_prompts = 0 265 | enable_edit_guidance = False 266 | 267 | 268 | # get prompt text embeddings 269 | text_inputs = self.tokenizer( 270 | prompt, 271 | padding="max_length", 272 | max_length=self.tokenizer.model_max_length, 273 | return_tensors="pt", 274 | ) 275 | text_input_ids = text_inputs.input_ids 276 | 277 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: 278 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) 279 | logger.warning( 280 | "The following part of your input was truncated because CLIP can only handle sequences up to" 281 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 282 | ) 283 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 284 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 285 | 286 | 287 | # duplicate text embeddings for each generation per prompt, using mps friendly method 288 | bs_embed, seq_len, _ = text_embeddings.shape 289 | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) 290 | text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 291 | 292 | if enable_edit_guidance: 293 | # get safety text embeddings 294 | if editing_prompt_prompt_embeddings is None: 295 | edit_concepts_input = self.tokenizer( 296 | [x for item in editing_prompt for x in repeat(item, batch_size)], 297 | padding="max_length", 298 | max_length=self.tokenizer.model_max_length, 299 | return_tensors="pt", 300 | ) 301 | 302 | edit_concepts_input_ids = edit_concepts_input.input_ids 303 | 304 | if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: 305 | removed_text = self.tokenizer.batch_decode( 306 | edit_concepts_input_ids[:, self.tokenizer.model_max_length :] 307 | ) 308 | logger.warning( 309 | "The following part of your input was truncated because CLIP can only handle sequences up to" 310 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 311 | ) 312 | edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] 313 | edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] 314 | else: 315 | edit_concepts = editing_prompt_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) 316 | 317 | # duplicate text embeddings for each generation per prompt, using mps friendly method 318 | bs_embed_edit, seq_len_edit, _ = edit_concepts.shape 319 | edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) 320 | edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) 321 | 322 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 323 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 324 | # corresponds to doing no classifier free guidance. 325 | do_classifier_free_guidance = guidance_scale > 1.0 326 | # get unconditional embeddings for classifier free guidance 327 | 328 | 329 | if do_classifier_free_guidance: 330 | uncond_tokens: List[str] 331 | if negative_prompt is None: 332 | uncond_tokens = [""] 333 | elif type(prompt) is not type(negative_prompt): 334 | raise TypeError( 335 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 336 | f" {type(prompt)}." 337 | ) 338 | elif isinstance(negative_prompt, str): 339 | uncond_tokens = [negative_prompt] 340 | elif batch_size != len(negative_prompt): 341 | raise ValueError( 342 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 343 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 344 | " the batch size of `prompt`." 345 | ) 346 | else: 347 | uncond_tokens = negative_prompt 348 | 349 | max_length = text_input_ids.shape[-1] 350 | uncond_input = self.tokenizer( 351 | uncond_tokens, 352 | padding="max_length", 353 | max_length=max_length, 354 | truncation=True, 355 | return_tensors="pt", 356 | ) 357 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 358 | 359 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 360 | seq_len = uncond_embeddings.shape[1] 361 | uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) 362 | uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) 363 | 364 | # For classifier free guidance, we need to do two forward passes. 365 | # Here we concatenate the unconditional and text embeddings into a single batch 366 | # to avoid doing two forward passes 367 | if enable_edit_guidance: 368 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) 369 | else: 370 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 371 | # get the initial random noise unless the user supplied it 372 | 373 | # Unlike in other pipelines, latents need to be generated in the target device 374 | # for 1-to-1 results reproducibility with the CompVis implementation. 375 | # However this currently doesn't work in `mps`. 376 | latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) 377 | latents_dtype = text_embeddings.dtype 378 | if latents is None: 379 | if self.device.type == "mps": 380 | # randn does not exist on mps 381 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( 382 | self.device 383 | ) 384 | else: 385 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) 386 | else: 387 | if latents.shape != latents_shape: 388 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 389 | latents = latents.to(self.device) 390 | 391 | # set timesteps 392 | self.scheduler.set_timesteps(num_inference_steps) 393 | 394 | # Some schedulers like PNDM have timesteps as arrays 395 | # It's more optimized to move all timesteps to correct device beforehand 396 | timesteps_tensor = self.scheduler.timesteps.to(self.device) 397 | 398 | # scale the initial noise by the standard deviation required by the scheduler 399 | latents = latents * self.scheduler.init_noise_sigma 400 | 401 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 402 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 403 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 404 | # and should be between [0, 1] 405 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 406 | extra_step_kwargs = {} 407 | if accepts_eta: 408 | extra_step_kwargs["eta"] = eta 409 | 410 | # Initialize edit_momentum to None 411 | edit_momentum = None 412 | 413 | self.uncond_estimates = None 414 | self.text_estimates = None 415 | self.edit_estimates = None 416 | self.sem_guidance = None 417 | 418 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): 419 | # expand the latents if we are doing classifier free guidance 420 | latent_model_input = ( 421 | torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents 422 | ) 423 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 424 | 425 | # predict the noise residual 426 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 427 | 428 | # perform guidance 429 | if do_classifier_free_guidance: 430 | noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] 431 | noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] 432 | noise_pred_edit_concepts = noise_pred_out[2:] 433 | 434 | # default text guidance 435 | noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) 436 | # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) 437 | 438 | if self.uncond_estimates is None: 439 | self.uncond_estimates = torch.zeros((num_inference_steps+1, *noise_pred_uncond.shape)) 440 | self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() 441 | 442 | if self.text_estimates is None: 443 | self.text_estimates = torch.zeros((num_inference_steps+1, *noise_pred_text.shape)) 444 | self.text_estimates[i] = noise_pred_text.detach().cpu() 445 | 446 | if self.edit_estimates is None and enable_edit_guidance: 447 | self.edit_estimates = torch.zeros((num_inference_steps+1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)) 448 | 449 | if self.sem_guidance is None: 450 | self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) 451 | 452 | if edit_momentum is None: 453 | edit_momentum = torch.zeros_like(noise_guidance) 454 | 455 | if enable_edit_guidance: 456 | 457 | concept_weights = torch.zeros( 458 | (len(noise_pred_edit_concepts), noise_guidance.shape[0]), device=self.device 459 | ) 460 | noise_guidance_edit = torch.zeros( 461 | (len(noise_pred_edit_concepts), *noise_guidance.shape), device=self.device 462 | ) 463 | # noise_guidance_edit = torch.zeros_like(noise_guidance) 464 | warmup_inds = [] 465 | for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): 466 | self.edit_estimates[i, c] = noise_pred_edit_concept 467 | if isinstance(edit_guidance_scale, list): 468 | edit_guidance_scale_c = edit_guidance_scale[c] 469 | else: 470 | edit_guidance_scale_c = edit_guidance_scale 471 | 472 | if isinstance(edit_threshold, list): 473 | edit_threshold_c = edit_threshold[c] 474 | else: 475 | edit_threshold_c = edit_threshold 476 | if isinstance(reverse_editing_direction, list): 477 | reverse_editing_direction_c = reverse_editing_direction[c] 478 | else: 479 | reverse_editing_direction_c = reverse_editing_direction 480 | if edit_weights: 481 | edit_weight_c = edit_weights[c] 482 | else: 483 | edit_weight_c = 1.0 484 | if isinstance(edit_warmup_steps, list): 485 | edit_warmup_steps_c = edit_warmup_steps[c] 486 | else: 487 | edit_warmup_steps_c = edit_warmup_steps 488 | 489 | if isinstance(edit_cooldown_steps, list): 490 | edit_cooldown_steps_c = edit_cooldown_steps[c] 491 | elif edit_cooldown_steps is None: 492 | edit_cooldown_steps_c = i + 1 493 | else: 494 | edit_cooldown_steps_c = edit_cooldown_steps 495 | if i >= edit_warmup_steps_c: 496 | warmup_inds.append(c) 497 | if i >= edit_cooldown_steps_c: 498 | noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) 499 | continue 500 | 501 | noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond 502 | # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) 503 | tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) 504 | 505 | tmp_weights = torch.full_like(tmp_weights, edit_weight_c) #* (1 / enabled_editing_prompts) 506 | if reverse_editing_direction_c: 507 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 508 | concept_weights[c, :] = tmp_weights 509 | 510 | noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c 511 | tmp = torch.quantile(torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False) 512 | noise_guidance_edit_tmp = torch.where( 513 | torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None] 514 | , noise_guidance_edit_tmp 515 | , torch.zeros_like(noise_guidance_edit_tmp) 516 | ) 517 | noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp 518 | 519 | 520 | # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp 521 | 522 | warmup_inds = torch.tensor(warmup_inds).to(self.device) 523 | if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: 524 | concept_weights = concept_weights.to("cpu") # Offload to cpu 525 | noise_guidance_edit = noise_guidance_edit.to("cpu") 526 | 527 | concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) 528 | concept_weights_tmp = torch.where( 529 | concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp 530 | ) 531 | concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) 532 | # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) 533 | 534 | noise_guidance_edit_tmp = torch.index_select( 535 | noise_guidance_edit.to(self.device), 0, warmup_inds 536 | ) 537 | noise_guidance_edit_tmp = torch.einsum( 538 | "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp 539 | ) 540 | noise_guidance_edit_tmp = noise_guidance_edit_tmp 541 | noise_guidance = noise_guidance + noise_guidance_edit_tmp 542 | 543 | self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() 544 | 545 | del noise_guidance_edit_tmp 546 | del concept_weights_tmp 547 | concept_weights = concept_weights.to(self.device) 548 | noise_guidance_edit = noise_guidance_edit.to(self.device) 549 | 550 | concept_weights = torch.where( 551 | concept_weights < 0, torch.zeros_like(concept_weights), concept_weights 552 | ) 553 | 554 | concept_weights = torch.nan_to_num(concept_weights) 555 | noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) 556 | 557 | noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum 558 | 559 | edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit 560 | 561 | if warmup_inds.shape[0] == len(noise_pred_edit_concepts): 562 | noise_guidance = noise_guidance + noise_guidance_edit 563 | self.sem_guidance[i] = noise_guidance_edit.detach().cpu() 564 | 565 | if sem_guidance is not None: 566 | edit_guidance = sem_guidance[i].to(self.device) 567 | noise_guidance = noise_guidance + edit_guidance 568 | 569 | noise_pred = noise_pred_uncond + noise_guidance 570 | 571 | # compute the previous noisy sample x_t -> x_t-1 572 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 573 | 574 | # call the callback, if provided 575 | if callback is not None and i % callback_steps == 0: 576 | callback(i, t, latents) 577 | 578 | latents = 1 / 0.18215 * latents 579 | image = self.vae.decode(latents).sample 580 | 581 | image = (image / 2 + 0.5).clamp(0, 1) 582 | 583 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 584 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 585 | 586 | if self.safety_checker is not None: 587 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( 588 | self.device 589 | ) 590 | image, has_nsfw_concept = self.safety_checker( 591 | images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) 592 | ) 593 | else: 594 | has_nsfw_concept = None 595 | 596 | if output_type == "pil": 597 | image = self.numpy_to_pil(image) 598 | 599 | if not return_dict: 600 | return (image, has_nsfw_concept) 601 | 602 | return SemanticEditPipelineOutput(images=image, inappropriate_content_detected=has_nsfw_concept) 603 | --------------------------------------------------------------------------------