├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── images ├── image.png ├── inpainting.png ├── inpainting_process.gif └── mask.png ├── pipelines ├── __init__.py └── pipeline_stable_diffusion_img2img_simple_inpaint.py ├── pyproject.toml ├── requirements.txt └── run_simple_inpainting.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | results/ 163 | .DS_Store 164 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: _pb2\.py$ 2 | repos: 3 | - repo: https://github.com/psf/black 4 | rev: 22.3.0 5 | hooks: 6 | - id: black 7 | args: [ ] 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v4.0.1 10 | hooks: 11 | - id: check-docstring-first 12 | - id: check-json 13 | - id: check-merge-conflict 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: trailing-whitespace 17 | - id: requirements-txt-fixer 18 | - repo: https://github.com/pre-commit/pygrep-hooks 19 | rev: v1.9.0 20 | hooks: 21 | - id: python-check-mock-methods 22 | - id: python-use-type-annotations 23 | - repo: https://github.com/pre-commit/mirrors-mypy 24 | rev: 'v0.991' 25 | hooks: 26 | - id: mypy 27 | args: [--ignore-missing-imports, --warn-no-return, --warn-redundant-casts, --disallow-incomplete-defs] 28 | additional_dependencies: [types-all] 29 | - repo: https://github.com/PyCQA/isort 30 | rev: 5.12.0 31 | hooks: 32 | - id: isort 33 | args: [ --profile, black, --filter-files ] 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vadim Titko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusers Inpainting 2 | 3 | ![License](https://img.shields.io/github/license/Vadbeg/diffusers-inpainting) 4 | 5 | This is a repository for image inpainting with a Stable Diffusion finetunes which 6 | weren't trained on inpainting task. Code is based on pipeline from huggingface 🤗 Diffusers library. 7 | 8 | It is a simple learning project, it is better to use 9 | [StableDiffusionInpaintPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) 10 | or 11 | [StableDiffusionInpaintPipelineLegacy](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py) 12 | from 🤗 Diffusers library. 13 | 14 | ![inpainting](images/inpainting.png) 15 | 16 | ## Installation 17 | 18 | 1. Create a virtual environment: 19 | ```shell 20 | virtualenv -p python3.9 .venv && source .venv/bin/activate 21 | ``` 22 | 2. Install all requirements: 23 | ```shell 24 | pip install -r requirements.txt 25 | ``` 26 | 3. Use the project 🎉 27 | 28 | ## Usage 29 | 30 | To run use command below: 31 | ```shell 32 | python run_simple_inpainting.py --device cuda:1 --prompt "Face of a yellow cat, high resolution, sitting on a park bench" --strength 0.95 --seed 0 33 | ``` 34 | 35 | ## Diffusion inpainting process 36 | 37 | ![inpainting](images/inpainting_process.gif) 38 | 39 | This gif was created by decoding latent features at each step of the diffusion process. 40 | 41 | ## Built With 42 | 43 | * [🤗 Diffusers](https://github.com/apple/coremltools) - Huggingface diffusion models library 44 | * [Typer](https://typer.tiangolo.com/) - CLI framework 45 | 46 | 47 | ## License 48 | 49 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 50 | 51 | ## Authors 52 | 53 | * **Vadim Titko** aka *Vadbeg* - 54 | [LinkedIn](https://www.linkedin.com/in/vadtitko) | 55 | [GitHub](https://github.com/Vadbeg) 56 | -------------------------------------------------------------------------------- /images/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vadbeg/diffusers-inpainting/3efd045e431ddfb40019809554285c5d3e62722e/images/image.png -------------------------------------------------------------------------------- /images/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vadbeg/diffusers-inpainting/3efd045e431ddfb40019809554285c5d3e62722e/images/inpainting.png -------------------------------------------------------------------------------- /images/inpainting_process.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vadbeg/diffusers-inpainting/3efd045e431ddfb40019809554285c5d3e62722e/images/inpainting_process.gif -------------------------------------------------------------------------------- /images/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vadbeg/diffusers-inpainting/3efd045e431ddfb40019809554285c5d3e62722e/images/mask.png -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_stable_diffusion_img2img_simple_inpaint import ( 2 | StableDiffusionSimpleInpaintingPipeline, 3 | ) 4 | -------------------------------------------------------------------------------- /pipelines/pipeline_stable_diffusion_img2img_simple_inpaint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pipeline for inferencing stable diffusion model in img2img+inpainting using simple masking technique 3 | """ 4 | 5 | import inspect 6 | import warnings 7 | from typing import Any, Callable, Dict, List, Optional, Union 8 | 9 | import numpy as np 10 | import PIL 11 | import PIL.Image 12 | import torch 13 | from diffusers.configuration_utils import FrozenDict 14 | from diffusers.image_processor import VaeImageProcessor 15 | from diffusers.loaders import ( 16 | FromSingleFileMixin, 17 | LoraLoaderMixin, 18 | TextualInversionLoaderMixin, 19 | ) 20 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 21 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 22 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 23 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( 24 | EXAMPLE_DOC_STRING, 25 | ) 26 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 27 | StableDiffusionSafetyChecker, 28 | ) 29 | from diffusers.schedulers import KarrasDiffusionSchedulers 30 | from diffusers.utils import ( 31 | deprecate, 32 | is_accelerate_available, 33 | is_accelerate_version, 34 | logging, 35 | randn_tensor, 36 | replace_example_docstring, 37 | ) 38 | from packaging import version 39 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 40 | 41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def prepare_mask( 45 | mask: Union[PIL.Image.Image, np.ndarray, torch.Tensor] 46 | ) -> torch.Tensor: 47 | """ 48 | Prepares a mask to be consumed by the Stable Diffusion pipeline. This means that this input will be 49 | converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``1`` for 50 | the ``mask``. 51 | 52 | The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. 53 | 54 | Args: 55 | mask (_type_): The mask to apply to the image, i.e. regions to inpaint. 56 | It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` 57 | ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. 58 | 59 | 60 | Raises: 61 | ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask 62 | should be in the ``[0, 1]`` range. 63 | TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not 64 | (ot the other way around). 65 | 66 | Returns: 67 | torch.Tensor: mask as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. 68 | """ 69 | if isinstance(mask, torch.Tensor): 70 | if not isinstance(mask, torch.Tensor): 71 | raise TypeError( 72 | f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" 73 | ) 74 | 75 | # Batch and add channel dim for single mask 76 | if mask.ndim == 2: 77 | mask = mask.unsqueeze(0).unsqueeze(0) 78 | 79 | # Batch single mask or add channel dim 80 | if mask.ndim == 3: 81 | # Single batched mask, no channel dim or single mask not batched but channel dim 82 | if mask.shape[0] == 1: 83 | mask = mask.unsqueeze(0) 84 | 85 | # Batched masks no channel dim 86 | else: 87 | mask = mask.unsqueeze(1) 88 | 89 | # Check mask is in [0, 1] 90 | if mask.min() < 0 or mask.max() > 1: 91 | raise ValueError("Mask should be in [0, 1] range") 92 | 93 | # Binarize mask 94 | mask[mask < 0.5] = 0 95 | mask[mask >= 0.5] = 1 96 | else: 97 | # preprocess mask 98 | if isinstance(mask, (PIL.Image.Image, np.ndarray)): 99 | mask = [mask] 100 | 101 | if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): 102 | mask = np.concatenate( 103 | [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 104 | ) 105 | mask = mask.astype(np.float32) / 255.0 106 | elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): 107 | mask = np.concatenate([m[None, None, :] for m in mask], axis=0) 108 | 109 | mask[mask < 0.5] = 0 110 | mask[mask >= 0.5] = 1 111 | mask = torch.from_numpy(mask) 112 | 113 | return mask 114 | 115 | 116 | class StableDiffusionSimpleInpaintingPipeline( 117 | DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin 118 | ): 119 | r""" 120 | Pipeline for text-guided image to image generation using Stable Diffusion. 121 | 122 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 123 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 124 | 125 | In addition the pipeline inherits the following loading methods: 126 | - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] 127 | - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] 128 | - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] 129 | 130 | as well as the following saving methods: 131 | - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] 132 | 133 | Args: 134 | vae ([`AutoencoderKL`]): 135 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 136 | text_encoder ([`CLIPTextModel`]): 137 | Frozen text-encoder. Stable Diffusion uses the text portion of 138 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 139 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 140 | tokenizer (`CLIPTokenizer`): 141 | Tokenizer of class 142 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 143 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 144 | scheduler ([`SchedulerMixin`]): 145 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 146 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 147 | safety_checker ([`StableDiffusionSafetyChecker`]): 148 | Classification module that estimates whether generated images could be considered offensive or harmful. 149 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 150 | feature_extractor ([`CLIPImageProcessor`]): 151 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 152 | """ 153 | _optional_components = ["safety_checker", "feature_extractor"] 154 | 155 | def __init__( 156 | self, 157 | vae: AutoencoderKL, 158 | text_encoder: CLIPTextModel, 159 | tokenizer: CLIPTokenizer, 160 | unet: UNet2DConditionModel, 161 | scheduler: KarrasDiffusionSchedulers, 162 | safety_checker: StableDiffusionSafetyChecker, 163 | feature_extractor: CLIPImageProcessor, 164 | requires_safety_checker: bool = True, 165 | ): 166 | super().__init__() 167 | 168 | if ( 169 | hasattr(scheduler.config, "steps_offset") 170 | and scheduler.config.steps_offset != 1 171 | ): 172 | deprecation_message = ( 173 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 174 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 175 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 176 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 177 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 178 | " file" 179 | ) 180 | deprecate( 181 | "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False 182 | ) 183 | new_config = dict(scheduler.config) 184 | new_config["steps_offset"] = 1 185 | scheduler._internal_dict = FrozenDict(new_config) 186 | 187 | if ( 188 | hasattr(scheduler.config, "clip_sample") 189 | and scheduler.config.clip_sample is True 190 | ): 191 | deprecation_message = ( 192 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 193 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 194 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 195 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 196 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 197 | ) 198 | deprecate( 199 | "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False 200 | ) 201 | new_config = dict(scheduler.config) 202 | new_config["clip_sample"] = False 203 | scheduler._internal_dict = FrozenDict(new_config) 204 | 205 | if safety_checker is None and requires_safety_checker: 206 | logger.warning( 207 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 208 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 209 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 210 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 211 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 212 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 213 | ) 214 | 215 | if safety_checker is not None and feature_extractor is None: 216 | raise ValueError( 217 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 218 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 219 | ) 220 | 221 | is_unet_version_less_0_9_0 = hasattr( 222 | unet.config, "_diffusers_version" 223 | ) and version.parse( 224 | version.parse(unet.config._diffusers_version).base_version 225 | ) < version.parse( 226 | "0.9.0.dev0" 227 | ) 228 | is_unet_sample_size_less_64 = ( 229 | hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 230 | ) 231 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 232 | deprecation_message = ( 233 | "The configuration file of the unet has set the default `sample_size` to smaller than" 234 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 235 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 236 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 237 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 238 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 239 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 240 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 241 | " the `unet/config.json` file" 242 | ) 243 | deprecate( 244 | "sample_size<64", "1.0.0", deprecation_message, standard_warn=False 245 | ) 246 | new_config = dict(unet.config) 247 | new_config["sample_size"] = 64 248 | unet._internal_dict = FrozenDict(new_config) 249 | 250 | self.register_modules( 251 | vae=vae, 252 | text_encoder=text_encoder, 253 | tokenizer=tokenizer, 254 | unet=unet, 255 | scheduler=scheduler, 256 | safety_checker=safety_checker, 257 | feature_extractor=feature_extractor, 258 | ) 259 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 260 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 261 | self.register_to_config(requires_safety_checker=requires_safety_checker) 262 | 263 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload 264 | def enable_sequential_cpu_offload(self, gpu_id=0): 265 | r""" 266 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 267 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a 268 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. 269 | Note that offloading happens on a submodule basis. Memory savings are higher than with 270 | `enable_model_cpu_offload`, but performance is lower. 271 | """ 272 | if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): 273 | from accelerate import cpu_offload 274 | else: 275 | raise ImportError( 276 | "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" 277 | ) 278 | 279 | device = torch.device(f"cuda:{gpu_id}") 280 | 281 | if self.device.type != "cpu": 282 | self.to("cpu", silence_dtype_warnings=True) 283 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 284 | 285 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 286 | cpu_offload(cpu_offloaded_model, device) 287 | 288 | if self.safety_checker is not None: 289 | cpu_offload( 290 | self.safety_checker, execution_device=device, offload_buffers=True 291 | ) 292 | 293 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload 294 | def enable_model_cpu_offload(self, gpu_id=0): 295 | r""" 296 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 297 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 298 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 299 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 300 | """ 301 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 302 | from accelerate import cpu_offload_with_hook 303 | else: 304 | raise ImportError( 305 | "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." 306 | ) 307 | 308 | device = torch.device(f"cuda:{gpu_id}") 309 | 310 | if self.device.type != "cpu": 311 | self.to("cpu", silence_dtype_warnings=True) 312 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 313 | 314 | hook = None 315 | for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: 316 | _, hook = cpu_offload_with_hook( 317 | cpu_offloaded_model, device, prev_module_hook=hook 318 | ) 319 | 320 | if self.safety_checker is not None: 321 | _, hook = cpu_offload_with_hook( 322 | self.safety_checker, device, prev_module_hook=hook 323 | ) 324 | 325 | # We'll offload the last model manually. 326 | self.final_offload_hook = hook 327 | 328 | @property 329 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device 330 | def _execution_device(self): 331 | r""" 332 | Returns the device on which the pipeline's models will be executed. After calling 333 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 334 | hooks. 335 | """ 336 | if not hasattr(self.unet, "_hf_hook"): 337 | return self.device 338 | for module in self.unet.modules(): 339 | if ( 340 | hasattr(module, "_hf_hook") 341 | and hasattr(module._hf_hook, "execution_device") 342 | and module._hf_hook.execution_device is not None 343 | ): 344 | return torch.device(module._hf_hook.execution_device) 345 | return self.device 346 | 347 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 348 | def _encode_prompt( # type: ignore 349 | self, 350 | prompt, 351 | device, 352 | num_images_per_prompt, 353 | do_classifier_free_guidance, 354 | negative_prompt=None, 355 | prompt_embeds: Optional[torch.FloatTensor] = None, 356 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 357 | lora_scale: Optional[float] = None, 358 | ): 359 | r""" 360 | Encodes the prompt into text encoder hidden states. 361 | 362 | Args: 363 | prompt (`str` or `List[str]`, *optional*): 364 | prompt to be encoded 365 | device: (`torch.device`): 366 | torch device 367 | num_images_per_prompt (`int`): 368 | number of images that should be generated per prompt 369 | do_classifier_free_guidance (`bool`): 370 | whether to use classifier free guidance or not 371 | negative_prompt (`str` or `List[str]`, *optional*): 372 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 373 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 374 | less than `1`). 375 | prompt_embeds (`torch.FloatTensor`, *optional*): 376 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 377 | provided, text embeddings will be generated from `prompt` input argument. 378 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 379 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 380 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 381 | argument. 382 | lora_scale (`float`, *optional*): 383 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 384 | """ 385 | # set lora scale so that monkey patched LoRA 386 | # function of text encoder can correctly access it 387 | if lora_scale is not None and isinstance(self, LoraLoaderMixin): 388 | self._lora_scale = lora_scale 389 | 390 | if prompt is not None and isinstance(prompt, str): 391 | batch_size = 1 392 | elif prompt is not None and isinstance(prompt, list): 393 | batch_size = len(prompt) 394 | else: 395 | batch_size = prompt_embeds.shape[0] # type: ignore 396 | 397 | if prompt_embeds is None: 398 | # textual inversion: procecss multi-vector tokens if necessary 399 | if isinstance(self, TextualInversionLoaderMixin): 400 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 401 | 402 | text_inputs = self.tokenizer( 403 | prompt, 404 | padding="max_length", 405 | max_length=self.tokenizer.model_max_length, 406 | truncation=True, 407 | return_tensors="pt", 408 | ) 409 | text_input_ids = text_inputs.input_ids 410 | untruncated_ids = self.tokenizer( 411 | prompt, padding="longest", return_tensors="pt" 412 | ).input_ids 413 | 414 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 415 | -1 416 | ] and not torch.equal(text_input_ids, untruncated_ids): 417 | removed_text = self.tokenizer.batch_decode( 418 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 419 | ) 420 | logger.warning( 421 | "The following part of your input was truncated because CLIP can only handle sequences up to" 422 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 423 | ) 424 | 425 | if ( 426 | hasattr(self.text_encoder.config, "use_attention_mask") 427 | and self.text_encoder.config.use_attention_mask 428 | ): 429 | attention_mask = text_inputs.attention_mask.to(device) 430 | else: 431 | attention_mask = None 432 | 433 | prompt_embeds = self.text_encoder( 434 | text_input_ids.to(device), 435 | attention_mask=attention_mask, 436 | ) 437 | prompt_embeds = prompt_embeds[0] 438 | 439 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 440 | 441 | bs_embed, seq_len, _ = prompt_embeds.shape 442 | # duplicate text embeddings for each generation per prompt, using mps friendly method 443 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 444 | prompt_embeds = prompt_embeds.view( 445 | bs_embed * num_images_per_prompt, seq_len, -1 446 | ) 447 | 448 | # get unconditional embeddings for classifier free guidance 449 | if do_classifier_free_guidance and negative_prompt_embeds is None: 450 | uncond_tokens: List[str] 451 | if negative_prompt is None: 452 | uncond_tokens = [""] * batch_size 453 | elif prompt is not None and type(prompt) is not type(negative_prompt): 454 | raise TypeError( 455 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 456 | f" {type(prompt)}." 457 | ) 458 | elif isinstance(negative_prompt, str): 459 | uncond_tokens = [negative_prompt] 460 | elif batch_size != len(negative_prompt): 461 | raise ValueError( 462 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 463 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 464 | " the batch size of `prompt`." 465 | ) 466 | else: 467 | uncond_tokens = negative_prompt 468 | 469 | # textual inversion: procecss multi-vector tokens if necessary 470 | if isinstance(self, TextualInversionLoaderMixin): 471 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 472 | 473 | max_length = prompt_embeds.shape[1] 474 | uncond_input = self.tokenizer( 475 | uncond_tokens, 476 | padding="max_length", 477 | max_length=max_length, 478 | truncation=True, 479 | return_tensors="pt", 480 | ) 481 | 482 | if ( 483 | hasattr(self.text_encoder.config, "use_attention_mask") 484 | and self.text_encoder.config.use_attention_mask 485 | ): 486 | attention_mask = uncond_input.attention_mask.to(device) 487 | else: 488 | attention_mask = None 489 | 490 | negative_prompt_embeds = self.text_encoder( 491 | uncond_input.input_ids.to(device), 492 | attention_mask=attention_mask, 493 | ) 494 | negative_prompt_embeds = negative_prompt_embeds[0] 495 | 496 | if do_classifier_free_guidance: 497 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 498 | seq_len = negative_prompt_embeds.shape[1] # type: ignore 499 | 500 | negative_prompt_embeds = negative_prompt_embeds.to( # type: ignore 501 | dtype=self.text_encoder.dtype, device=device 502 | ) 503 | 504 | negative_prompt_embeds = negative_prompt_embeds.repeat( 505 | 1, num_images_per_prompt, 1 506 | ) 507 | negative_prompt_embeds = negative_prompt_embeds.view( 508 | batch_size * num_images_per_prompt, seq_len, -1 509 | ) 510 | 511 | # For classifier free guidance, we need to do two forward passes. 512 | # Here we concatenate the unconditional and text embeddings into a single batch 513 | # to avoid doing two forward passes 514 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 515 | 516 | return prompt_embeds 517 | 518 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 519 | def run_safety_checker(self, image, device, dtype): 520 | if self.safety_checker is None: 521 | has_nsfw_concept = None 522 | else: 523 | if torch.is_tensor(image): 524 | feature_extractor_input = self.image_processor.postprocess( 525 | image, output_type="pil" 526 | ) 527 | else: 528 | feature_extractor_input = self.image_processor.numpy_to_pil(image) 529 | safety_checker_input = self.feature_extractor( 530 | feature_extractor_input, return_tensors="pt" 531 | ).to(device) 532 | image, has_nsfw_concept = self.safety_checker( 533 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 534 | ) 535 | return image, has_nsfw_concept 536 | 537 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 538 | def decode_latents(self, latents): 539 | warnings.warn( 540 | "The decode_latents method is deprecated and will be removed in a future version. Please" 541 | " use VaeImageProcessor instead", 542 | FutureWarning, 543 | ) 544 | latents = 1 / self.vae.config.scaling_factor * latents 545 | image = self.vae.decode(latents, return_dict=False)[0] 546 | image = (image / 2 + 0.5).clamp(0, 1) 547 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 548 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 549 | return image 550 | 551 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 552 | def prepare_extra_step_kwargs(self, generator, eta): 553 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 554 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 555 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 556 | # and should be between [0, 1] 557 | 558 | accepts_eta = "eta" in set( 559 | inspect.signature(self.scheduler.step).parameters.keys() 560 | ) 561 | extra_step_kwargs = {} 562 | if accepts_eta: 563 | extra_step_kwargs["eta"] = eta 564 | 565 | # check if the scheduler accepts generator 566 | accepts_generator = "generator" in set( 567 | inspect.signature(self.scheduler.step).parameters.keys() 568 | ) 569 | if accepts_generator: 570 | extra_step_kwargs["generator"] = generator 571 | return extra_step_kwargs 572 | 573 | def check_inputs( 574 | self, 575 | prompt, 576 | strength, 577 | callback_steps, 578 | negative_prompt=None, 579 | prompt_embeds=None, 580 | negative_prompt_embeds=None, 581 | ): 582 | if strength < 0 or strength > 1: 583 | raise ValueError( 584 | f"The value of strength should in [0.0, 1.0] but is {strength}" 585 | ) 586 | 587 | if (callback_steps is None) or ( 588 | callback_steps is not None 589 | and (not isinstance(callback_steps, int) or callback_steps <= 0) 590 | ): 591 | raise ValueError( 592 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 593 | f" {type(callback_steps)}." 594 | ) 595 | 596 | if prompt is not None and prompt_embeds is not None: 597 | raise ValueError( 598 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 599 | " only forward one of the two." 600 | ) 601 | elif prompt is None and prompt_embeds is None: 602 | raise ValueError( 603 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 604 | ) 605 | elif prompt is not None and ( 606 | not isinstance(prompt, str) and not isinstance(prompt, list) 607 | ): 608 | raise ValueError( 609 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 610 | ) 611 | 612 | if negative_prompt is not None and negative_prompt_embeds is not None: 613 | raise ValueError( 614 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 615 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 616 | ) 617 | 618 | if prompt_embeds is not None and negative_prompt_embeds is not None: 619 | if prompt_embeds.shape != negative_prompt_embeds.shape: 620 | raise ValueError( 621 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 622 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 623 | f" {negative_prompt_embeds.shape}." 624 | ) 625 | 626 | def get_timesteps(self, num_inference_steps, strength, device): 627 | # get the original timestep using init_timestep 628 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 629 | 630 | t_start = max(num_inference_steps - init_timestep, 0) 631 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] 632 | 633 | return timesteps, num_inference_steps - t_start 634 | 635 | def prepare_latents( 636 | self, 637 | image, 638 | timestep, 639 | batch_size, 640 | num_images_per_prompt, 641 | dtype, 642 | device, 643 | generator=None, 644 | ): 645 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 646 | raise ValueError( 647 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 648 | ) 649 | 650 | image = image.to(device=device, dtype=dtype) 651 | 652 | batch_size = batch_size * num_images_per_prompt 653 | 654 | if image.shape[1] == 4: 655 | init_latents = image 656 | 657 | else: 658 | if isinstance(generator, list) and len(generator) != batch_size: 659 | raise ValueError( 660 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 661 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 662 | ) 663 | 664 | elif isinstance(generator, list): 665 | init_latents = [ 666 | self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) 667 | for i in range(batch_size) 668 | ] 669 | init_latents = torch.cat(init_latents, dim=0) 670 | else: 671 | init_latents = self.vae.encode(image).latent_dist.sample(generator) 672 | 673 | init_latents = self.vae.config.scaling_factor * init_latents 674 | 675 | if ( 676 | batch_size > init_latents.shape[0] 677 | and batch_size % init_latents.shape[0] == 0 678 | ): 679 | # expand init_latents for batch_size 680 | deprecation_message = ( 681 | f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" 682 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 683 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 684 | " your script to pass as many initial images as text prompts to suppress this warning." 685 | ) 686 | deprecate( 687 | "len(prompt) != len(image)", 688 | "1.0.0", 689 | deprecation_message, 690 | standard_warn=False, 691 | ) 692 | additional_image_per_prompt = batch_size // init_latents.shape[0] 693 | init_latents = torch.cat( 694 | [init_latents] * additional_image_per_prompt, dim=0 695 | ) 696 | elif ( 697 | batch_size > init_latents.shape[0] 698 | and batch_size % init_latents.shape[0] != 0 699 | ): 700 | raise ValueError( 701 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 702 | ) 703 | else: 704 | init_latents = torch.cat([init_latents], dim=0) 705 | 706 | shape = init_latents.shape 707 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 708 | 709 | # get latents 710 | latents = self.scheduler.add_noise(init_latents, noise, timestep) 711 | 712 | return latents, init_latents, noise 713 | 714 | @torch.no_grad() 715 | @replace_example_docstring(EXAMPLE_DOC_STRING) 716 | def __call__( # type: ignore 717 | self, 718 | prompt: Union[str, List[str]] = None, # type: ignore 719 | image: Union[ 720 | torch.FloatTensor, 721 | PIL.Image.Image, 722 | np.ndarray, 723 | List[torch.FloatTensor], 724 | List[PIL.Image.Image], 725 | List[np.ndarray], 726 | ] = None, 727 | mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, 728 | strength: float = 0.8, 729 | num_inference_steps: Optional[int] = 50, 730 | guidance_scale: Optional[float] = 7.5, 731 | negative_prompt: Optional[Union[str, List[str]]] = None, 732 | num_images_per_prompt: Optional[int] = 1, 733 | eta: Optional[float] = 0.0, 734 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 735 | prompt_embeds: Optional[torch.FloatTensor] = None, 736 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 737 | output_type: Optional[str] = "pil", 738 | return_dict: bool = True, 739 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 740 | callback_steps: int = 1, 741 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 742 | ): 743 | r""" 744 | Function invoked when calling the pipeline for generation. 745 | 746 | Args: 747 | prompt (`str` or `List[str]`, *optional*): 748 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 749 | instead. 750 | image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 751 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 752 | process. Can also accpet image latents_with_noise as `image`, if passing latents_with_noise directly, it will not be encoded 753 | again. 754 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`): 755 | `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be 756 | replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a 757 | PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the 758 | expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. 759 | strength (`float`, *optional*, defaults to 0.8): 760 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` 761 | will be used as a starting point, adding more noise to it the larger the `strength`. The number of 762 | denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will 763 | be maximum and the denoising process will run for the full number of iterations specified in 764 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. 765 | num_inference_steps (`int`, *optional*, defaults to 50): 766 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 767 | expense of slower inference. This parameter will be modulated by `strength`. 768 | guidance_scale (`float`, *optional*, defaults to 7.5): 769 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 770 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 771 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 772 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 773 | usually at the expense of lower image quality. 774 | negative_prompt (`str` or `List[str]`, *optional*): 775 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 776 | `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` 777 | is less than `1`). 778 | num_images_per_prompt (`int`, *optional*, defaults to 1): 779 | The number of images to generate per prompt. 780 | eta (`float`, *optional*, defaults to 0.0): 781 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 782 | [`schedulers.DDIMScheduler`], will be ignored for others. 783 | generator (`torch.Generator`, *optional*): 784 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 785 | to make generation deterministic. 786 | prompt_embeds (`torch.FloatTensor`, *optional*): 787 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 788 | provided, text embeddings will be generated from `prompt` input argument. 789 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 790 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 791 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 792 | argument. 793 | output_type (`str`, *optional*, defaults to `"pil"`): 794 | The output format of the generate image. Choose between 795 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 796 | return_dict (`bool`, *optional*, defaults to `True`): 797 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 798 | plain tuple. 799 | callback (`Callable`, *optional*): 800 | A function that will be called every `callback_steps` steps during inference. The function will be 801 | called with the following arguments: `callback(step: int, timestep: int, latents_with_noise: torch.FloatTensor)`. 802 | callback_steps (`int`, *optional*, defaults to 1): 803 | The frequency at which the `callback` function will be called. If not specified, the callback will be 804 | called at every step. 805 | cross_attention_kwargs (`dict`, *optional*): 806 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 807 | `self.processor` in 808 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 809 | Examples: 810 | 811 | Returns: 812 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 813 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 814 | When returning a tuple, the first element is a list with the generated images, and the second element is a 815 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 816 | (nsfw) content, according to the `safety_checker`. 817 | """ 818 | # 1. Check inputs. Raise error if not correct 819 | self.check_inputs( 820 | prompt, 821 | strength, 822 | callback_steps, 823 | negative_prompt, 824 | prompt_embeds, 825 | negative_prompt_embeds, 826 | ) 827 | if mask_image is None: 828 | raise ValueError("`mask_image` input cannot be undefined.") 829 | 830 | # 2. Define call parameters 831 | if prompt is not None and isinstance(prompt, str): 832 | batch_size = 1 833 | elif prompt is not None and isinstance(prompt, list): 834 | batch_size = len(prompt) 835 | else: 836 | batch_size = prompt_embeds.shape[0] 837 | device = self._execution_device 838 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 839 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 840 | # corresponds to doing no classifier free guidance. 841 | do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore 842 | 843 | # 3. Encode input prompt 844 | text_encoder_lora_scale = ( 845 | cross_attention_kwargs.get("scale", None) 846 | if cross_attention_kwargs is not None 847 | else None 848 | ) 849 | prompt_embeds = self._encode_prompt( 850 | prompt, 851 | device, 852 | num_images_per_prompt, 853 | do_classifier_free_guidance, 854 | negative_prompt, 855 | prompt_embeds=prompt_embeds, 856 | negative_prompt_embeds=negative_prompt_embeds, 857 | lora_scale=text_encoder_lora_scale, 858 | ) 859 | 860 | # 4. Preprocess image and mask 861 | image = self.image_processor.preprocess(image) 862 | mask = prepare_mask(mask=mask_image) 863 | 864 | # 5. set timesteps 865 | self.scheduler.set_timesteps(num_inference_steps, device=device) 866 | timesteps, num_inference_steps = self.get_timesteps( 867 | num_inference_steps, strength, device 868 | ) 869 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # type: ignore 870 | 871 | # 6. Prepare latent variables 872 | latents_with_noise, init_latents, noise = self.prepare_latents( 873 | image, 874 | latent_timestep, 875 | batch_size, 876 | num_images_per_prompt, 877 | prompt_embeds.dtype, 878 | device, 879 | generator, 880 | ) 881 | 882 | # 6.1. Prepare mask 883 | height, width = mask.shape[-2:] 884 | mask = torch.nn.functional.interpolate( 885 | mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) 886 | ) 887 | mask = mask.to(device) 888 | mask = mask.to(prompt_embeds.dtype) 889 | 890 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 891 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 892 | 893 | # 8. Denoising loop 894 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 895 | with self.progress_bar(total=num_inference_steps) as progress_bar: 896 | for i, t in enumerate(timesteps): 897 | # expand the latents_with_noise if we are doing classifier free guidance 898 | latent_model_input = ( 899 | torch.cat([latents_with_noise] * 2) 900 | if do_classifier_free_guidance 901 | else latents_with_noise 902 | ) 903 | latent_model_input = self.scheduler.scale_model_input( 904 | latent_model_input, t 905 | ) 906 | 907 | # predict the noise residual 908 | noise_pred = self.unet( 909 | latent_model_input, 910 | t, 911 | encoder_hidden_states=prompt_embeds, 912 | cross_attention_kwargs=cross_attention_kwargs, 913 | return_dict=False, 914 | )[0] 915 | 916 | # perform guidance 917 | if do_classifier_free_guidance: 918 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 919 | noise_pred = noise_pred_uncond + guidance_scale * ( 920 | noise_pred_text - noise_pred_uncond 921 | ) 922 | 923 | # compute the previous noisy sample x_t -> x_t-1 924 | latents_with_noise = self.scheduler.step( 925 | noise_pred, 926 | t, 927 | latents_with_noise, 928 | **extra_step_kwargs, 929 | return_dict=False, 930 | )[0] 931 | 932 | # masking process 933 | init_latents_proper = self.scheduler.add_noise( 934 | init_latents, noise, t 935 | ).to(device) 936 | 937 | mask = (mask > 0.5).to(prompt_embeds.dtype) 938 | latents_with_noise = ( 939 | mask * latents_with_noise + (1 - mask) * init_latents_proper 940 | ) 941 | 942 | # call the callback, if provided 943 | if i == len(timesteps) - 1 or ( 944 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 945 | ): 946 | progress_bar.update() 947 | if callback is not None and i % callback_steps == 0: 948 | callback(i, t, latents_with_noise) 949 | 950 | if not output_type == "latent": 951 | image = self.vae.decode( 952 | latents_with_noise / self.vae.config.scaling_factor, return_dict=False 953 | )[0] 954 | image, has_nsfw_concept = self.run_safety_checker( 955 | image, device, prompt_embeds.dtype 956 | ) 957 | else: 958 | image = latents_with_noise 959 | has_nsfw_concept = None 960 | 961 | if has_nsfw_concept is None: 962 | do_denormalize = [True] * image.shape[0] 963 | else: 964 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 965 | 966 | image = self.image_processor.postprocess( 967 | image, output_type=output_type, do_denormalize=do_denormalize 968 | ) 969 | 970 | # Offload last model to CPU 971 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 972 | self.final_offload_hook.offload() 973 | 974 | if not return_dict: 975 | return (image, has_nsfw_concept) 976 | 977 | return StableDiffusionPipelineOutput( 978 | images=image, nsfw_content_detected=has_nsfw_concept 979 | ) 980 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ["py38"] 4 | 5 | [tool.isort] 6 | profile = "black" 7 | multi_line_output = 3 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2023.5.7 2 | charset-normalizer==3.2.0 3 | click==8.1.4 4 | colorama==0.4.6 5 | diffusers==0.18.1 6 | filelock==3.12.2 7 | fsspec==2023.6.0 8 | huggingface-hub==0.16.4 9 | idna==3.4 10 | importlib-metadata==6.8.0 11 | Jinja2==3.1.2 12 | markdown-it-py==3.0.0 13 | MarkupSafe==2.1.3 14 | mdurl==0.1.2 15 | mpmath==1.3.0 16 | networkx==3.1 17 | numpy==1.25.0 18 | packaging==23.1 19 | Pillow==10.0.0 20 | Pygments==2.15.1 21 | PyYAML==6.0 22 | regex==2023.6.3 23 | requests==2.31.0 24 | rich==13.4.2 25 | safetensors==0.3.1 26 | shellingham==1.5.0.post1 27 | sympy==1.12 28 | tokenizers==0.13.3 29 | torch==2.0.1 30 | torchvision==0.15.2 31 | tqdm==4.65.0 32 | transformers==4.30.2 33 | typer==0.9.0 34 | typing_extensions==4.7.1 35 | urllib3==2.0.3 36 | zipp==3.15.0 37 | -------------------------------------------------------------------------------- /run_simple_inpainting.py: -------------------------------------------------------------------------------- 1 | """Script for running Simple Inpainting pipeline""" 2 | 3 | import enum 4 | from pathlib import Path 5 | 6 | import PIL.Image 7 | import torch 8 | import typer 9 | 10 | from pipelines import StableDiffusionSimpleInpaintingPipeline # type: ignore 11 | 12 | app = typer.Typer(pretty_exceptions_show_locals=False) 13 | 14 | 15 | class DTYPE(enum.Enum): 16 | FLOAT32 = "float32" 17 | FLOAT16 = "float16" 18 | 19 | 20 | @app.command() 21 | def run_inpainting( 22 | image_path: Path = typer.Option( 23 | default=Path("images/image.png"), help="Path to input image" 24 | ), 25 | mask_path: Path = typer.Option( 26 | default=Path("images/mask.png"), help="Path to input mask" 27 | ), 28 | save_path: Path = typer.Option( 29 | default=Path("images/inpainting.png"), help="Saves output image to *save_path*" 30 | ), 31 | save_concat: bool = typer.Option( 32 | default=True, help="Saves concatenated image, mask and inpaint to *save_path*" 33 | ), 34 | disable_safety_checker: bool = typer.Option( 35 | default=True, help="Disables safety checker" 36 | ), 37 | model_id: str = typer.Option( 38 | default="redstonehero/Yiffymix_Diffusers", help="Model ID in HuggingFace Hub" 39 | ), 40 | device: str = typer.Option(default="cuda", help="Device to run inference on"), 41 | dtype: DTYPE = typer.Option( 42 | default=DTYPE.FLOAT16.value, help="Data type to use for inference" 43 | ), 44 | seed: int = typer.Option(default=0, help="Random seed"), 45 | prompt: str = typer.Option(default="Cute cat", help="Prompt to run inference on"), 46 | strength: float = typer.Option( 47 | default=0.8, help="How much to transform the reference `image`" 48 | ), 49 | num_inference_steps: int = typer.Option( 50 | default=50, help="Number of inference steps" 51 | ), 52 | guidance_scale: float = typer.Option( 53 | default=7.5, 54 | help="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).", 55 | ), 56 | max_image_side: int = typer.Option( 57 | default=512, help="Maximum width or height of the output image" 58 | ), 59 | ) -> None: 60 | 61 | if dtype == DTYPE.FLOAT32: 62 | dtype_torch = torch.float32 63 | elif dtype == DTYPE.FLOAT16: 64 | dtype_torch = torch.float16 65 | else: 66 | raise ValueError(f"Unknown dtype: {dtype}") 67 | 68 | pipeline = StableDiffusionSimpleInpaintingPipeline.from_pretrained( 69 | pretrained_model_name_or_path=model_id, 70 | torch_dtype=dtype_torch, 71 | ) 72 | pipeline = pipeline.to(device) 73 | 74 | if disable_safety_checker: 75 | pipeline.safety_checker = None 76 | 77 | image = PIL.Image.open(image_path).convert("RGB") 78 | mask = PIL.Image.open(mask_path).convert("RGB") 79 | 80 | image.thumbnail((max_image_side, max_image_side)) 81 | mask.thumbnail((max_image_side, max_image_side)) 82 | 83 | image_result = pipeline( 84 | prompt=prompt, 85 | image=image, 86 | mask_image=mask, 87 | strength=strength, 88 | num_inference_steps=num_inference_steps, 89 | guidance_scale=guidance_scale, 90 | generator=torch.manual_seed(seed), 91 | ).images[0] 92 | 93 | if save_concat: 94 | # Concat original image, mask, and inpainted image 95 | concat = PIL.Image.new("RGB", (image.width * 3, image.height)) 96 | concat.paste(image, (0, 0)) 97 | concat.paste(mask, (image.width, 0)) 98 | concat.paste(image_result, (image.width * 2, 0)) 99 | 100 | concat.save(save_path) 101 | else: 102 | image_result.save(save_path) 103 | 104 | 105 | if __name__ == "__main__": 106 | app() 107 | --------------------------------------------------------------------------------