├── .gitignore ├── assets ├── color.png ├── font.png ├── size.png └── footnote.png ├── scripts ├── models │ ├── utils │ │ ├── .DS_Store │ │ └── richtext_utils.py │ └── dual_transformer_2d.py └── rich_text_settings.py ├── diffusers_official ├── models │ ├── activations.py │ ├── __init__.py │ ├── embeddings_flax.py │ ├── resnet_flax.py │ ├── modeling_flax_pytorch_utils.py │ ├── cross_attention.py │ ├── vq_model.py │ ├── modeling_pytorch_flax_utils.py │ ├── dual_transformer_2d.py │ ├── transformer_temporal.py │ └── unet_1d.py ├── utils │ ├── dummy_onnx_objects.py │ ├── dummy_note_seq_objects.py │ ├── dummy_torch_and_scipy_objects.py │ ├── dummy_torch_and_torchsde_objects.py │ ├── dummy_transformers_and_torch_and_note_seq_objects.py │ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py │ ├── dummy_torch_and_librosa_objects.py │ ├── dummy_torch_and_transformers_and_invisible_watermark_objects.py │ ├── constants.py │ ├── doc_utils.py │ ├── pil_utils.py │ ├── model_card_template.md │ ├── accelerate_utils.py │ ├── dummy_flax_and_transformers_objects.py │ ├── deprecation_utils.py │ ├── dummy_torch_and_transformers_and_onnx_objects.py │ ├── torch_utils.py │ ├── __init__.py │ ├── outputs.py │ ├── dummy_flax_objects.py │ └── logging.py ├── pipelines │ ├── stable_diffusion_xl │ │ ├── __init__.py │ │ └── watermark.py │ ├── __init__.py │ └── onnx_utils.py ├── pipeline_utils.py ├── dependency_versions_table.py ├── dependency_versions_check.py ├── schedulers │ ├── scheduling_sde_vp.py │ ├── __init__.py │ ├── scheduling_ipndm.py │ ├── scheduling_utils.py │ ├── scheduling_karras_ve_flax.py │ └── scheduling_karras_ve.py └── __init__.py ├── install.py ├── README.md └── share_btn.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | __pycache__/ 3 | *.pyc 4 | gradio_cached_examples/ -------------------------------------------------------------------------------- /assets/color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/color.png -------------------------------------------------------------------------------- /assets/font.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/font.png -------------------------------------------------------------------------------- /assets/size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/size.png -------------------------------------------------------------------------------- /assets/footnote.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/footnote.png -------------------------------------------------------------------------------- /scripts/models/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/scripts/models/utils/.DS_Store -------------------------------------------------------------------------------- /diffusers_official/models/activations.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def get_activation(act_fn): 5 | if act_fn in ["swish", "silu"]: 6 | return nn.SiLU() 7 | elif act_fn == "mish": 8 | return nn.Mish() 9 | elif act_fn == "gelu": 10 | return nn.GELU() 11 | else: 12 | raise ValueError(f"Unsupported activation function: {act_fn}") 13 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_onnx_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class OnnxRuntimeModel(metaclass=DummyObject): 6 | _backends = ["onnx"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["onnx"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["onnx"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["onnx"]) 18 | -------------------------------------------------------------------------------- /scripts/rich_text_settings.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | import os 4 | 5 | from modules import shared 6 | from modules import script_callbacks 7 | 8 | def on_ui_settings(): 9 | section = ('template', "Rich-Text-to-Image") 10 | shared.opts.add_option( 11 | "option1", 12 | shared.OptionInfo( 13 | False, 14 | "This is a placeholder for option. It is not used yet.", 15 | gr.Checkbox, 16 | {"interactive": True}, 17 | section=section) 18 | ) 19 | 20 | script_callbacks.on_ui_settings(on_ui_settings) 21 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_note_seq_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class MidiProcessor(metaclass=DummyObject): 6 | _backends = ["note_seq"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["note_seq"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["note_seq"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["note_seq"]) 18 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_torch_and_scipy_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class LMSDiscreteScheduler(metaclass=DummyObject): 6 | _backends = ["torch", "scipy"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "scipy"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "scipy"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "scipy"]) 18 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_torch_and_torchsde_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class DPMSolverSDEScheduler(metaclass=DummyObject): 6 | _backends = ["torch", "torchsde"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "torchsde"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "torchsde"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "torchsde"]) 18 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_transformers_and_torch_and_note_seq_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class SpectrogramDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["transformers", "torch", "note_seq"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["transformers", "torch", "note_seq"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["transformers", "torch", "note_seq"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["transformers", "torch", "note_seq"]) 18 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "transformers", "k_diffusion"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "transformers", "k_diffusion"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "transformers", "k_diffusion"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "transformers", "k_diffusion"]) 18 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | # TODO: add pip dependency if need extra module only on extension 4 | 5 | if not launch.is_installed("diffusers"): 6 | launch.run_pip("install diffusers==0.18.2", "requirements for Rich-Text-to-Image") 7 | 8 | if not launch.is_installed("invisible-watermark"): 9 | launch.run_pip("install invisible-watermark==0.2.0", "requirements for Rich-Text-to-Image") 10 | 11 | if not launch.is_installed("accelerate"): 12 | launch.run_pip("install accelerate==0.21.0", "requirements for Rich-Text-to-Image") 13 | 14 | if not launch.is_installed("safetensors"): 15 | launch.run_pip("install safetensors==0.3.1", "requirements for Rich-Text-to-Image") 16 | 17 | if not launch.is_installed("seaborn"): 18 | launch.run_pip("install seaborn==0.12.2", "requirements for Rich-Text-to-Image") 19 | 20 | if not launch.is_installed("scikit-learn"): 21 | launch.run_pip("install scikit-learn==1.3.0", "requirements for Rich-Text-to-Image") 22 | 23 | if not launch.is_installed("threadpoolctl"): 24 | launch.run_pip("install threadpoolctl==3.1.0", "requirements for Rich-Text-to-Image") -------------------------------------------------------------------------------- /diffusers_official/pipelines/stable_diffusion_xl/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL 6 | 7 | from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available 8 | 9 | 10 | @dataclass 11 | class StableDiffusionXLPipelineOutput(BaseOutput): 12 | """ 13 | Output class for Stable Diffusion pipelines. 14 | 15 | Args: 16 | images (`List[PIL.Image.Image]` or `np.ndarray`) 17 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, 18 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. 19 | """ 20 | 21 | images: Union[List[PIL.Image.Image], np.ndarray] 22 | 23 | 24 | if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): 25 | from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline 26 | from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline 27 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_torch_and_librosa_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class AudioDiffusionPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "librosa"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "librosa"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "librosa"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "librosa"]) 18 | 19 | 20 | class Mel(metaclass=DummyObject): 21 | _backends = ["torch", "librosa"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["torch", "librosa"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["torch", "librosa"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["torch", "librosa"]) 33 | -------------------------------------------------------------------------------- /diffusers_official/pipeline_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | 14 | # limitations under the License. 15 | 16 | # NOTE: This file is deprecated and will be removed in a future version. 17 | # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works 18 | 19 | from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 20 | from .utils import deprecate 21 | 22 | 23 | deprecate( 24 | "pipelines_utils", 25 | "0.22.0", 26 | "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.", 27 | standard_warn=False, 28 | stacklevel=3, 29 | ) 30 | -------------------------------------------------------------------------------- /diffusers_official/pipelines/stable_diffusion_xl/watermark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from imwatermark import WatermarkEncoder 4 | 5 | 6 | # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 7 | WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 8 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 9 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] 10 | 11 | 12 | class StableDiffusionXLWatermarker: 13 | def __init__(self): 14 | self.watermark = WATERMARK_BITS 15 | self.encoder = WatermarkEncoder() 16 | 17 | self.encoder.set_watermark("bits", self.watermark) 18 | 19 | def apply_watermark(self, images: torch.FloatTensor): 20 | # can't encode images that are smaller than 256 21 | if images.shape[-1] < 256: 22 | return images 23 | 24 | images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() 25 | 26 | images = [self.encoder.encode(image, "dwtDct") for image in images] 27 | 28 | images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) 29 | 30 | images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) 31 | return images 32 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "transformers", "invisible_watermark"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "transformers", "invisible_watermark"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) 18 | 19 | 20 | class StableDiffusionXLPipeline(metaclass=DummyObject): 21 | _backends = ["torch", "transformers", "invisible_watermark"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["torch", "transformers", "invisible_watermark"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) 33 | -------------------------------------------------------------------------------- /diffusers_official/utils/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home 17 | 18 | 19 | default_cache_path = HUGGINGFACE_HUB_CACHE 20 | 21 | 22 | CONFIG_NAME = "config.json" 23 | WEIGHTS_NAME = "diffusion_pytorch_model.bin" 24 | FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" 25 | ONNX_WEIGHTS_NAME = "model.onnx" 26 | SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" 27 | ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" 28 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" 29 | DIFFUSERS_CACHE = default_cache_path 30 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" 31 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) 32 | DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] 33 | TEXT_ENCODER_ATTN_MODULE = ".self_attn" 34 | -------------------------------------------------------------------------------- /diffusers_official/utils/doc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Doc utilities: Utilities related to documentation 16 | """ 17 | import re 18 | 19 | 20 | def replace_example_docstring(example_docstring): 21 | def docstring_decorator(fn): 22 | func_doc = fn.__doc__ 23 | lines = func_doc.split("\n") 24 | i = 0 25 | while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: 26 | i += 1 27 | if i < len(lines): 28 | lines[i] = example_docstring 29 | func_doc = "\n".join(lines) 30 | else: 31 | raise ValueError( 32 | f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " 33 | f"current docstring is:\n{func_doc}" 34 | ) 35 | fn.__doc__ = func_doc 36 | return fn 37 | 38 | return docstring_decorator 39 | -------------------------------------------------------------------------------- /diffusers_official/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ..utils import is_flax_available, is_torch_available 16 | 17 | 18 | if is_torch_available(): 19 | from .autoencoder_kl import AutoencoderKL 20 | from .controlnet import ControlNetModel 21 | from .dual_transformer_2d import DualTransformer2DModel 22 | from .modeling_utils import ModelMixin 23 | from .prior_transformer import PriorTransformer 24 | from .t5_film_transformer import T5FilmDecoder 25 | from .transformer_2d import Transformer2DModel 26 | from .unet_1d import UNet1DModel 27 | from .unet_2d import UNet2DModel 28 | from .unet_2d_condition import UNet2DConditionModel 29 | from .unet_3d_condition import UNet3DConditionModel 30 | from .vq_model import VQModel 31 | 32 | if is_flax_available(): 33 | from .controlnet_flax import FlaxControlNetModel 34 | from .unet_2d_condition_flax import FlaxUNet2DConditionModel 35 | from .vae_flax import FlaxAutoencoderKL 36 | -------------------------------------------------------------------------------- /diffusers_official/utils/pil_utils.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import PIL.ImageOps 3 | from packaging import version 4 | from PIL import Image 5 | 6 | 7 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 8 | PIL_INTERPOLATION = { 9 | "linear": PIL.Image.Resampling.BILINEAR, 10 | "bilinear": PIL.Image.Resampling.BILINEAR, 11 | "bicubic": PIL.Image.Resampling.BICUBIC, 12 | "lanczos": PIL.Image.Resampling.LANCZOS, 13 | "nearest": PIL.Image.Resampling.NEAREST, 14 | } 15 | else: 16 | PIL_INTERPOLATION = { 17 | "linear": PIL.Image.LINEAR, 18 | "bilinear": PIL.Image.BILINEAR, 19 | "bicubic": PIL.Image.BICUBIC, 20 | "lanczos": PIL.Image.LANCZOS, 21 | "nearest": PIL.Image.NEAREST, 22 | } 23 | 24 | 25 | def pt_to_pil(images): 26 | """ 27 | Convert a torch image to a PIL image. 28 | """ 29 | images = (images / 2 + 0.5).clamp(0, 1) 30 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 31 | images = numpy_to_pil(images) 32 | return images 33 | 34 | 35 | def numpy_to_pil(images): 36 | """ 37 | Convert a numpy image or a batch of images to a PIL image. 38 | """ 39 | if images.ndim == 3: 40 | images = images[None, ...] 41 | images = (images * 255).round().astype("uint8") 42 | if images.shape[-1] == 1: 43 | # special case for grayscale (single channel) images 44 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] 45 | else: 46 | pil_images = [Image.fromarray(image) for image in images] 47 | 48 | return pil_images 49 | -------------------------------------------------------------------------------- /diffusers_official/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "Pillow": "Pillow", 6 | "accelerate": "accelerate>=0.11.0", 7 | "compel": "compel==0.1.8", 8 | "black": "black~=23.1", 9 | "datasets": "datasets", 10 | "filelock": "filelock", 11 | "flax": "flax>=0.4.1", 12 | "hf-doc-builder": "hf-doc-builder>=0.3.0", 13 | "huggingface-hub": "huggingface-hub>=0.13.2", 14 | "requests-mock": "requests-mock==1.10.0", 15 | "importlib_metadata": "importlib_metadata", 16 | "invisible-watermark": "invisible-watermark", 17 | "isort": "isort>=5.5.4", 18 | "jax": "jax>=0.2.8,!=0.3.2", 19 | "jaxlib": "jaxlib>=0.1.65", 20 | "Jinja2": "Jinja2", 21 | "k-diffusion": "k-diffusion>=0.0.12", 22 | "torchsde": "torchsde", 23 | "note_seq": "note_seq", 24 | "librosa": "librosa", 25 | "numpy": "numpy", 26 | "omegaconf": "omegaconf", 27 | "parameterized": "parameterized", 28 | "protobuf": "protobuf>=3.20.3,<4", 29 | "pytest": "pytest", 30 | "pytest-timeout": "pytest-timeout", 31 | "pytest-xdist": "pytest-xdist", 32 | "ruff": "ruff>=0.0.241", 33 | "safetensors": "safetensors", 34 | "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", 35 | "scipy": "scipy", 36 | "onnx": "onnx", 37 | "regex": "regex!=2019.12.17", 38 | "requests": "requests", 39 | "tensorboard": "tensorboard", 40 | "torch": "torch>=1.4", 41 | "torchvision": "torchvision", 42 | "transformers": "transformers>=4.25.1", 43 | "urllib3": "urllib3<=2.0.0", 44 | } 45 | -------------------------------------------------------------------------------- /diffusers_official/utils/model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | 7 | 8 | # {{ model_name | default("Diffusion Model") }} 9 | 10 | ## Model description 11 | 12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library 13 | on the `{{ dataset_name }}` dataset. 14 | 15 | ## Intended uses & limitations 16 | 17 | #### How to use 18 | 19 | ```python 20 | # TODO: add an example code snippet for running this diffusion pipeline 21 | ``` 22 | 23 | #### Limitations and bias 24 | 25 | [TODO: provide examples of latent issues and potential remediations] 26 | 27 | ## Training data 28 | 29 | [TODO: describe the data used to train the model] 30 | 31 | ### Training hyperparameters 32 | 33 | The following hyperparameters were used during training: 34 | - learning_rate: {{ learning_rate }} 35 | - train_batch_size: {{ train_batch_size }} 36 | - eval_batch_size: {{ eval_batch_size }} 37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }} 38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} 39 | - lr_scheduler: {{ lr_scheduler }} 40 | - lr_warmup_steps: {{ lr_warmup_steps }} 41 | - ema_inv_gamma: {{ ema_inv_gamma }} 42 | - ema_inv_gamma: {{ ema_power }} 43 | - ema_inv_gamma: {{ ema_max_decay }} 44 | - mixed_precision: {{ mixed_precision }} 45 | 46 | ### Training results 47 | 48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) 49 | 50 | 51 | -------------------------------------------------------------------------------- /diffusers_official/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | from .dependency_versions_table import deps 17 | from .utils.versions import require_version, require_version_core 18 | 19 | 20 | # define which module versions we always want to check at run time 21 | # (usually the ones defined in `install_requires` in setup.py) 22 | # 23 | # order specific notes: 24 | # - tqdm must be checked before tokenizers 25 | 26 | pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() 27 | if sys.version_info < (3, 7): 28 | pkgs_to_check_at_runtime.append("dataclasses") 29 | if sys.version_info < (3, 8): 30 | pkgs_to_check_at_runtime.append("importlib_metadata") 31 | 32 | for pkg in pkgs_to_check_at_runtime: 33 | if pkg in deps: 34 | if pkg == "tokenizers": 35 | # must be loaded here, or else tqdm check may fail 36 | from .utils import is_tokenizers_available 37 | 38 | if not is_tokenizers_available(): 39 | continue # not required, check version only if installed 40 | 41 | require_version_core(deps[pkg]) 42 | else: 43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 44 | 45 | 46 | def dep_version_check(pkg, hint=None): 47 | require_version(deps[pkg], hint) 48 | -------------------------------------------------------------------------------- /diffusers_official/utils/accelerate_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Accelerate utilities: Utilities related to accelerate 16 | """ 17 | 18 | from packaging import version 19 | 20 | from .import_utils import is_accelerate_available 21 | 22 | 23 | if is_accelerate_available(): 24 | import accelerate 25 | 26 | 27 | def apply_forward_hook(method): 28 | """ 29 | Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful 30 | for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the 31 | appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. 32 | 33 | This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. 34 | 35 | :param method: The method to decorate. This method should be a method of a PyTorch module. 36 | """ 37 | if not is_accelerate_available(): 38 | return method 39 | accelerate_version = version.parse(accelerate.__version__).base_version 40 | if version.parse(accelerate_version) < version.parse("0.17.0"): 41 | return method 42 | 43 | def wrapper(self, *args, **kwargs): 44 | if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): 45 | self._hf_hook.pre_forward(self) 46 | return method(self, *args, **kwargs) 47 | 48 | return wrapper 49 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_flax_and_transformers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): 6 | _backends = ["flax", "transformers"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["flax", "transformers"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["flax", "transformers"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["flax", "transformers"]) 18 | 19 | 20 | class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): 21 | _backends = ["flax", "transformers"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["flax", "transformers"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["flax", "transformers"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["flax", "transformers"]) 33 | 34 | 35 | class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): 36 | _backends = ["flax", "transformers"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["flax", "transformers"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["flax", "transformers"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["flax", "transformers"]) 48 | 49 | 50 | class FlaxStableDiffusionPipeline(metaclass=DummyObject): 51 | _backends = ["flax", "transformers"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["flax", "transformers"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["flax", "transformers"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["flax", "transformers"]) 63 | -------------------------------------------------------------------------------- /diffusers_official/utils/deprecation_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from typing import Any, Dict, Optional, Union 4 | 5 | from packaging import version 6 | 7 | 8 | def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): 9 | from .. import __version__ 10 | 11 | deprecated_kwargs = take_from 12 | values = () 13 | if not isinstance(args[0], tuple): 14 | args = (args,) 15 | 16 | for attribute, version_name, message in args: 17 | if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): 18 | raise ValueError( 19 | f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" 20 | f" version {__version__} is >= {version_name}" 21 | ) 22 | 23 | warning = None 24 | if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: 25 | values += (deprecated_kwargs.pop(attribute),) 26 | warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." 27 | elif hasattr(deprecated_kwargs, attribute): 28 | values += (getattr(deprecated_kwargs, attribute),) 29 | warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." 30 | elif deprecated_kwargs is None: 31 | warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." 32 | 33 | if warning is not None: 34 | warning = warning + " " if standard_warn else "" 35 | warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) 36 | 37 | if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: 38 | call_frame = inspect.getouterframes(inspect.currentframe())[1] 39 | filename = call_frame.filename 40 | line_number = call_frame.lineno 41 | function = call_frame.function 42 | key, value = next(iter(deprecated_kwargs.items())) 43 | raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") 44 | 45 | if len(values) == 0: 46 | return 47 | elif len(values) == 1: 48 | return values[0] 49 | return values 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rich-Text-to-Image for Stable Diffusion WebUI 2 | #### [Project Page](https://rich-text-to-image.github.io/) | [Paper](https://arxiv.org/abs/2304.06720) | [Code](https://github.com/songweige/rich-text-to-image) | [HuggingFace Demo](https://huggingface.co/spaces/songweig/rich-text-to-image) | [Video](https://youtu.be/ihDbAUh0LXk) 3 | 4 | The WebUI extension for integrating a rich-text editor for text-to-image generation. 5 | 6 | ![image](https://github.com/songweige/sd-webui-rich-text/assets/22885450/c57cf981-8332-41fb-8f47-b03238311ca4) 7 | 8 | This extension is for [AUTOMATIC1111's Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [rich-text-to-image](https://rich-text-to-image.github.io/) to the original Stable Diffusion model to generate images. 9 | 10 | ## Installation 11 | 12 | 1. Open "Extensions" tab. 13 | 1. Open "Install from URL" tab in the tab. 14 | 1. Enter URL of this repo (https://github.com/songweige/sd-webui-rich-text) to "URL for extension's git repository". 15 | 1. Press "Install" button. 16 | 1. Restart Web UI. 17 | 18 | ## Usage 19 | 20 | The extension now supports [SD-v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) (default), [SD-XL-v1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [ANIMAGINE-XL](https://huggingface.co/Linaqruf/animagine-xl). The checkpoints will be automatically downloaded when the model is first selected. 21 | 22 | 23 | #### Font Color 24 | 25 | ![color](assets/color.png) 26 | 27 | Font color is used to control the precise color of the generated objects. 28 | 29 | #### Footnote 30 | 31 | ![footnote](assets/footnote.png) 32 | 33 | Footnotes provide supplementary descriptions for selected text elements. 34 | 35 | #### Font Style 36 | 37 | ![style](assets/font.png) 38 | 39 | Just as the font style distinguishes the styles of individual text elements, it is used to define the artistic style of specific areas in the generation. 40 | 41 | #### Font Size 42 | 43 | ![size](assets/size.png) 44 | 45 | Font size indicates the weight of each token in the final generation. 46 | 47 | ## Acknowledgement 48 | 49 | The extension is built on the [extension-templates](https://github.com/udon-universe/stable-diffusion-webui-extension-templates). The rich-text editor is built on [Quill](https://quilljs.com/). The model code is built on [huggingface / diffusers](https://github.com/huggingface/diffusers#readme). 50 | -------------------------------------------------------------------------------- /diffusers_official/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import ( 2 | OptionalDependencyNotAvailable, 3 | is_flax_available, 4 | is_invisible_watermark_available, 5 | is_k_diffusion_available, 6 | is_librosa_available, 7 | is_note_seq_available, 8 | is_onnx_available, 9 | is_torch_available, 10 | is_transformers_available, 11 | ) 12 | 13 | 14 | try: 15 | if not is_torch_available(): 16 | raise OptionalDependencyNotAvailable() 17 | except OptionalDependencyNotAvailable: 18 | from ..utils.dummy_pt_objects import * # noqa F403 19 | else: 20 | from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput 21 | 22 | try: 23 | if not (is_torch_available() and is_librosa_available()): 24 | raise OptionalDependencyNotAvailable() 25 | except OptionalDependencyNotAvailable: 26 | from ..utils.dummy_torch_and_librosa_objects import * # noqa F403 27 | 28 | try: 29 | if not (is_torch_available() and is_transformers_available()): 30 | raise OptionalDependencyNotAvailable() 31 | except OptionalDependencyNotAvailable: 32 | from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 33 | 34 | 35 | try: 36 | if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): 37 | raise OptionalDependencyNotAvailable() 38 | except OptionalDependencyNotAvailable: 39 | from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 40 | else: 41 | from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline 42 | 43 | try: 44 | if not is_onnx_available(): 45 | raise OptionalDependencyNotAvailable() 46 | except OptionalDependencyNotAvailable: 47 | from ..utils.dummy_onnx_objects import * # noqa F403 48 | else: 49 | from .onnx_utils import OnnxRuntimeModel 50 | 51 | try: 52 | if not (is_torch_available() and is_transformers_available() and is_onnx_available()): 53 | raise OptionalDependencyNotAvailable() 54 | except OptionalDependencyNotAvailable: 55 | from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 56 | 57 | 58 | try: 59 | if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): 60 | raise OptionalDependencyNotAvailable() 61 | except OptionalDependencyNotAvailable: 62 | from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 63 | 64 | try: 65 | if not is_flax_available(): 66 | raise OptionalDependencyNotAvailable() 67 | except OptionalDependencyNotAvailable: 68 | from ..utils.dummy_flax_objects import * # noqa F403 69 | else: 70 | from .pipeline_flax_utils import FlaxDiffusionPipeline 71 | 72 | 73 | try: 74 | if not (is_flax_available() and is_transformers_available()): 75 | raise OptionalDependencyNotAvailable() 76 | except OptionalDependencyNotAvailable: 77 | from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 78 | 79 | try: 80 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): 81 | raise OptionalDependencyNotAvailable() 82 | except OptionalDependencyNotAvailable: 83 | from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 84 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_torch_and_transformers_and_onnx_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): 6 | _backends = ["torch", "transformers", "onnx"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["torch", "transformers", "onnx"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["torch", "transformers", "onnx"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["torch", "transformers", "onnx"]) 18 | 19 | 20 | class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): 21 | _backends = ["torch", "transformers", "onnx"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["torch", "transformers", "onnx"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["torch", "transformers", "onnx"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["torch", "transformers", "onnx"]) 33 | 34 | 35 | class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): 36 | _backends = ["torch", "transformers", "onnx"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["torch", "transformers", "onnx"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["torch", "transformers", "onnx"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["torch", "transformers", "onnx"]) 48 | 49 | 50 | class OnnxStableDiffusionPipeline(metaclass=DummyObject): 51 | _backends = ["torch", "transformers", "onnx"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["torch", "transformers", "onnx"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["torch", "transformers", "onnx"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["torch", "transformers", "onnx"]) 63 | 64 | 65 | class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): 66 | _backends = ["torch", "transformers", "onnx"] 67 | 68 | def __init__(self, *args, **kwargs): 69 | requires_backends(self, ["torch", "transformers", "onnx"]) 70 | 71 | @classmethod 72 | def from_config(cls, *args, **kwargs): 73 | requires_backends(cls, ["torch", "transformers", "onnx"]) 74 | 75 | @classmethod 76 | def from_pretrained(cls, *args, **kwargs): 77 | requires_backends(cls, ["torch", "transformers", "onnx"]) 78 | 79 | 80 | class StableDiffusionOnnxPipeline(metaclass=DummyObject): 81 | _backends = ["torch", "transformers", "onnx"] 82 | 83 | def __init__(self, *args, **kwargs): 84 | requires_backends(self, ["torch", "transformers", "onnx"]) 85 | 86 | @classmethod 87 | def from_config(cls, *args, **kwargs): 88 | requires_backends(cls, ["torch", "transformers", "onnx"]) 89 | 90 | @classmethod 91 | def from_pretrained(cls, *args, **kwargs): 92 | requires_backends(cls, ["torch", "transformers", "onnx"]) 93 | -------------------------------------------------------------------------------- /diffusers_official/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | PyTorch utilities: Utilities related to PyTorch 16 | """ 17 | from typing import List, Optional, Tuple, Union 18 | 19 | from . import logging 20 | from .import_utils import is_torch_available, is_torch_version 21 | 22 | 23 | if is_torch_available(): 24 | import torch 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | try: 29 | from torch._dynamo import allow_in_graph as maybe_allow_in_graph 30 | except (ImportError, ModuleNotFoundError): 31 | 32 | def maybe_allow_in_graph(cls): 33 | return cls 34 | 35 | 36 | def randn_tensor( 37 | shape: Union[Tuple, List], 38 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 39 | device: Optional["torch.device"] = None, 40 | dtype: Optional["torch.dtype"] = None, 41 | layout: Optional["torch.layout"] = None, 42 | ): 43 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When 44 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor 45 | is always created on the CPU. 46 | """ 47 | # device on which tensor is created defaults to device 48 | rand_device = device 49 | batch_size = shape[0] 50 | 51 | layout = layout or torch.strided 52 | device = device or torch.device("cpu") 53 | 54 | if generator is not None: 55 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 56 | if gen_device_type != device.type and gen_device_type == "cpu": 57 | rand_device = "cpu" 58 | if device != "mps": 59 | logger.info( 60 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 61 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 62 | f" slighly speed up this function by passing a generator that was created on the {device} device." 63 | ) 64 | elif gen_device_type != device.type and gen_device_type == "cuda": 65 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 66 | 67 | if isinstance(generator, list): 68 | shape = (1,) + shape[1:] 69 | latents = [ 70 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 71 | for i in range(batch_size) 72 | ] 73 | latents = torch.cat(latents, dim=0).to(device) 74 | else: 75 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 76 | 77 | return latents 78 | 79 | 80 | def is_compiled_module(module): 81 | """Check whether the module was compiled with torch.compile()""" 82 | if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): 83 | return False 84 | return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) 85 | -------------------------------------------------------------------------------- /diffusers_official/schedulers/scheduling_sde_vp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch 16 | 17 | import math 18 | from typing import Union 19 | 20 | import torch 21 | 22 | from ..configuration_utils import ConfigMixin, register_to_config 23 | from ..utils import randn_tensor 24 | from .scheduling_utils import SchedulerMixin 25 | 26 | 27 | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): 28 | """ 29 | The variance preserving stochastic differential equation (SDE) scheduler. 30 | 31 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 32 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 33 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 34 | [`~SchedulerMixin.from_pretrained`] functions. 35 | 36 | For more information, see the original paper: https://arxiv.org/abs/2011.13456 37 | 38 | UNDER CONSTRUCTION 39 | 40 | """ 41 | 42 | order = 1 43 | 44 | @register_to_config 45 | def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): 46 | self.sigmas = None 47 | self.discrete_sigmas = None 48 | self.timesteps = None 49 | 50 | def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): 51 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) 52 | 53 | def step_pred(self, score, x, t, generator=None): 54 | if self.timesteps is None: 55 | raise ValueError( 56 | "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" 57 | ) 58 | 59 | # TODO(Patrick) better comments + non-PyTorch 60 | # postprocess model score 61 | log_mean_coeff = ( 62 | -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min 63 | ) 64 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) 65 | std = std.flatten() 66 | while len(std.shape) < len(score.shape): 67 | std = std.unsqueeze(-1) 68 | score = -score / std 69 | 70 | # compute 71 | dt = -1.0 / len(self.timesteps) 72 | 73 | beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) 74 | beta_t = beta_t.flatten() 75 | while len(beta_t.shape) < len(x.shape): 76 | beta_t = beta_t.unsqueeze(-1) 77 | drift = -0.5 * beta_t * x 78 | 79 | diffusion = torch.sqrt(beta_t) 80 | drift = drift - diffusion**2 * score 81 | x_mean = x + drift * dt 82 | 83 | # add noise 84 | noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) 85 | x = x_mean + diffusion * math.sqrt(-dt) * noise 86 | 87 | return x, x_mean 88 | 89 | def __len__(self): 90 | return self.config.num_train_timesteps 91 | -------------------------------------------------------------------------------- /diffusers_official/models/embeddings_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | 16 | import flax.linen as nn 17 | import jax.numpy as jnp 18 | 19 | 20 | def get_sinusoidal_embeddings( 21 | timesteps: jnp.ndarray, 22 | embedding_dim: int, 23 | freq_shift: float = 1, 24 | min_timescale: float = 1, 25 | max_timescale: float = 1.0e4, 26 | flip_sin_to_cos: bool = False, 27 | scale: float = 1.0, 28 | ) -> jnp.ndarray: 29 | """Returns the positional encoding (same as Tensor2Tensor). 30 | 31 | Args: 32 | timesteps: a 1-D Tensor of N indices, one per batch element. 33 | These may be fractional. 34 | embedding_dim: The number of output channels. 35 | min_timescale: The smallest time unit (should probably be 0.0). 36 | max_timescale: The largest time unit. 37 | Returns: 38 | a Tensor of timing signals [N, num_channels] 39 | """ 40 | assert timesteps.ndim == 1, "Timesteps should be a 1d-array" 41 | assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" 42 | num_timescales = float(embedding_dim // 2) 43 | log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) 44 | inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) 45 | emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) 46 | 47 | # scale embeddings 48 | scaled_time = scale * emb 49 | 50 | if flip_sin_to_cos: 51 | signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) 52 | else: 53 | signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) 54 | signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) 55 | return signal 56 | 57 | 58 | class FlaxTimestepEmbedding(nn.Module): 59 | r""" 60 | Time step Embedding Module. Learns embeddings for input time steps. 61 | 62 | Args: 63 | time_embed_dim (`int`, *optional*, defaults to `32`): 64 | Time step embedding dimension 65 | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): 66 | Parameters `dtype` 67 | """ 68 | time_embed_dim: int = 32 69 | dtype: jnp.dtype = jnp.float32 70 | 71 | @nn.compact 72 | def __call__(self, temb): 73 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) 74 | temb = nn.silu(temb) 75 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) 76 | return temb 77 | 78 | 79 | class FlaxTimesteps(nn.Module): 80 | r""" 81 | Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 82 | 83 | Args: 84 | dim (`int`, *optional*, defaults to `32`): 85 | Time step embedding dimension 86 | """ 87 | dim: int = 32 88 | flip_sin_to_cos: bool = False 89 | freq_shift: float = 1 90 | 91 | @nn.compact 92 | def __call__(self, timesteps): 93 | return get_sinusoidal_embeddings( 94 | timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift 95 | ) 96 | -------------------------------------------------------------------------------- /diffusers_official/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | from packaging import version 19 | 20 | from .. import __version__ 21 | from .accelerate_utils import apply_forward_hook 22 | from .constants import ( 23 | CONFIG_NAME, 24 | DEPRECATED_REVISION_ARGS, 25 | DIFFUSERS_CACHE, 26 | DIFFUSERS_DYNAMIC_MODULE_NAME, 27 | FLAX_WEIGHTS_NAME, 28 | HF_MODULES_CACHE, 29 | HUGGINGFACE_CO_RESOLVE_ENDPOINT, 30 | ONNX_EXTERNAL_WEIGHTS_NAME, 31 | ONNX_WEIGHTS_NAME, 32 | SAFETENSORS_WEIGHTS_NAME, 33 | TEXT_ENCODER_ATTN_MODULE, 34 | WEIGHTS_NAME, 35 | ) 36 | from .deprecation_utils import deprecate 37 | from .doc_utils import replace_example_docstring 38 | from .dynamic_modules_utils import get_class_from_dynamic_module 39 | from .hub_utils import ( 40 | HF_HUB_OFFLINE, 41 | _add_variant, 42 | _get_model_file, 43 | extract_commit_hash, 44 | http_user_agent, 45 | ) 46 | from .import_utils import ( 47 | BACKENDS_MAPPING, 48 | ENV_VARS_TRUE_AND_AUTO_VALUES, 49 | ENV_VARS_TRUE_VALUES, 50 | USE_JAX, 51 | USE_TF, 52 | USE_TORCH, 53 | DummyObject, 54 | OptionalDependencyNotAvailable, 55 | is_accelerate_available, 56 | is_accelerate_version, 57 | is_bs4_available, 58 | is_flax_available, 59 | is_ftfy_available, 60 | is_inflect_available, 61 | is_invisible_watermark_available, 62 | is_k_diffusion_available, 63 | is_k_diffusion_version, 64 | is_librosa_available, 65 | is_note_seq_available, 66 | is_omegaconf_available, 67 | is_onnx_available, 68 | is_safetensors_available, 69 | is_scipy_available, 70 | is_tensorboard_available, 71 | is_tf_available, 72 | is_torch_available, 73 | is_torch_version, 74 | is_torchsde_available, 75 | is_transformers_available, 76 | is_transformers_version, 77 | is_unidecode_available, 78 | is_wandb_available, 79 | is_xformers_available, 80 | requires_backends, 81 | ) 82 | from .logging import get_logger 83 | from .outputs import BaseOutput 84 | from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil 85 | from .torch_utils import is_compiled_module, randn_tensor 86 | 87 | 88 | if is_torch_available(): 89 | from .testing_utils import ( 90 | floats_tensor, 91 | load_hf_numpy, 92 | load_image, 93 | load_numpy, 94 | load_pt, 95 | nightly, 96 | parse_flag_from_env, 97 | print_tensor_test, 98 | require_torch_2, 99 | require_torch_gpu, 100 | skip_mps, 101 | slow, 102 | torch_all_close, 103 | torch_device, 104 | ) 105 | from .torch_utils import maybe_allow_in_graph 106 | 107 | from .testing_utils import export_to_gif, export_to_video 108 | 109 | 110 | logger = get_logger(__name__) 111 | 112 | 113 | def check_min_version(min_version): 114 | if version.parse(__version__) < version.parse(min_version): 115 | if "dev" in min_version: 116 | error_message = ( 117 | "This example requires a source install from HuggingFace diffusers (see " 118 | "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," 119 | ) 120 | else: 121 | error_message = f"This example requires a minimum version of {min_version}," 122 | error_message += f" but the version found is {__version__}.\n" 123 | raise ImportError(error_message) 124 | -------------------------------------------------------------------------------- /diffusers_official/utils/outputs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Generic utilities 16 | """ 17 | 18 | from collections import OrderedDict 19 | from dataclasses import fields 20 | from typing import Any, Tuple 21 | 22 | import numpy as np 23 | 24 | from .import_utils import is_torch_available 25 | 26 | 27 | def is_tensor(x): 28 | """ 29 | Tests if `x` is a `torch.Tensor` or `np.ndarray`. 30 | """ 31 | if is_torch_available(): 32 | import torch 33 | 34 | if isinstance(x, torch.Tensor): 35 | return True 36 | 37 | return isinstance(x, np.ndarray) 38 | 39 | 40 | class BaseOutput(OrderedDict): 41 | """ 42 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a 43 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular 44 | Python dictionary. 45 | 46 | 47 | 48 | You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple 49 | first. 50 | 51 | 52 | """ 53 | 54 | def __post_init__(self): 55 | class_fields = fields(self) 56 | 57 | # Safety and consistency checks 58 | if not len(class_fields): 59 | raise ValueError(f"{self.__class__.__name__} has no fields.") 60 | 61 | first_field = getattr(self, class_fields[0].name) 62 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) 63 | 64 | if other_fields_are_none and isinstance(first_field, dict): 65 | for key, value in first_field.items(): 66 | self[key] = value 67 | else: 68 | for field in class_fields: 69 | v = getattr(self, field.name) 70 | if v is not None: 71 | self[field.name] = v 72 | 73 | def __delitem__(self, *args, **kwargs): 74 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") 75 | 76 | def setdefault(self, *args, **kwargs): 77 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") 78 | 79 | def pop(self, *args, **kwargs): 80 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") 81 | 82 | def update(self, *args, **kwargs): 83 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") 84 | 85 | def __getitem__(self, k): 86 | if isinstance(k, str): 87 | inner_dict = dict(self.items()) 88 | return inner_dict[k] 89 | else: 90 | return self.to_tuple()[k] 91 | 92 | def __setattr__(self, name, value): 93 | if name in self.keys() and value is not None: 94 | # Don't call self.__setitem__ to avoid recursion errors 95 | super().__setitem__(name, value) 96 | super().__setattr__(name, value) 97 | 98 | def __setitem__(self, key, value): 99 | # Will raise a KeyException if needed 100 | super().__setitem__(key, value) 101 | # Don't call self.__setattr__ to avoid recursion errors 102 | super().__setattr__(key, value) 103 | 104 | def to_tuple(self) -> Tuple[Any]: 105 | """ 106 | Convert self to a tuple containing all the attributes/keys that are not `None`. 107 | """ 108 | return tuple(self[k] for k in self.keys()) 109 | -------------------------------------------------------------------------------- /diffusers_official/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from ..utils import ( 17 | OptionalDependencyNotAvailable, 18 | is_flax_available, 19 | is_scipy_available, 20 | is_torch_available, 21 | is_torchsde_available, 22 | ) 23 | 24 | 25 | try: 26 | if not is_torch_available(): 27 | raise OptionalDependencyNotAvailable() 28 | except OptionalDependencyNotAvailable: 29 | from ..utils.dummy_pt_objects import * # noqa F403 30 | else: 31 | from .scheduling_consistency_models import CMStochasticIterativeScheduler 32 | from .scheduling_ddim import DDIMScheduler 33 | from .scheduling_ddim_inverse import DDIMInverseScheduler 34 | from .scheduling_ddim_parallel import DDIMParallelScheduler 35 | from .scheduling_ddpm import DDPMScheduler 36 | from .scheduling_ddpm_parallel import DDPMParallelScheduler 37 | from .scheduling_deis_multistep import DEISMultistepScheduler 38 | from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler 39 | from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler 40 | from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler 41 | from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 42 | from .scheduling_euler_discrete import EulerDiscreteScheduler 43 | from .scheduling_heun_discrete import HeunDiscreteScheduler 44 | from .scheduling_ipndm import IPNDMScheduler 45 | from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler 46 | from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler 47 | from .scheduling_karras_ve import KarrasVeScheduler 48 | from .scheduling_pndm import PNDMScheduler 49 | from .scheduling_repaint import RePaintScheduler 50 | from .scheduling_sde_ve import ScoreSdeVeScheduler 51 | from .scheduling_sde_vp import ScoreSdeVpScheduler 52 | from .scheduling_unclip import UnCLIPScheduler 53 | from .scheduling_unipc_multistep import UniPCMultistepScheduler 54 | from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 55 | from .scheduling_vq_diffusion import VQDiffusionScheduler 56 | 57 | try: 58 | if not is_flax_available(): 59 | raise OptionalDependencyNotAvailable() 60 | except OptionalDependencyNotAvailable: 61 | from ..utils.dummy_flax_objects import * # noqa F403 62 | else: 63 | from .scheduling_ddim_flax import FlaxDDIMScheduler 64 | from .scheduling_ddpm_flax import FlaxDDPMScheduler 65 | from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler 66 | from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler 67 | from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler 68 | from .scheduling_pndm_flax import FlaxPNDMScheduler 69 | from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler 70 | from .scheduling_utils_flax import ( 71 | FlaxKarrasDiffusionSchedulers, 72 | FlaxSchedulerMixin, 73 | FlaxSchedulerOutput, 74 | broadcast_to_shape_from_left, 75 | ) 76 | 77 | 78 | try: 79 | if not (is_torch_available() and is_scipy_available()): 80 | raise OptionalDependencyNotAvailable() 81 | except OptionalDependencyNotAvailable: 82 | from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 83 | else: 84 | from .scheduling_lms_discrete import LMSDiscreteScheduler 85 | 86 | try: 87 | if not (is_torch_available() and is_torchsde_available()): 88 | raise OptionalDependencyNotAvailable() 89 | except OptionalDependencyNotAvailable: 90 | from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 91 | else: 92 | from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler 93 | -------------------------------------------------------------------------------- /diffusers_official/models/resnet_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import flax.linen as nn 15 | import jax 16 | import jax.numpy as jnp 17 | 18 | 19 | class FlaxUpsample2D(nn.Module): 20 | out_channels: int 21 | dtype: jnp.dtype = jnp.float32 22 | 23 | def setup(self): 24 | self.conv = nn.Conv( 25 | self.out_channels, 26 | kernel_size=(3, 3), 27 | strides=(1, 1), 28 | padding=((1, 1), (1, 1)), 29 | dtype=self.dtype, 30 | ) 31 | 32 | def __call__(self, hidden_states): 33 | batch, height, width, channels = hidden_states.shape 34 | hidden_states = jax.image.resize( 35 | hidden_states, 36 | shape=(batch, height * 2, width * 2, channels), 37 | method="nearest", 38 | ) 39 | hidden_states = self.conv(hidden_states) 40 | return hidden_states 41 | 42 | 43 | class FlaxDownsample2D(nn.Module): 44 | out_channels: int 45 | dtype: jnp.dtype = jnp.float32 46 | 47 | def setup(self): 48 | self.conv = nn.Conv( 49 | self.out_channels, 50 | kernel_size=(3, 3), 51 | strides=(2, 2), 52 | padding=((1, 1), (1, 1)), # padding="VALID", 53 | dtype=self.dtype, 54 | ) 55 | 56 | def __call__(self, hidden_states): 57 | # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim 58 | # hidden_states = jnp.pad(hidden_states, pad_width=pad) 59 | hidden_states = self.conv(hidden_states) 60 | return hidden_states 61 | 62 | 63 | class FlaxResnetBlock2D(nn.Module): 64 | in_channels: int 65 | out_channels: int = None 66 | dropout_prob: float = 0.0 67 | use_nin_shortcut: bool = None 68 | dtype: jnp.dtype = jnp.float32 69 | 70 | def setup(self): 71 | out_channels = self.in_channels if self.out_channels is None else self.out_channels 72 | 73 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 74 | self.conv1 = nn.Conv( 75 | out_channels, 76 | kernel_size=(3, 3), 77 | strides=(1, 1), 78 | padding=((1, 1), (1, 1)), 79 | dtype=self.dtype, 80 | ) 81 | 82 | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) 83 | 84 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) 85 | self.dropout = nn.Dropout(self.dropout_prob) 86 | self.conv2 = nn.Conv( 87 | out_channels, 88 | kernel_size=(3, 3), 89 | strides=(1, 1), 90 | padding=((1, 1), (1, 1)), 91 | dtype=self.dtype, 92 | ) 93 | 94 | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut 95 | 96 | self.conv_shortcut = None 97 | if use_nin_shortcut: 98 | self.conv_shortcut = nn.Conv( 99 | out_channels, 100 | kernel_size=(1, 1), 101 | strides=(1, 1), 102 | padding="VALID", 103 | dtype=self.dtype, 104 | ) 105 | 106 | def __call__(self, hidden_states, temb, deterministic=True): 107 | residual = hidden_states 108 | hidden_states = self.norm1(hidden_states) 109 | hidden_states = nn.swish(hidden_states) 110 | hidden_states = self.conv1(hidden_states) 111 | 112 | temb = self.time_emb_proj(nn.swish(temb)) 113 | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) 114 | hidden_states = hidden_states + temb 115 | 116 | hidden_states = self.norm2(hidden_states) 117 | hidden_states = nn.swish(hidden_states) 118 | hidden_states = self.dropout(hidden_states, deterministic) 119 | hidden_states = self.conv2(hidden_states) 120 | 121 | if self.conv_shortcut is not None: 122 | residual = self.conv_shortcut(residual) 123 | 124 | return hidden_states + residual 125 | -------------------------------------------------------------------------------- /diffusers_official/models/modeling_flax_pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch - Flax general utilities.""" 16 | import re 17 | 18 | import jax.numpy as jnp 19 | from flax.traverse_util import flatten_dict, unflatten_dict 20 | from jax.random import PRNGKey 21 | 22 | from ..utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | def rename_key(key): 29 | regex = r"\w+[.]\d+" 30 | pats = re.findall(regex, key) 31 | for pat in pats: 32 | key = key.replace(pat, "_".join(pat.split("."))) 33 | return key 34 | 35 | 36 | ##################### 37 | # PyTorch => Flax # 38 | ##################### 39 | 40 | 41 | # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 42 | # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py 43 | def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): 44 | """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" 45 | 46 | # conv norm or layer norm 47 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 48 | if ( 49 | any("norm" in str_ for str_ in pt_tuple_key) 50 | and (pt_tuple_key[-1] == "bias") 51 | and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) 52 | and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) 53 | ): 54 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 55 | return renamed_pt_tuple_key, pt_tensor 56 | elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: 57 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) 58 | return renamed_pt_tuple_key, pt_tensor 59 | 60 | # embedding 61 | if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: 62 | pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) 63 | return renamed_pt_tuple_key, pt_tensor 64 | 65 | # conv layer 66 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 67 | if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: 68 | pt_tensor = pt_tensor.transpose(2, 3, 1, 0) 69 | return renamed_pt_tuple_key, pt_tensor 70 | 71 | # linear layer 72 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) 73 | if pt_tuple_key[-1] == "weight": 74 | pt_tensor = pt_tensor.T 75 | return renamed_pt_tuple_key, pt_tensor 76 | 77 | # old PyTorch layer norm weight 78 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) 79 | if pt_tuple_key[-1] == "gamma": 80 | return renamed_pt_tuple_key, pt_tensor 81 | 82 | # old PyTorch layer norm bias 83 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) 84 | if pt_tuple_key[-1] == "beta": 85 | return renamed_pt_tuple_key, pt_tensor 86 | 87 | return pt_tuple_key, pt_tensor 88 | 89 | 90 | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): 91 | # Step 1: Convert pytorch tensor to numpy 92 | pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} 93 | 94 | # Step 2: Since the model is stateless, get random Flax params 95 | random_flax_params = flax_model.init_weights(PRNGKey(init_key)) 96 | 97 | random_flax_state_dict = flatten_dict(random_flax_params) 98 | flax_state_dict = {} 99 | 100 | # Need to change some parameters name to match Flax names 101 | for pt_key, pt_tensor in pt_state_dict.items(): 102 | renamed_pt_key = rename_key(pt_key) 103 | pt_tuple_key = tuple(renamed_pt_key.split(".")) 104 | 105 | # Correctly rename weight parameters 106 | flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) 107 | 108 | if flax_key in random_flax_state_dict: 109 | if flax_tensor.shape != random_flax_state_dict[flax_key].shape: 110 | raise ValueError( 111 | f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " 112 | f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." 113 | ) 114 | 115 | # also add unexpected weight so that warning is thrown 116 | flax_state_dict[flax_key] = jnp.asarray(flax_tensor) 117 | 118 | return unflatten_dict(flax_state_dict) 119 | -------------------------------------------------------------------------------- /diffusers_official/models/cross_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from ..utils import deprecate 15 | from .attention_processor import ( # noqa: F401 16 | Attention, 17 | AttentionProcessor, 18 | AttnAddedKVProcessor, 19 | AttnProcessor2_0, 20 | LoRAAttnProcessor, 21 | LoRALinearLayer, 22 | LoRAXFormersAttnProcessor, 23 | SlicedAttnAddedKVProcessor, 24 | SlicedAttnProcessor, 25 | XFormersAttnProcessor, 26 | ) 27 | from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401 28 | 29 | 30 | deprecate( 31 | "cross_attention", 32 | "0.20.0", 33 | "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.", 34 | standard_warn=False, 35 | ) 36 | 37 | 38 | AttnProcessor = AttentionProcessor 39 | 40 | 41 | class CrossAttention(Attention): 42 | def __init__(self, *args, **kwargs): 43 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 44 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 45 | super().__init__(*args, **kwargs) 46 | 47 | 48 | class CrossAttnProcessor(AttnProcessorRename): 49 | def __init__(self, *args, **kwargs): 50 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 51 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 52 | super().__init__(*args, **kwargs) 53 | 54 | 55 | class LoRACrossAttnProcessor(LoRAAttnProcessor): 56 | def __init__(self, *args, **kwargs): 57 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 58 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 59 | super().__init__(*args, **kwargs) 60 | 61 | 62 | class CrossAttnAddedKVProcessor(AttnAddedKVProcessor): 63 | def __init__(self, *args, **kwargs): 64 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 65 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 66 | super().__init__(*args, **kwargs) 67 | 68 | 69 | class XFormersCrossAttnProcessor(XFormersAttnProcessor): 70 | def __init__(self, *args, **kwargs): 71 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 72 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 73 | super().__init__(*args, **kwargs) 74 | 75 | 76 | class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor): 77 | def __init__(self, *args, **kwargs): 78 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 79 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 80 | super().__init__(*args, **kwargs) 81 | 82 | 83 | class SlicedCrossAttnProcessor(SlicedAttnProcessor): 84 | def __init__(self, *args, **kwargs): 85 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 86 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 87 | super().__init__(*args, **kwargs) 88 | 89 | 90 | class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor): 91 | def __init__(self, *args, **kwargs): 92 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." 93 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) 94 | super().__init__(*args, **kwargs) 95 | -------------------------------------------------------------------------------- /diffusers_official/utils/dummy_flax_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..utils import DummyObject, requires_backends 3 | 4 | 5 | class FlaxControlNetModel(metaclass=DummyObject): 6 | _backends = ["flax"] 7 | 8 | def __init__(self, *args, **kwargs): 9 | requires_backends(self, ["flax"]) 10 | 11 | @classmethod 12 | def from_config(cls, *args, **kwargs): 13 | requires_backends(cls, ["flax"]) 14 | 15 | @classmethod 16 | def from_pretrained(cls, *args, **kwargs): 17 | requires_backends(cls, ["flax"]) 18 | 19 | 20 | class FlaxModelMixin(metaclass=DummyObject): 21 | _backends = ["flax"] 22 | 23 | def __init__(self, *args, **kwargs): 24 | requires_backends(self, ["flax"]) 25 | 26 | @classmethod 27 | def from_config(cls, *args, **kwargs): 28 | requires_backends(cls, ["flax"]) 29 | 30 | @classmethod 31 | def from_pretrained(cls, *args, **kwargs): 32 | requires_backends(cls, ["flax"]) 33 | 34 | 35 | class FlaxUNet2DConditionModel(metaclass=DummyObject): 36 | _backends = ["flax"] 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, ["flax"]) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, ["flax"]) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, ["flax"]) 48 | 49 | 50 | class FlaxAutoencoderKL(metaclass=DummyObject): 51 | _backends = ["flax"] 52 | 53 | def __init__(self, *args, **kwargs): 54 | requires_backends(self, ["flax"]) 55 | 56 | @classmethod 57 | def from_config(cls, *args, **kwargs): 58 | requires_backends(cls, ["flax"]) 59 | 60 | @classmethod 61 | def from_pretrained(cls, *args, **kwargs): 62 | requires_backends(cls, ["flax"]) 63 | 64 | 65 | class FlaxDiffusionPipeline(metaclass=DummyObject): 66 | _backends = ["flax"] 67 | 68 | def __init__(self, *args, **kwargs): 69 | requires_backends(self, ["flax"]) 70 | 71 | @classmethod 72 | def from_config(cls, *args, **kwargs): 73 | requires_backends(cls, ["flax"]) 74 | 75 | @classmethod 76 | def from_pretrained(cls, *args, **kwargs): 77 | requires_backends(cls, ["flax"]) 78 | 79 | 80 | class FlaxDDIMScheduler(metaclass=DummyObject): 81 | _backends = ["flax"] 82 | 83 | def __init__(self, *args, **kwargs): 84 | requires_backends(self, ["flax"]) 85 | 86 | @classmethod 87 | def from_config(cls, *args, **kwargs): 88 | requires_backends(cls, ["flax"]) 89 | 90 | @classmethod 91 | def from_pretrained(cls, *args, **kwargs): 92 | requires_backends(cls, ["flax"]) 93 | 94 | 95 | class FlaxDDPMScheduler(metaclass=DummyObject): 96 | _backends = ["flax"] 97 | 98 | def __init__(self, *args, **kwargs): 99 | requires_backends(self, ["flax"]) 100 | 101 | @classmethod 102 | def from_config(cls, *args, **kwargs): 103 | requires_backends(cls, ["flax"]) 104 | 105 | @classmethod 106 | def from_pretrained(cls, *args, **kwargs): 107 | requires_backends(cls, ["flax"]) 108 | 109 | 110 | class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): 111 | _backends = ["flax"] 112 | 113 | def __init__(self, *args, **kwargs): 114 | requires_backends(self, ["flax"]) 115 | 116 | @classmethod 117 | def from_config(cls, *args, **kwargs): 118 | requires_backends(cls, ["flax"]) 119 | 120 | @classmethod 121 | def from_pretrained(cls, *args, **kwargs): 122 | requires_backends(cls, ["flax"]) 123 | 124 | 125 | class FlaxKarrasVeScheduler(metaclass=DummyObject): 126 | _backends = ["flax"] 127 | 128 | def __init__(self, *args, **kwargs): 129 | requires_backends(self, ["flax"]) 130 | 131 | @classmethod 132 | def from_config(cls, *args, **kwargs): 133 | requires_backends(cls, ["flax"]) 134 | 135 | @classmethod 136 | def from_pretrained(cls, *args, **kwargs): 137 | requires_backends(cls, ["flax"]) 138 | 139 | 140 | class FlaxLMSDiscreteScheduler(metaclass=DummyObject): 141 | _backends = ["flax"] 142 | 143 | def __init__(self, *args, **kwargs): 144 | requires_backends(self, ["flax"]) 145 | 146 | @classmethod 147 | def from_config(cls, *args, **kwargs): 148 | requires_backends(cls, ["flax"]) 149 | 150 | @classmethod 151 | def from_pretrained(cls, *args, **kwargs): 152 | requires_backends(cls, ["flax"]) 153 | 154 | 155 | class FlaxPNDMScheduler(metaclass=DummyObject): 156 | _backends = ["flax"] 157 | 158 | def __init__(self, *args, **kwargs): 159 | requires_backends(self, ["flax"]) 160 | 161 | @classmethod 162 | def from_config(cls, *args, **kwargs): 163 | requires_backends(cls, ["flax"]) 164 | 165 | @classmethod 166 | def from_pretrained(cls, *args, **kwargs): 167 | requires_backends(cls, ["flax"]) 168 | 169 | 170 | class FlaxSchedulerMixin(metaclass=DummyObject): 171 | _backends = ["flax"] 172 | 173 | def __init__(self, *args, **kwargs): 174 | requires_backends(self, ["flax"]) 175 | 176 | @classmethod 177 | def from_config(cls, *args, **kwargs): 178 | requires_backends(cls, ["flax"]) 179 | 180 | @classmethod 181 | def from_pretrained(cls, *args, **kwargs): 182 | requires_backends(cls, ["flax"]) 183 | 184 | 185 | class FlaxScoreSdeVeScheduler(metaclass=DummyObject): 186 | _backends = ["flax"] 187 | 188 | def __init__(self, *args, **kwargs): 189 | requires_backends(self, ["flax"]) 190 | 191 | @classmethod 192 | def from_config(cls, *args, **kwargs): 193 | requires_backends(cls, ["flax"]) 194 | 195 | @classmethod 196 | def from_pretrained(cls, *args, **kwargs): 197 | requires_backends(cls, ["flax"]) 198 | -------------------------------------------------------------------------------- /diffusers_official/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.18.2" 2 | 3 | from .configuration_utils import ConfigMixin 4 | from .utils import ( 5 | OptionalDependencyNotAvailable, 6 | is_flax_available, 7 | is_inflect_available, 8 | is_invisible_watermark_available, 9 | is_k_diffusion_available, 10 | is_k_diffusion_version, 11 | is_librosa_available, 12 | is_note_seq_available, 13 | is_onnx_available, 14 | is_scipy_available, 15 | is_torch_available, 16 | is_torchsde_available, 17 | is_transformers_available, 18 | is_transformers_version, 19 | is_unidecode_available, 20 | logging, 21 | ) 22 | 23 | 24 | try: 25 | if not is_onnx_available(): 26 | raise OptionalDependencyNotAvailable() 27 | except OptionalDependencyNotAvailable: 28 | from .utils.dummy_onnx_objects import * # noqa F403 29 | else: 30 | from .pipelines import OnnxRuntimeModel 31 | 32 | try: 33 | if not is_torch_available(): 34 | raise OptionalDependencyNotAvailable() 35 | except OptionalDependencyNotAvailable: 36 | from .utils.dummy_pt_objects import * # noqa F403 37 | else: 38 | from .models import ( 39 | AutoencoderKL, 40 | ControlNetModel, 41 | ModelMixin, 42 | PriorTransformer, 43 | T5FilmDecoder, 44 | Transformer2DModel, 45 | UNet1DModel, 46 | UNet2DConditionModel, 47 | UNet2DModel, 48 | UNet3DConditionModel, 49 | VQModel, 50 | ) 51 | from .optimization import ( 52 | get_constant_schedule, 53 | get_constant_schedule_with_warmup, 54 | get_cosine_schedule_with_warmup, 55 | get_cosine_with_hard_restarts_schedule_with_warmup, 56 | get_linear_schedule_with_warmup, 57 | get_polynomial_decay_schedule_with_warmup, 58 | get_scheduler, 59 | ) 60 | from .schedulers import ( 61 | CMStochasticIterativeScheduler, 62 | DDIMInverseScheduler, 63 | DDIMParallelScheduler, 64 | DDIMScheduler, 65 | DDPMParallelScheduler, 66 | DDPMScheduler, 67 | DEISMultistepScheduler, 68 | DPMSolverMultistepInverseScheduler, 69 | DPMSolverMultistepScheduler, 70 | DPMSolverSinglestepScheduler, 71 | EulerAncestralDiscreteScheduler, 72 | EulerDiscreteScheduler, 73 | HeunDiscreteScheduler, 74 | IPNDMScheduler, 75 | KarrasVeScheduler, 76 | KDPM2AncestralDiscreteScheduler, 77 | KDPM2DiscreteScheduler, 78 | PNDMScheduler, 79 | RePaintScheduler, 80 | SchedulerMixin, 81 | ScoreSdeVeScheduler, 82 | UnCLIPScheduler, 83 | UniPCMultistepScheduler, 84 | VQDiffusionScheduler, 85 | ) 86 | from .training_utils import EMAModel 87 | 88 | try: 89 | if not (is_torch_available() and is_scipy_available()): 90 | raise OptionalDependencyNotAvailable() 91 | except OptionalDependencyNotAvailable: 92 | from .utils.dummy_torch_and_scipy_objects import * # noqa F403 93 | else: 94 | from .schedulers import LMSDiscreteScheduler 95 | 96 | try: 97 | if not (is_torch_available() and is_torchsde_available()): 98 | raise OptionalDependencyNotAvailable() 99 | except OptionalDependencyNotAvailable: 100 | from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 101 | else: 102 | from .schedulers import DPMSolverSDEScheduler 103 | 104 | try: 105 | if not (is_torch_available() and is_transformers_available()): 106 | raise OptionalDependencyNotAvailable() 107 | except OptionalDependencyNotAvailable: 108 | from .utils.dummy_torch_and_transformers_objects import * # noqa F403 109 | 110 | try: 111 | if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): 112 | raise OptionalDependencyNotAvailable() 113 | except OptionalDependencyNotAvailable: 114 | from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 115 | else: 116 | from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline 117 | 118 | try: 119 | if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): 120 | raise OptionalDependencyNotAvailable() 121 | except OptionalDependencyNotAvailable: 122 | from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 123 | 124 | try: 125 | if not (is_torch_available() and is_transformers_available() and is_onnx_available()): 126 | raise OptionalDependencyNotAvailable() 127 | except OptionalDependencyNotAvailable: 128 | from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 129 | 130 | try: 131 | if not (is_torch_available() and is_librosa_available()): 132 | raise OptionalDependencyNotAvailable() 133 | except OptionalDependencyNotAvailable: 134 | from .utils.dummy_torch_and_librosa_objects import * # noqa F403 135 | 136 | try: 137 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): 138 | raise OptionalDependencyNotAvailable() 139 | except OptionalDependencyNotAvailable: 140 | from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 141 | 142 | try: 143 | if not is_flax_available(): 144 | raise OptionalDependencyNotAvailable() 145 | except OptionalDependencyNotAvailable: 146 | from .utils.dummy_flax_objects import * # noqa F403 147 | else: 148 | from .models.controlnet_flax import FlaxControlNetModel 149 | from .models.modeling_flax_utils import FlaxModelMixin 150 | from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel 151 | from .models.vae_flax import FlaxAutoencoderKL 152 | from .pipelines import FlaxDiffusionPipeline 153 | from .schedulers import ( 154 | FlaxDDIMScheduler, 155 | FlaxDDPMScheduler, 156 | FlaxDPMSolverMultistepScheduler, 157 | FlaxKarrasVeScheduler, 158 | FlaxLMSDiscreteScheduler, 159 | FlaxPNDMScheduler, 160 | FlaxSchedulerMixin, 161 | FlaxScoreSdeVeScheduler, 162 | ) 163 | 164 | 165 | try: 166 | if not (is_flax_available() and is_transformers_available()): 167 | raise OptionalDependencyNotAvailable() 168 | except OptionalDependencyNotAvailable: 169 | from .utils.dummy_flax_and_transformers_objects import * # noqa F403 170 | 171 | try: 172 | if not (is_note_seq_available()): 173 | raise OptionalDependencyNotAvailable() 174 | except OptionalDependencyNotAvailable: 175 | from .utils.dummy_note_seq_objects import * # noqa F403 176 | -------------------------------------------------------------------------------- /diffusers_official/schedulers/scheduling_ipndm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | 21 | from ..configuration_utils import ConfigMixin, register_to_config 22 | from .scheduling_utils import SchedulerMixin, SchedulerOutput 23 | 24 | 25 | class IPNDMScheduler(SchedulerMixin, ConfigMixin): 26 | """ 27 | Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion 28 | [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) 29 | 30 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 31 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 32 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 33 | [`~SchedulerMixin.from_pretrained`] functions. 34 | 35 | For more details, see the original paper: https://arxiv.org/abs/2202.09778 36 | 37 | Args: 38 | num_train_timesteps (`int`): number of diffusion steps used to train the model. 39 | """ 40 | 41 | order = 1 42 | 43 | @register_to_config 44 | def __init__( 45 | self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None 46 | ): 47 | # set `betas`, `alphas`, `timesteps` 48 | self.set_timesteps(num_train_timesteps) 49 | 50 | # standard deviation of the initial noise distribution 51 | self.init_noise_sigma = 1.0 52 | 53 | # For now we only support F-PNDM, i.e. the runge-kutta method 54 | # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf 55 | # mainly at formula (9), (12), (13) and the Algorithm 2. 56 | self.pndm_order = 4 57 | 58 | # running values 59 | self.ets = [] 60 | 61 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 62 | """ 63 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 64 | 65 | Args: 66 | num_inference_steps (`int`): 67 | the number of diffusion steps used when generating samples with a pre-trained model. 68 | """ 69 | self.num_inference_steps = num_inference_steps 70 | steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] 71 | steps = torch.cat([steps, torch.tensor([0.0])]) 72 | 73 | if self.config.trained_betas is not None: 74 | self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) 75 | else: 76 | self.betas = torch.sin(steps * math.pi / 2) ** 2 77 | 78 | self.alphas = (1.0 - self.betas**2) ** 0.5 79 | 80 | timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] 81 | self.timesteps = timesteps.to(device) 82 | 83 | self.ets = [] 84 | 85 | def step( 86 | self, 87 | model_output: torch.FloatTensor, 88 | timestep: int, 89 | sample: torch.FloatTensor, 90 | return_dict: bool = True, 91 | ) -> Union[SchedulerOutput, Tuple]: 92 | """ 93 | Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple 94 | times to approximate the solution. 95 | 96 | Args: 97 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 98 | timestep (`int`): current discrete timestep in the diffusion chain. 99 | sample (`torch.FloatTensor`): 100 | current instance of sample being created by diffusion process. 101 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class 102 | 103 | Returns: 104 | [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is 105 | True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. 106 | 107 | """ 108 | if self.num_inference_steps is None: 109 | raise ValueError( 110 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 111 | ) 112 | 113 | timestep_index = (self.timesteps == timestep).nonzero().item() 114 | prev_timestep_index = timestep_index + 1 115 | 116 | ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] 117 | self.ets.append(ets) 118 | 119 | if len(self.ets) == 1: 120 | ets = self.ets[-1] 121 | elif len(self.ets) == 2: 122 | ets = (3 * self.ets[-1] - self.ets[-2]) / 2 123 | elif len(self.ets) == 3: 124 | ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 125 | else: 126 | ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) 127 | 128 | prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) 129 | 130 | if not return_dict: 131 | return (prev_sample,) 132 | 133 | return SchedulerOutput(prev_sample=prev_sample) 134 | 135 | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: 136 | """ 137 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 138 | current timestep. 139 | 140 | Args: 141 | sample (`torch.FloatTensor`): input sample 142 | 143 | Returns: 144 | `torch.FloatTensor`: scaled input sample 145 | """ 146 | return sample 147 | 148 | def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): 149 | alpha = self.alphas[timestep_index] 150 | sigma = self.betas[timestep_index] 151 | 152 | next_alpha = self.alphas[prev_timestep_index] 153 | next_sigma = self.betas[prev_timestep_index] 154 | 155 | pred = (sample - sigma * ets) / max(alpha, 1e-8) 156 | prev_sample = next_alpha * pred + ets * next_sigma 157 | 158 | return prev_sample 159 | 160 | def __len__(self): 161 | return self.config.num_train_timesteps 162 | -------------------------------------------------------------------------------- /diffusers_official/models/vq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Optional, Tuple, Union 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from ..configuration_utils import ConfigMixin, register_to_config 21 | from ..utils import BaseOutput, apply_forward_hook 22 | from .modeling_utils import ModelMixin 23 | from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer 24 | 25 | 26 | @dataclass 27 | class VQEncoderOutput(BaseOutput): 28 | """ 29 | Output of VQModel encoding method. 30 | 31 | Args: 32 | latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 33 | The encoded output sample from the last layer of the model. 34 | """ 35 | 36 | latents: torch.FloatTensor 37 | 38 | 39 | class VQModel(ModelMixin, ConfigMixin): 40 | r""" 41 | A VQ-VAE model for decoding latent representations. 42 | 43 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 44 | for all models (such as downloading or saving). 45 | 46 | Parameters: 47 | in_channels (int, *optional*, defaults to 3): Number of channels in the input image. 48 | out_channels (int, *optional*, defaults to 3): Number of channels in the output. 49 | down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): 50 | Tuple of downsample block types. 51 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): 52 | Tuple of upsample block types. 53 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): 54 | Tuple of block output channels. 55 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 56 | latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. 57 | sample_size (`int`, *optional*, defaults to `32`): Sample input size. 58 | num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. 59 | vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. 60 | scaling_factor (`float`, *optional*, defaults to `0.18215`): 61 | The component-wise standard deviation of the trained latent space computed using the first batch of the 62 | training set. This is used to scale the latent space to have unit variance when training the diffusion 63 | model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the 64 | diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 65 | / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image 66 | Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. 67 | """ 68 | 69 | @register_to_config 70 | def __init__( 71 | self, 72 | in_channels: int = 3, 73 | out_channels: int = 3, 74 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 75 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 76 | block_out_channels: Tuple[int] = (64,), 77 | layers_per_block: int = 1, 78 | act_fn: str = "silu", 79 | latent_channels: int = 3, 80 | sample_size: int = 32, 81 | num_vq_embeddings: int = 256, 82 | norm_num_groups: int = 32, 83 | vq_embed_dim: Optional[int] = None, 84 | scaling_factor: float = 0.18215, 85 | norm_type: str = "group", # group, spatial 86 | ): 87 | super().__init__() 88 | 89 | # pass init params to Encoder 90 | self.encoder = Encoder( 91 | in_channels=in_channels, 92 | out_channels=latent_channels, 93 | down_block_types=down_block_types, 94 | block_out_channels=block_out_channels, 95 | layers_per_block=layers_per_block, 96 | act_fn=act_fn, 97 | norm_num_groups=norm_num_groups, 98 | double_z=False, 99 | ) 100 | 101 | vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels 102 | 103 | self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) 104 | self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) 105 | self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) 106 | 107 | # pass init params to Decoder 108 | self.decoder = Decoder( 109 | in_channels=latent_channels, 110 | out_channels=out_channels, 111 | up_block_types=up_block_types, 112 | block_out_channels=block_out_channels, 113 | layers_per_block=layers_per_block, 114 | act_fn=act_fn, 115 | norm_num_groups=norm_num_groups, 116 | norm_type=norm_type, 117 | ) 118 | 119 | @apply_forward_hook 120 | def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: 121 | h = self.encoder(x) 122 | h = self.quant_conv(h) 123 | 124 | if not return_dict: 125 | return (h,) 126 | 127 | return VQEncoderOutput(latents=h) 128 | 129 | @apply_forward_hook 130 | def decode( 131 | self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True 132 | ) -> Union[DecoderOutput, torch.FloatTensor]: 133 | # also go through quantization layer 134 | if not force_not_quantize: 135 | quant, emb_loss, info = self.quantize(h) 136 | else: 137 | quant = h 138 | quant2 = self.post_quant_conv(quant) 139 | dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) 140 | 141 | if not return_dict: 142 | return (dec,) 143 | 144 | return DecoderOutput(sample=dec) 145 | 146 | def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: 147 | r""" 148 | The [`VQModel`] forward method. 149 | 150 | Args: 151 | sample (`torch.FloatTensor`): Input sample. 152 | return_dict (`bool`, *optional*, defaults to `True`): 153 | Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. 154 | 155 | Returns: 156 | [`~models.vq_model.VQEncoderOutput`] or `tuple`: 157 | If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` 158 | is returned. 159 | """ 160 | x = sample 161 | h = self.encode(x).latents 162 | dec = self.decode(h).sample 163 | 164 | if not return_dict: 165 | return (dec,) 166 | 167 | return DecoderOutput(sample=dec) 168 | -------------------------------------------------------------------------------- /diffusers_official/models/modeling_pytorch_flax_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch - Flax general utilities.""" 16 | 17 | from pickle import UnpicklingError 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from flax.serialization import from_bytes 23 | from flax.traverse_util import flatten_dict 24 | 25 | from ..utils import logging 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | 31 | ##################### 32 | # Flax => PyTorch # 33 | ##################### 34 | 35 | 36 | # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 37 | def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): 38 | try: 39 | with open(model_file, "rb") as flax_state_f: 40 | flax_state = from_bytes(None, flax_state_f.read()) 41 | except UnpicklingError as e: 42 | try: 43 | with open(model_file) as f: 44 | if f.read().startswith("version"): 45 | raise OSError( 46 | "You seem to have cloned a repository without having git-lfs installed. Please" 47 | " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" 48 | " folder you cloned." 49 | ) 50 | else: 51 | raise ValueError from e 52 | except (UnicodeDecodeError, ValueError): 53 | raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") 54 | 55 | return load_flax_weights_in_pytorch_model(pt_model, flax_state) 56 | 57 | 58 | def load_flax_weights_in_pytorch_model(pt_model, flax_state): 59 | """Load flax checkpoints in a PyTorch model""" 60 | 61 | try: 62 | import torch # noqa: F401 63 | except ImportError: 64 | logger.error( 65 | "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" 66 | " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" 67 | " instructions." 68 | ) 69 | raise 70 | 71 | # check if we have bf16 weights 72 | is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() 73 | if any(is_type_bf16): 74 | # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 75 | 76 | # and bf16 is not fully supported in PT yet. 77 | logger.warning( 78 | "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " 79 | "before loading those in PyTorch model." 80 | ) 81 | flax_state = jax.tree_util.tree_map( 82 | lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state 83 | ) 84 | 85 | pt_model.base_model_prefix = "" 86 | 87 | flax_state_dict = flatten_dict(flax_state, sep=".") 88 | pt_model_dict = pt_model.state_dict() 89 | 90 | # keep track of unexpected & missing keys 91 | unexpected_keys = [] 92 | missing_keys = set(pt_model_dict.keys()) 93 | 94 | for flax_key_tuple, flax_tensor in flax_state_dict.items(): 95 | flax_key_tuple_array = flax_key_tuple.split(".") 96 | 97 | if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: 98 | flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] 99 | flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) 100 | elif flax_key_tuple_array[-1] == "kernel": 101 | flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] 102 | flax_tensor = flax_tensor.T 103 | elif flax_key_tuple_array[-1] == "scale": 104 | flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] 105 | 106 | if "time_embedding" not in flax_key_tuple_array: 107 | for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): 108 | flax_key_tuple_array[i] = ( 109 | flax_key_tuple_string.replace("_0", ".0") 110 | .replace("_1", ".1") 111 | .replace("_2", ".2") 112 | .replace("_3", ".3") 113 | .replace("_4", ".4") 114 | .replace("_5", ".5") 115 | .replace("_6", ".6") 116 | .replace("_7", ".7") 117 | .replace("_8", ".8") 118 | .replace("_9", ".9") 119 | ) 120 | 121 | flax_key = ".".join(flax_key_tuple_array) 122 | 123 | if flax_key in pt_model_dict: 124 | if flax_tensor.shape != pt_model_dict[flax_key].shape: 125 | raise ValueError( 126 | f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " 127 | f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." 128 | ) 129 | else: 130 | # add weight to pytorch dict 131 | flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor 132 | pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) 133 | # remove from missing keys 134 | missing_keys.remove(flax_key) 135 | else: 136 | # weight is not expected by PyTorch model 137 | unexpected_keys.append(flax_key) 138 | 139 | pt_model.load_state_dict(pt_model_dict) 140 | 141 | # re-transform missing_keys to list 142 | missing_keys = list(missing_keys) 143 | 144 | if len(unexpected_keys) > 0: 145 | logger.warning( 146 | "Some weights of the Flax model were not used when initializing the PyTorch model" 147 | f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" 148 | f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" 149 | " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" 150 | f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" 151 | " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" 152 | " FlaxBertForSequenceClassification model)." 153 | ) 154 | if len(missing_keys) > 0: 155 | logger.warning( 156 | f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" 157 | f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" 158 | " use it for predictions and inference." 159 | ) 160 | 161 | return pt_model 162 | -------------------------------------------------------------------------------- /diffusers_official/models/dual_transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | from torch import nn 17 | 18 | from .transformer_2d import Transformer2DModel, Transformer2DModelOutput 19 | 20 | 21 | class DualTransformer2DModel(nn.Module): 22 | """ 23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. 24 | 25 | Parameters: 26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 28 | in_channels (`int`, *optional*): 29 | Pass if the input is continuous. The number of channels in the input and output. 30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. 32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 35 | `ImagePositionalEmbeddings`. 36 | num_vector_embeds (`int`, *optional*): 37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. 38 | Includes the class for the masked latent pixel. 39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. 41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used 42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for 43 | up to but not more than steps than `num_embeds_ada_norm`. 44 | attention_bias (`bool`, *optional*): 45 | Configure if the TransformerBlocks' attention should contain a bias parameter. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_attention_heads: int = 16, 51 | attention_head_dim: int = 88, 52 | in_channels: Optional[int] = None, 53 | num_layers: int = 1, 54 | dropout: float = 0.0, 55 | norm_num_groups: int = 32, 56 | cross_attention_dim: Optional[int] = None, 57 | attention_bias: bool = False, 58 | sample_size: Optional[int] = None, 59 | num_vector_embeds: Optional[int] = None, 60 | activation_fn: str = "geglu", 61 | num_embeds_ada_norm: Optional[int] = None, 62 | ): 63 | super().__init__() 64 | self.transformers = nn.ModuleList( 65 | [ 66 | Transformer2DModel( 67 | num_attention_heads=num_attention_heads, 68 | attention_head_dim=attention_head_dim, 69 | in_channels=in_channels, 70 | num_layers=num_layers, 71 | dropout=dropout, 72 | norm_num_groups=norm_num_groups, 73 | cross_attention_dim=cross_attention_dim, 74 | attention_bias=attention_bias, 75 | sample_size=sample_size, 76 | num_vector_embeds=num_vector_embeds, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | ) 80 | for _ in range(2) 81 | ] 82 | ) 83 | 84 | # Variables that can be set by a pipeline: 85 | 86 | # The ratio of transformer1 to transformer2's output states to be combined during inference 87 | self.mix_ratio = 0.5 88 | 89 | # The shape of `encoder_hidden_states` is expected to be 90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` 91 | self.condition_lengths = [77, 257] 92 | 93 | # Which transformer to use to encode which condition. 94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` 95 | self.transformer_index_for_condition = [1, 0] 96 | 97 | def forward( 98 | self, 99 | hidden_states, 100 | encoder_hidden_states, 101 | timestep=None, 102 | attention_mask=None, 103 | cross_attention_kwargs=None, 104 | return_dict: bool = True, 105 | ): 106 | """ 107 | Args: 108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 110 | hidden_states 111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 113 | self-attention. 114 | timestep ( `torch.long`, *optional*): 115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 116 | attention_mask (`torch.FloatTensor`, *optional*): 117 | Optional attention mask to be applied in Attention 118 | return_dict (`bool`, *optional*, defaults to `True`): 119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 120 | 121 | Returns: 122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: 123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When 124 | returning a tuple, the first element is the sample tensor. 125 | """ 126 | input_states = hidden_states 127 | 128 | encoded_states = [] 129 | tokens_start = 0 130 | # attention_mask is not used yet 131 | for i in range(2): 132 | # for each of the two transformers, pass the corresponding condition tokens 133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] 134 | transformer_index = self.transformer_index_for_condition[i] 135 | encoded_state = self.transformers[transformer_index]( 136 | input_states, 137 | encoder_hidden_states=condition_state, 138 | timestep=timestep, 139 | cross_attention_kwargs=cross_attention_kwargs, 140 | return_dict=False, 141 | )[0] 142 | encoded_states.append(encoded_state - input_states) 143 | tokens_start += self.condition_lengths[i] 144 | 145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) 146 | output_states = output_states + input_states 147 | 148 | if not return_dict: 149 | return (output_states,) 150 | 151 | return Transformer2DModelOutput(sample=output_states) 152 | -------------------------------------------------------------------------------- /scripts/models/dual_transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | from torch import nn 17 | 18 | from scripts.models.transformer_2d import Transformer2DModel, Transformer2DModelOutput 19 | 20 | 21 | class DualTransformer2DModel(nn.Module): 22 | """ 23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. 24 | 25 | Parameters: 26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 28 | in_channels (`int`, *optional*): 29 | Pass if the input is continuous. The number of channels in the input and output. 30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. 32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 35 | `ImagePositionalEmbeddings`. 36 | num_vector_embeds (`int`, *optional*): 37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. 38 | Includes the class for the masked latent pixel. 39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. 41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used 42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for 43 | up to but not more than steps than `num_embeds_ada_norm`. 44 | attention_bias (`bool`, *optional*): 45 | Configure if the TransformerBlocks' attention should contain a bias parameter. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_attention_heads: int = 16, 51 | attention_head_dim: int = 88, 52 | in_channels: Optional[int] = None, 53 | num_layers: int = 1, 54 | dropout: float = 0.0, 55 | norm_num_groups: int = 32, 56 | cross_attention_dim: Optional[int] = None, 57 | attention_bias: bool = False, 58 | sample_size: Optional[int] = None, 59 | num_vector_embeds: Optional[int] = None, 60 | activation_fn: str = "geglu", 61 | num_embeds_ada_norm: Optional[int] = None, 62 | ): 63 | super().__init__() 64 | self.transformers = nn.ModuleList( 65 | [ 66 | Transformer2DModel( 67 | num_attention_heads=num_attention_heads, 68 | attention_head_dim=attention_head_dim, 69 | in_channels=in_channels, 70 | num_layers=num_layers, 71 | dropout=dropout, 72 | norm_num_groups=norm_num_groups, 73 | cross_attention_dim=cross_attention_dim, 74 | attention_bias=attention_bias, 75 | sample_size=sample_size, 76 | num_vector_embeds=num_vector_embeds, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | ) 80 | for _ in range(2) 81 | ] 82 | ) 83 | 84 | # Variables that can be set by a pipeline: 85 | 86 | # The ratio of transformer1 to transformer2's output states to be combined during inference 87 | self.mix_ratio = 0.5 88 | 89 | # The shape of `encoder_hidden_states` is expected to be 90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` 91 | self.condition_lengths = [77, 257] 92 | 93 | # Which transformer to use to encode which condition. 94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` 95 | self.transformer_index_for_condition = [1, 0] 96 | 97 | def forward( 98 | self, 99 | hidden_states, 100 | encoder_hidden_states, 101 | timestep=None, 102 | attention_mask=None, 103 | cross_attention_kwargs=None, 104 | return_dict: bool = True, 105 | ): 106 | """ 107 | Args: 108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 110 | hidden_states 111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 113 | self-attention. 114 | timestep ( `torch.long`, *optional*): 115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 116 | attention_mask (`torch.FloatTensor`, *optional*): 117 | Optional attention mask to be applied in Attention 118 | return_dict (`bool`, *optional*, defaults to `True`): 119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 120 | 121 | Returns: 122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: 123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When 124 | returning a tuple, the first element is the sample tensor. 125 | """ 126 | input_states = hidden_states 127 | 128 | encoded_states = [] 129 | tokens_start = 0 130 | # attention_mask is not used yet 131 | for i in range(2): 132 | # for each of the two transformers, pass the corresponding condition tokens 133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] 134 | transformer_index = self.transformer_index_for_condition[i] 135 | encoded_state = self.transformers[transformer_index]( 136 | input_states, 137 | encoder_hidden_states=condition_state, 138 | timestep=timestep, 139 | cross_attention_kwargs=cross_attention_kwargs, 140 | return_dict=False, 141 | )[0] 142 | encoded_states.append(encoded_state - input_states) 143 | tokens_start += self.condition_lengths[i] 144 | 145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) 146 | output_states = output_states + input_states 147 | 148 | if not return_dict: 149 | return (output_states,) 150 | 151 | return Transformer2DModelOutput(sample=output_states) 152 | -------------------------------------------------------------------------------- /diffusers_official/models/transformer_temporal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Optional 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from ..configuration_utils import ConfigMixin, register_to_config 21 | from ..utils import BaseOutput 22 | from .attention import BasicTransformerBlock 23 | from .modeling_utils import ModelMixin 24 | 25 | 26 | @dataclass 27 | class TransformerTemporalModelOutput(BaseOutput): 28 | """ 29 | The output of [`TransformerTemporalModel`]. 30 | 31 | Args: 32 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): 33 | The hidden states output conditioned on `encoder_hidden_states` input. 34 | """ 35 | 36 | sample: torch.FloatTensor 37 | 38 | 39 | class TransformerTemporalModel(ModelMixin, ConfigMixin): 40 | """ 41 | A Transformer model for video-like data. 42 | 43 | Parameters: 44 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 45 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 46 | in_channels (`int`, *optional*): 47 | The number of channels in the input and output (specify if the input is **continuous**). 48 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 49 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 50 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 51 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 52 | This is fixed during training since it is used to learn a number of position embeddings. 53 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 54 | attention_bias (`bool`, *optional*): 55 | Configure if the `TransformerBlock` attention should contain a bias parameter. 56 | double_self_attention (`bool`, *optional*): 57 | Configure if each `TransformerBlock` should contain two self-attention layers. 58 | """ 59 | 60 | @register_to_config 61 | def __init__( 62 | self, 63 | num_attention_heads: int = 16, 64 | attention_head_dim: int = 88, 65 | in_channels: Optional[int] = None, 66 | out_channels: Optional[int] = None, 67 | num_layers: int = 1, 68 | dropout: float = 0.0, 69 | norm_num_groups: int = 32, 70 | cross_attention_dim: Optional[int] = None, 71 | attention_bias: bool = False, 72 | sample_size: Optional[int] = None, 73 | activation_fn: str = "geglu", 74 | norm_elementwise_affine: bool = True, 75 | double_self_attention: bool = True, 76 | ): 77 | super().__init__() 78 | self.num_attention_heads = num_attention_heads 79 | self.attention_head_dim = attention_head_dim 80 | inner_dim = num_attention_heads * attention_head_dim 81 | 82 | self.in_channels = in_channels 83 | 84 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 85 | self.proj_in = nn.Linear(in_channels, inner_dim) 86 | 87 | # 3. Define transformers blocks 88 | self.transformer_blocks = nn.ModuleList( 89 | [ 90 | BasicTransformerBlock( 91 | inner_dim, 92 | num_attention_heads, 93 | attention_head_dim, 94 | dropout=dropout, 95 | cross_attention_dim=cross_attention_dim, 96 | activation_fn=activation_fn, 97 | attention_bias=attention_bias, 98 | double_self_attention=double_self_attention, 99 | norm_elementwise_affine=norm_elementwise_affine, 100 | ) 101 | for d in range(num_layers) 102 | ] 103 | ) 104 | 105 | self.proj_out = nn.Linear(inner_dim, in_channels) 106 | 107 | def forward( 108 | self, 109 | hidden_states, 110 | encoder_hidden_states=None, 111 | timestep=None, 112 | class_labels=None, 113 | num_frames=1, 114 | cross_attention_kwargs=None, 115 | return_dict: bool = True, 116 | ): 117 | """ 118 | The [`TransformerTemporal`] forward method. 119 | 120 | Args: 121 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 122 | Input hidden_states. 123 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 124 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 125 | self-attention. 126 | timestep ( `torch.long`, *optional*): 127 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 128 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 129 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 130 | `AdaLayerZeroNorm`. 131 | return_dict (`bool`, *optional*, defaults to `True`): 132 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 133 | tuple. 134 | 135 | Returns: 136 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: 137 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is 138 | returned, otherwise a `tuple` where the first element is the sample tensor. 139 | """ 140 | # 1. Input 141 | batch_frames, channel, height, width = hidden_states.shape 142 | batch_size = batch_frames // num_frames 143 | 144 | residual = hidden_states 145 | 146 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) 147 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4) 148 | 149 | hidden_states = self.norm(hidden_states) 150 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) 151 | 152 | hidden_states = self.proj_in(hidden_states) 153 | 154 | # 2. Blocks 155 | for block in self.transformer_blocks: 156 | hidden_states = block( 157 | hidden_states, 158 | encoder_hidden_states=encoder_hidden_states, 159 | timestep=timestep, 160 | cross_attention_kwargs=cross_attention_kwargs, 161 | class_labels=class_labels, 162 | ) 163 | 164 | # 3. Output 165 | hidden_states = self.proj_out(hidden_states) 166 | hidden_states = ( 167 | hidden_states[None, None, :] 168 | .reshape(batch_size, height, width, channel, num_frames) 169 | .permute(0, 3, 4, 1, 2) 170 | .contiguous() 171 | ) 172 | hidden_states = hidden_states.reshape(batch_frames, channel, height, width) 173 | 174 | output = hidden_states + residual 175 | 176 | if not return_dict: 177 | return (output,) 178 | 179 | return TransformerTemporalModelOutput(sample=output) 180 | -------------------------------------------------------------------------------- /share_btn.py: -------------------------------------------------------------------------------- 1 | community_icon_html = """""" 5 | 6 | loading_icon_html = """""" 7 | 8 | share_js = """async () => { 9 | async function uploadFile(file){ 10 | const UPLOAD_URL = 'https://huggingface.co/uploads'; 11 | const response = await fetch(UPLOAD_URL, { 12 | method: 'POST', 13 | headers: { 14 | 'Content-Type': file.type, 15 | 'X-Requested-With': 'XMLHttpRequest', 16 | }, 17 | body: file, /// <- File inherits from Blob 18 | }); 19 | const url = await response.text(); 20 | return url; 21 | } 22 | async function getInputImageFile(imageEl){ 23 | const res = await fetch(imageEl.src); 24 | const blob = await res.blob(); 25 | const imageId = Date.now(); 26 | const fileName = `rich-text-image-${{imageId}}.png`; 27 | return new File([blob], fileName, { type: 'image/png'}); 28 | } 29 | const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app'); 30 | const richEl = document.getElementById("rich-text-root"); 31 | const data = richEl? richEl.contentDocument.body._data : {}; 32 | const text_input = JSON.stringify(data); 33 | const negative_prompt = gradioEl.querySelector('#negative_prompt input').value; 34 | const seed = gradioEl.querySelector('#seed input').value; 35 | const richTextImg = gradioEl.querySelector('#rich-text-image img'); 36 | const plainTextImg = gradioEl.querySelector('#plain-text-image img'); 37 | const text_input_obj = JSON.parse(text_input); 38 | const plain_prompt = text_input_obj.ops.map(e=> e.insert).join(''); 39 | const linkSrc = `https://huggingface.co/spaces/songweig/rich-text-to-image?prompt=${encodeURIComponent(text_input)}`; 40 | 41 | const titleTxt = `RT2I: ${plain_prompt.slice(0, 50)}...`; 42 | const shareBtnEl = gradioEl.querySelector('#share-btn'); 43 | const shareIconEl = gradioEl.querySelector('#share-btn-share-icon'); 44 | const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon'); 45 | if(!richTextImg){ 46 | return; 47 | }; 48 | shareBtnEl.style.pointerEvents = 'none'; 49 | shareIconEl.style.display = 'none'; 50 | loadingIconEl.style.removeProperty('display'); 51 | 52 | const richImgFile = await getInputImageFile(richTextImg); 53 | const plainImgFile = await getInputImageFile(plainTextImg); 54 | const richImgURL = await uploadFile(richImgFile); 55 | const plainImgURL = await uploadFile(plainImgFile); 56 | 57 | const descriptionMd = ` 58 | ### Plain Prompt 59 | ${plain_prompt} 60 | 61 | 🔗 Shareable Link + Params: [here](${linkSrc}) 62 | 63 | ### Rich Tech Image 64 | 65 | 66 | ### Plain Text Image 67 | 68 | 69 | `; 70 | const params = new URLSearchParams({ 71 | title: titleTxt, 72 | description: descriptionMd, 73 | }); 74 | const paramsStr = params.toString(); 75 | window.open(`https://huggingface.co/spaces/songweig/rich-text-to-image/discussions/new?${paramsStr}`, '_blank'); 76 | shareBtnEl.style.removeProperty('pointer-events'); 77 | shareIconEl.style.removeProperty('display'); 78 | loadingIconEl.style.display = 'none'; 79 | }""" 80 | 81 | css = """ 82 | #share-btn-container { 83 | display: flex; 84 | padding-left: 0.5rem !important; 85 | padding-right: 0.5rem !important; 86 | background-color: #000000; 87 | justify-content: center; 88 | align-items: center; 89 | border-radius: 9999px !important; 90 | width: 13rem; 91 | margin-top: 10px; 92 | margin-left: auto; 93 | flex: unset !important; 94 | } 95 | #share-btn { 96 | all: initial; 97 | color: #ffffff; 98 | font-weight: 600; 99 | cursor: pointer; 100 | font-family: 'IBM Plex Sans', sans-serif; 101 | margin-left: 0.5rem !important; 102 | padding-top: 0.25rem !important; 103 | padding-bottom: 0.25rem !important; 104 | right:0; 105 | } 106 | #share-btn * { 107 | all: unset !important; 108 | } 109 | #share-btn-container div:nth-child(-n+2){ 110 | width: auto !important; 111 | min-height: 0px !important; 112 | } 113 | #share-btn-container .wrap { 114 | display: none !important; 115 | } 116 | """ 117 | -------------------------------------------------------------------------------- /diffusers_official/schedulers/scheduling_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import os 16 | from dataclasses import dataclass 17 | from enum import Enum 18 | from typing import Any, Dict, Optional, Union 19 | 20 | import torch 21 | 22 | from ..utils import BaseOutput 23 | 24 | 25 | SCHEDULER_CONFIG_NAME = "scheduler_config.json" 26 | 27 | 28 | # NOTE: We make this type an enum because it simplifies usage in docs and prevents 29 | # circular imports when used for `_compatibles` within the schedulers module. 30 | # When it's used as a type in pipelines, it really is a Union because the actual 31 | # scheduler instance is passed in. 32 | class KarrasDiffusionSchedulers(Enum): 33 | DDIMScheduler = 1 34 | DDPMScheduler = 2 35 | PNDMScheduler = 3 36 | LMSDiscreteScheduler = 4 37 | EulerDiscreteScheduler = 5 38 | HeunDiscreteScheduler = 6 39 | EulerAncestralDiscreteScheduler = 7 40 | DPMSolverMultistepScheduler = 8 41 | DPMSolverSinglestepScheduler = 9 42 | KDPM2DiscreteScheduler = 10 43 | KDPM2AncestralDiscreteScheduler = 11 44 | DEISMultistepScheduler = 12 45 | UniPCMultistepScheduler = 13 46 | DPMSolverSDEScheduler = 14 47 | 48 | 49 | @dataclass 50 | class SchedulerOutput(BaseOutput): 51 | """ 52 | Base class for the scheduler's step function output. 53 | 54 | Args: 55 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 56 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 57 | denoising loop. 58 | """ 59 | 60 | prev_sample: torch.FloatTensor 61 | 62 | 63 | class SchedulerMixin: 64 | """ 65 | Mixin containing common functions for the schedulers. 66 | 67 | Class attributes: 68 | - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that 69 | `from_config` can be used from a class different than the one used to save the config (should be overridden 70 | by parent class). 71 | """ 72 | 73 | config_name = SCHEDULER_CONFIG_NAME 74 | _compatibles = [] 75 | has_compatibles = True 76 | 77 | @classmethod 78 | def from_pretrained( 79 | cls, 80 | pretrained_model_name_or_path: Dict[str, Any] = None, 81 | subfolder: Optional[str] = None, 82 | return_unused_kwargs=False, 83 | **kwargs, 84 | ): 85 | r""" 86 | Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. 87 | 88 | Parameters: 89 | pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): 90 | Can be either: 91 | 92 | - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an 93 | organization name, like `google/ddpm-celebahq-256`. 94 | - A path to a *directory* containing the schedluer configurations saved using 95 | [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. 96 | subfolder (`str`, *optional*): 97 | In case the relevant files are located inside a subfolder of the model repo (either remote in 98 | huggingface.co or downloaded locally), you can specify the folder name here. 99 | return_unused_kwargs (`bool`, *optional*, defaults to `False`): 100 | Whether kwargs that are not consumed by the Python class should be returned or not. 101 | cache_dir (`Union[str, os.PathLike]`, *optional*): 102 | Path to a directory in which a downloaded pretrained model configuration should be cached if the 103 | standard cache should not be used. 104 | force_download (`bool`, *optional*, defaults to `False`): 105 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the 106 | cached versions if they exist. 107 | resume_download (`bool`, *optional*, defaults to `False`): 108 | Whether or not to delete incompletely received files. Will attempt to resume the download if such a 109 | file exists. 110 | proxies (`Dict[str, str]`, *optional*): 111 | A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 112 | 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 113 | output_loading_info(`bool`, *optional*, defaults to `False`): 114 | Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. 115 | local_files_only(`bool`, *optional*, defaults to `False`): 116 | Whether or not to only look at local files (i.e., do not try to download the model). 117 | use_auth_token (`str` or *bool*, *optional*): 118 | The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated 119 | when running `transformers-cli login` (stored in `~/.huggingface`). 120 | revision (`str`, *optional*, defaults to `"main"`): 121 | The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a 122 | git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any 123 | identifier allowed by git. 124 | 125 | 126 | 127 | It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated 128 | models](https://huggingface.co/docs/hub/models-gated#gated-models). 129 | 130 | 131 | 132 | 133 | 134 | Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to 135 | use this method in a firewalled environment. 136 | 137 | 138 | 139 | """ 140 | config, kwargs, commit_hash = cls.load_config( 141 | pretrained_model_name_or_path=pretrained_model_name_or_path, 142 | subfolder=subfolder, 143 | return_unused_kwargs=True, 144 | return_commit_hash=True, 145 | **kwargs, 146 | ) 147 | return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) 148 | 149 | def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): 150 | """ 151 | Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the 152 | [`~SchedulerMixin.from_pretrained`] class method. 153 | 154 | Args: 155 | save_directory (`str` or `os.PathLike`): 156 | Directory where the configuration JSON file will be saved (will be created if it does not exist). 157 | """ 158 | self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) 159 | 160 | @property 161 | def compatibles(self): 162 | """ 163 | Returns all schedulers that are compatible with this scheduler 164 | 165 | Returns: 166 | `List[SchedulerMixin]`: List of compatible schedulers 167 | """ 168 | return self._get_compatibles() 169 | 170 | @classmethod 171 | def _get_compatibles(cls): 172 | compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) 173 | diffusers_library = importlib.import_module(__name__.split(".")[0]) 174 | compatible_classes = [ 175 | getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) 176 | ] 177 | return compatible_classes 178 | -------------------------------------------------------------------------------- /diffusers_official/pipelines/onnx_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | import os 19 | import shutil 20 | from pathlib import Path 21 | from typing import Optional, Union 22 | 23 | import numpy as np 24 | from huggingface_hub import hf_hub_download 25 | 26 | from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging 27 | 28 | 29 | if is_onnx_available(): 30 | import onnxruntime as ort 31 | 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | ORT_TO_NP_TYPE = { 36 | "tensor(bool)": np.bool_, 37 | "tensor(int8)": np.int8, 38 | "tensor(uint8)": np.uint8, 39 | "tensor(int16)": np.int16, 40 | "tensor(uint16)": np.uint16, 41 | "tensor(int32)": np.int32, 42 | "tensor(uint32)": np.uint32, 43 | "tensor(int64)": np.int64, 44 | "tensor(uint64)": np.uint64, 45 | "tensor(float16)": np.float16, 46 | "tensor(float)": np.float32, 47 | "tensor(double)": np.float64, 48 | } 49 | 50 | 51 | class OnnxRuntimeModel: 52 | def __init__(self, model=None, **kwargs): 53 | logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") 54 | self.model = model 55 | self.model_save_dir = kwargs.get("model_save_dir", None) 56 | self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) 57 | 58 | def __call__(self, **kwargs): 59 | inputs = {k: np.array(v) for k, v in kwargs.items()} 60 | return self.model.run(None, inputs) 61 | 62 | @staticmethod 63 | def load_model(path: Union[str, Path], provider=None, sess_options=None): 64 | """ 65 | Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` 66 | 67 | Arguments: 68 | path (`str` or `Path`): 69 | Directory from which to load 70 | provider(`str`, *optional*): 71 | Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` 72 | """ 73 | if provider is None: 74 | logger.info("No onnxruntime provider specified, using CPUExecutionProvider") 75 | provider = "CPUExecutionProvider" 76 | 77 | return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) 78 | 79 | def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): 80 | """ 81 | Save a model and its configuration file to a directory, so that it can be re-loaded using the 82 | [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the 83 | latest_model_name. 84 | 85 | Arguments: 86 | save_directory (`str` or `Path`): 87 | Directory where to save the model file. 88 | file_name(`str`, *optional*): 89 | Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the 90 | model with a different name. 91 | """ 92 | model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME 93 | 94 | src_path = self.model_save_dir.joinpath(self.latest_model_name) 95 | dst_path = Path(save_directory).joinpath(model_file_name) 96 | try: 97 | shutil.copyfile(src_path, dst_path) 98 | except shutil.SameFileError: 99 | pass 100 | 101 | # copy external weights (for models >2GB) 102 | src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) 103 | if src_path.exists(): 104 | dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) 105 | try: 106 | shutil.copyfile(src_path, dst_path) 107 | except shutil.SameFileError: 108 | pass 109 | 110 | def save_pretrained( 111 | self, 112 | save_directory: Union[str, os.PathLike], 113 | **kwargs, 114 | ): 115 | """ 116 | Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class 117 | method.: 118 | 119 | Arguments: 120 | save_directory (`str` or `os.PathLike`): 121 | Directory to which to save. Will be created if it doesn't exist. 122 | """ 123 | if os.path.isfile(save_directory): 124 | logger.error(f"Provided path ({save_directory}) should be a directory, not a file") 125 | return 126 | 127 | os.makedirs(save_directory, exist_ok=True) 128 | 129 | # saving model weights/files 130 | self._save_pretrained(save_directory, **kwargs) 131 | 132 | @classmethod 133 | def _from_pretrained( 134 | cls, 135 | model_id: Union[str, Path], 136 | use_auth_token: Optional[Union[bool, str, None]] = None, 137 | revision: Optional[Union[str, None]] = None, 138 | force_download: bool = False, 139 | cache_dir: Optional[str] = None, 140 | file_name: Optional[str] = None, 141 | provider: Optional[str] = None, 142 | sess_options: Optional["ort.SessionOptions"] = None, 143 | **kwargs, 144 | ): 145 | """ 146 | Load a model from a directory or the HF Hub. 147 | 148 | Arguments: 149 | model_id (`str` or `Path`): 150 | Directory from which to load 151 | use_auth_token (`str` or `bool`): 152 | Is needed to load models from a private or gated repository 153 | revision (`str`): 154 | Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id 155 | cache_dir (`Union[str, Path]`, *optional*): 156 | Path to a directory in which a downloaded pretrained model configuration should be cached if the 157 | standard cache should not be used. 158 | force_download (`bool`, *optional*, defaults to `False`): 159 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the 160 | cached versions if they exist. 161 | file_name(`str`): 162 | Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load 163 | different model files from the same repository or directory. 164 | provider(`str`): 165 | The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. 166 | kwargs (`Dict`, *optional*): 167 | kwargs will be passed to the model during initialization 168 | """ 169 | model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME 170 | # load model from local directory 171 | if os.path.isdir(model_id): 172 | model = OnnxRuntimeModel.load_model( 173 | os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options 174 | ) 175 | kwargs["model_save_dir"] = Path(model_id) 176 | # load model from hub 177 | else: 178 | # download model 179 | model_cache_path = hf_hub_download( 180 | repo_id=model_id, 181 | filename=model_file_name, 182 | use_auth_token=use_auth_token, 183 | revision=revision, 184 | cache_dir=cache_dir, 185 | force_download=force_download, 186 | ) 187 | kwargs["model_save_dir"] = Path(model_cache_path).parent 188 | kwargs["latest_model_name"] = Path(model_cache_path).name 189 | model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) 190 | return cls(model=model, **kwargs) 191 | 192 | @classmethod 193 | def from_pretrained( 194 | cls, 195 | model_id: Union[str, Path], 196 | force_download: bool = True, 197 | use_auth_token: Optional[str] = None, 198 | cache_dir: Optional[str] = None, 199 | **model_kwargs, 200 | ): 201 | revision = None 202 | if len(str(model_id).split("@")) == 2: 203 | model_id, revision = model_id.split("@") 204 | 205 | return cls._from_pretrained( 206 | model_id=model_id, 207 | revision=revision, 208 | cache_dir=cache_dir, 209 | force_download=force_download, 210 | use_auth_token=use_auth_token, 211 | **model_kwargs, 212 | ) 213 | -------------------------------------------------------------------------------- /scripts/models/utils/richtext_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import numpy as np 6 | 7 | COLORS = { 8 | 'brown': [165, 42, 42], 9 | 'red': [255, 0, 0], 10 | 'pink': [253, 108, 158], 11 | 'orange': [255, 165, 0], 12 | 'yellow': [255, 255, 0], 13 | 'purple': [128, 0, 128], 14 | 'green': [0, 128, 0], 15 | 'blue': [0, 0, 255], 16 | 'white': [255, 255, 255], 17 | 'gray': [128, 128, 128], 18 | 'black': [0, 0, 0], 19 | } 20 | 21 | 22 | def seed_everything(seed): 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | 29 | 30 | def hex_to_rgb(hex_string, return_nearest_color=False): 31 | r""" 32 | Covert Hex triplet to RGB triplet. 33 | """ 34 | # Remove '#' symbol if present 35 | hex_string = hex_string.lstrip('#') 36 | # Convert hex values to integers 37 | red = int(hex_string[0:2], 16) 38 | green = int(hex_string[2:4], 16) 39 | blue = int(hex_string[4:6], 16) 40 | rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255. 41 | if return_nearest_color: 42 | nearest_color = find_nearest_color(rgb) 43 | return rgb.cuda(), nearest_color 44 | return rgb.cuda() 45 | 46 | 47 | def find_nearest_color(rgb): 48 | r""" 49 | Find the nearest neighbor color given the RGB value. 50 | """ 51 | if isinstance(rgb, list) or isinstance(rgb, tuple): 52 | rgb = torch.FloatTensor(rgb)[None, :, None, None]/255. 53 | color_distance = torch.FloatTensor([np.linalg.norm( 54 | rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()]) 55 | nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()] 56 | return nearest_color 57 | 58 | 59 | def font2style(font): 60 | r""" 61 | Convert the font name to the style name. 62 | """ 63 | return {'mirza': 'Claud Monet, impressionism, oil on canvas', 64 | 'roboto': 'Ukiyoe', 65 | 'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq', 66 | 'sofia': 'Pop Art, masterpiece, andy warhol', 67 | 'slabo': 'Vincent Van Gogh', 68 | 'inconsolata': 'Pixel Art, 8 bits, 16 bits', 69 | 'ubuntu': 'Rembrandt', 70 | 'Monoton': 'neon art, colorful light, highly details, octane render', 71 | 'Akronim': 'Abstract Cubism, Pablo Picasso', }[font] 72 | 73 | 74 | def parse_json(json_str): 75 | r""" 76 | Convert the JSON string to attributes. 77 | """ 78 | # initialze region-base attributes. 79 | base_text_prompt = '' 80 | style_text_prompts = [] 81 | footnote_text_prompts = [] 82 | footnote_target_tokens = [] 83 | color_text_prompts = [] 84 | color_rgbs = [] 85 | color_names = [] 86 | size_text_prompts_and_sizes = [] 87 | 88 | # parse the attributes from JSON. 89 | prev_style = None 90 | prev_color_rgb = None 91 | use_grad_guidance = False 92 | for span in json_str['ops']: 93 | text_prompt = span['insert'].rstrip('\n') 94 | base_text_prompt += span['insert'].rstrip('\n') 95 | if text_prompt == ' ': 96 | continue 97 | if 'attributes' in span: 98 | if 'font' in span['attributes']: 99 | style = font2style(span['attributes']['font']) 100 | if prev_style == style: 101 | prev_text_prompt = style_text_prompts[-1].split('in the style of')[ 102 | 0] 103 | style_text_prompts[-1] = prev_text_prompt + \ 104 | ' ' + text_prompt + f' in the style of {style}' 105 | else: 106 | style_text_prompts.append( 107 | text_prompt + f' in the style of {style}') 108 | prev_style = style 109 | else: 110 | prev_style = None 111 | if 'link' in span['attributes']: 112 | footnote_text_prompts.append(span['attributes']['link']) 113 | footnote_target_tokens.append(text_prompt) 114 | font_size = 1 115 | if 'size' in span['attributes'] and 'strike' not in span['attributes']: 116 | font_size = float(span['attributes']['size'][:-2])/3. 117 | elif 'size' in span['attributes'] and 'strike' in span['attributes']: 118 | font_size = -float(span['attributes']['size'][:-2])/3. 119 | elif 'size' not in span['attributes'] and 'strike' not in span['attributes']: 120 | font_size = 1 121 | if 'color' in span['attributes']: 122 | use_grad_guidance = True 123 | color_rgb, nearest_color = hex_to_rgb( 124 | span['attributes']['color'], True) 125 | if prev_color_rgb == color_rgb: 126 | prev_text_prompt = color_text_prompts[-1] 127 | color_text_prompts[-1] = prev_text_prompt + \ 128 | ' ' + text_prompt 129 | else: 130 | color_rgbs.append(color_rgb) 131 | color_names.append(nearest_color) 132 | color_text_prompts.append(text_prompt) 133 | if font_size != 1: 134 | size_text_prompts_and_sizes.append([text_prompt, font_size]) 135 | return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\ 136 | color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance 137 | 138 | 139 | def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts, 140 | footnote_target_tokens, color_text_prompts, color_names): 141 | r""" 142 | Algorithm 1 in the paper. 143 | """ 144 | region_text_prompts = [] 145 | region_target_token_ids = [] 146 | base_tokens = model.tokenizer._tokenize(base_text_prompt) 147 | # process the style text prompt 148 | for text_prompt in style_text_prompts: 149 | region_text_prompts.append(text_prompt) 150 | region_target_token_ids.append([]) 151 | style_tokens = model.tokenizer._tokenize( 152 | text_prompt.split('in the style of')[0]) 153 | for style_token in style_tokens: 154 | region_target_token_ids[-1].append( 155 | base_tokens.index(style_token)+1) 156 | 157 | # process the complementary text prompt 158 | for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens): 159 | region_target_token_ids.append([]) 160 | region_text_prompts.append(footnote_text_prompt) 161 | style_tokens = model.tokenizer._tokenize(text_prompt) 162 | for style_token in style_tokens: 163 | region_target_token_ids[-1].append( 164 | base_tokens.index(style_token)+1) 165 | 166 | # process the color text prompt 167 | for color_text_prompt, color_name in zip(color_text_prompts, color_names): 168 | region_target_token_ids.append([]) 169 | region_text_prompts.append(color_name+' '+color_text_prompt) 170 | style_tokens = model.tokenizer._tokenize(color_text_prompt) 171 | for style_token in style_tokens: 172 | region_target_token_ids[-1].append( 173 | base_tokens.index(style_token)+1) 174 | 175 | # process the remaining tokens without any attributes 176 | region_text_prompts.append(base_text_prompt) 177 | region_target_token_ids_all = [ 178 | id for ids in region_target_token_ids for id in ids] 179 | target_token_ids_rest = [id for id in range( 180 | 1, len(base_tokens)+1) if id not in region_target_token_ids_all] 181 | region_target_token_ids.append(target_token_ids_rest) 182 | 183 | region_target_token_ids = [torch.LongTensor( 184 | obj_token_id) for obj_token_id in region_target_token_ids] 185 | return region_text_prompts, region_target_token_ids, base_tokens 186 | 187 | 188 | def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes): 189 | r""" 190 | Control the token impact using font sizes. 191 | """ 192 | word_pos = [] 193 | font_sizes = [] 194 | for text_prompt, font_size in size_text_prompts_and_sizes: 195 | size_tokens = model.tokenizer._tokenize(text_prompt) 196 | for size_token in size_tokens: 197 | word_pos.append(base_tokens.index(size_token)+1) 198 | font_sizes.append(font_size) 199 | if len(word_pos) > 0: 200 | word_pos = torch.LongTensor(word_pos).cuda() 201 | font_sizes = torch.FloatTensor(font_sizes).cuda() 202 | else: 203 | word_pos = None 204 | font_sizes = None 205 | text_format_dict = { 206 | 'word_pos': word_pos, 207 | 'font_size': font_sizes, 208 | } 209 | return text_format_dict 210 | 211 | 212 | def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, 213 | guidance_start_step=999, color_guidance_weight=1): 214 | r""" 215 | Control the token impact using font sizes. 216 | """ 217 | color_target_token_ids = [] 218 | for text_prompt in color_text_prompts: 219 | color_target_token_ids.append([]) 220 | color_tokens = model.tokenizer._tokenize(text_prompt) 221 | for color_token in color_tokens: 222 | color_target_token_ids[-1].append(base_tokens.index(color_token)+1) 223 | color_target_token_ids_all = [ 224 | id for ids in color_target_token_ids for id in ids] 225 | color_target_token_ids_rest = [id for id in range( 226 | 1, len(base_tokens)+1) if id not in color_target_token_ids_all] 227 | color_target_token_ids.append(color_target_token_ids_rest) 228 | color_target_token_ids = [torch.LongTensor( 229 | obj_token_id) for obj_token_id in color_target_token_ids] 230 | 231 | text_format_dict['target_RGB'] = color_rgbs 232 | text_format_dict['guidance_start_step'] = guidance_start_step 233 | text_format_dict['color_guidance_weight'] = color_guidance_weight 234 | return text_format_dict, color_target_token_ids 235 | -------------------------------------------------------------------------------- /diffusers_official/schedulers/scheduling_karras_ve_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from dataclasses import dataclass 17 | from typing import Optional, Tuple, Union 18 | 19 | import flax 20 | import jax.numpy as jnp 21 | from jax import random 22 | 23 | from ..configuration_utils import ConfigMixin, register_to_config 24 | from ..utils import BaseOutput 25 | from .scheduling_utils_flax import FlaxSchedulerMixin 26 | 27 | 28 | @flax.struct.dataclass 29 | class KarrasVeSchedulerState: 30 | # setable values 31 | num_inference_steps: Optional[int] = None 32 | timesteps: Optional[jnp.ndarray] = None 33 | schedule: Optional[jnp.ndarray] = None # sigma(t_i) 34 | 35 | @classmethod 36 | def create(cls): 37 | return cls() 38 | 39 | 40 | @dataclass 41 | class FlaxKarrasVeOutput(BaseOutput): 42 | """ 43 | Output class for the scheduler's step function output. 44 | 45 | Args: 46 | prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): 47 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 48 | denoising loop. 49 | derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): 50 | Derivative of predicted original image sample (x_0). 51 | state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. 52 | """ 53 | 54 | prev_sample: jnp.ndarray 55 | derivative: jnp.ndarray 56 | state: KarrasVeSchedulerState 57 | 58 | 59 | class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): 60 | """ 61 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and 62 | the VE column of Table 1 from [1] for reference. 63 | 64 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." 65 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic 66 | differential equations." https://arxiv.org/abs/2011.13456 67 | 68 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 69 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 70 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 71 | [`~SchedulerMixin.from_pretrained`] functions. 72 | 73 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of 74 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the 75 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. 76 | 77 | Args: 78 | sigma_min (`float`): minimum noise magnitude 79 | sigma_max (`float`): maximum noise magnitude 80 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. 81 | A reasonable range is [1.000, 1.011]. 82 | s_churn (`float`): the parameter controlling the overall amount of stochasticity. 83 | A reasonable range is [0, 100]. 84 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). 85 | A reasonable range is [0, 10]. 86 | s_max (`float`): the end value of the sigma range where we add noise. 87 | A reasonable range is [0.2, 80]. 88 | """ 89 | 90 | @property 91 | def has_state(self): 92 | return True 93 | 94 | @register_to_config 95 | def __init__( 96 | self, 97 | sigma_min: float = 0.02, 98 | sigma_max: float = 100, 99 | s_noise: float = 1.007, 100 | s_churn: float = 80, 101 | s_min: float = 0.05, 102 | s_max: float = 50, 103 | ): 104 | pass 105 | 106 | def create_state(self): 107 | return KarrasVeSchedulerState.create() 108 | 109 | def set_timesteps( 110 | self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () 111 | ) -> KarrasVeSchedulerState: 112 | """ 113 | Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. 114 | 115 | Args: 116 | state (`KarrasVeSchedulerState`): 117 | the `FlaxKarrasVeScheduler` state data class. 118 | num_inference_steps (`int`): 119 | the number of diffusion steps used when generating samples with a pre-trained model. 120 | 121 | """ 122 | timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() 123 | schedule = [ 124 | ( 125 | self.config.sigma_max**2 126 | * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) 127 | ) 128 | for i in timesteps 129 | ] 130 | 131 | return state.replace( 132 | num_inference_steps=num_inference_steps, 133 | schedule=jnp.array(schedule, dtype=jnp.float32), 134 | timesteps=timesteps, 135 | ) 136 | 137 | def add_noise_to_input( 138 | self, 139 | state: KarrasVeSchedulerState, 140 | sample: jnp.ndarray, 141 | sigma: float, 142 | key: random.KeyArray, 143 | ) -> Tuple[jnp.ndarray, float]: 144 | """ 145 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a 146 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. 147 | 148 | TODO Args: 149 | """ 150 | if self.config.s_min <= sigma <= self.config.s_max: 151 | gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) 152 | else: 153 | gamma = 0 154 | 155 | # sample eps ~ N(0, S_noise^2 * I) 156 | key = random.split(key, num=1) 157 | eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) 158 | sigma_hat = sigma + gamma * sigma 159 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) 160 | 161 | return sample_hat, sigma_hat 162 | 163 | def step( 164 | self, 165 | state: KarrasVeSchedulerState, 166 | model_output: jnp.ndarray, 167 | sigma_hat: float, 168 | sigma_prev: float, 169 | sample_hat: jnp.ndarray, 170 | return_dict: bool = True, 171 | ) -> Union[FlaxKarrasVeOutput, Tuple]: 172 | """ 173 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 174 | process from the learned model outputs (most often the predicted noise). 175 | 176 | Args: 177 | state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. 178 | model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. 179 | sigma_hat (`float`): TODO 180 | sigma_prev (`float`): TODO 181 | sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO 182 | return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class 183 | 184 | Returns: 185 | [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion 186 | chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is 187 | True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. 188 | """ 189 | 190 | pred_original_sample = sample_hat + sigma_hat * model_output 191 | derivative = (sample_hat - pred_original_sample) / sigma_hat 192 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative 193 | 194 | if not return_dict: 195 | return (sample_prev, derivative, state) 196 | 197 | return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) 198 | 199 | def step_correct( 200 | self, 201 | state: KarrasVeSchedulerState, 202 | model_output: jnp.ndarray, 203 | sigma_hat: float, 204 | sigma_prev: float, 205 | sample_hat: jnp.ndarray, 206 | sample_prev: jnp.ndarray, 207 | derivative: jnp.ndarray, 208 | return_dict: bool = True, 209 | ) -> Union[FlaxKarrasVeOutput, Tuple]: 210 | """ 211 | Correct the predicted sample based on the output model_output of the network. TODO complete description 212 | 213 | Args: 214 | state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. 215 | model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. 216 | sigma_hat (`float`): TODO 217 | sigma_prev (`float`): TODO 218 | sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO 219 | sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO 220 | derivative (`torch.FloatTensor` or `np.ndarray`): TODO 221 | return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class 222 | 223 | Returns: 224 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO 225 | 226 | """ 227 | pred_original_sample = sample_prev + sigma_prev * model_output 228 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev 229 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) 230 | 231 | if not return_dict: 232 | return (sample_prev, derivative, state) 233 | 234 | return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) 235 | 236 | def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): 237 | raise NotImplementedError() 238 | -------------------------------------------------------------------------------- /diffusers_official/schedulers/scheduling_karras_ve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from dataclasses import dataclass 17 | from typing import Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from ..configuration_utils import ConfigMixin, register_to_config 23 | from ..utils import BaseOutput, randn_tensor 24 | from .scheduling_utils import SchedulerMixin 25 | 26 | 27 | @dataclass 28 | class KarrasVeOutput(BaseOutput): 29 | """ 30 | Output class for the scheduler's step function output. 31 | 32 | Args: 33 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 34 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 35 | denoising loop. 36 | derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 37 | Derivative of predicted original image sample (x_0). 38 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 39 | The predicted denoised sample (x_{0}) based on the model output from the current timestep. 40 | `pred_original_sample` can be used to preview progress or for guidance. 41 | """ 42 | 43 | prev_sample: torch.FloatTensor 44 | derivative: torch.FloatTensor 45 | pred_original_sample: Optional[torch.FloatTensor] = None 46 | 47 | 48 | class KarrasVeScheduler(SchedulerMixin, ConfigMixin): 49 | """ 50 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and 51 | the VE column of Table 1 from [1] for reference. 52 | 53 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." 54 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic 55 | differential equations." https://arxiv.org/abs/2011.13456 56 | 57 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 58 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 59 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 60 | [`~SchedulerMixin.from_pretrained`] functions. 61 | 62 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of 63 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the 64 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. 65 | 66 | Args: 67 | sigma_min (`float`): minimum noise magnitude 68 | sigma_max (`float`): maximum noise magnitude 69 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. 70 | A reasonable range is [1.000, 1.011]. 71 | s_churn (`float`): the parameter controlling the overall amount of stochasticity. 72 | A reasonable range is [0, 100]. 73 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). 74 | A reasonable range is [0, 10]. 75 | s_max (`float`): the end value of the sigma range where we add noise. 76 | A reasonable range is [0.2, 80]. 77 | 78 | """ 79 | 80 | order = 2 81 | 82 | @register_to_config 83 | def __init__( 84 | self, 85 | sigma_min: float = 0.02, 86 | sigma_max: float = 100, 87 | s_noise: float = 1.007, 88 | s_churn: float = 80, 89 | s_min: float = 0.05, 90 | s_max: float = 50, 91 | ): 92 | # standard deviation of the initial noise distribution 93 | self.init_noise_sigma = sigma_max 94 | 95 | # setable values 96 | self.num_inference_steps: int = None 97 | self.timesteps: np.IntTensor = None 98 | self.schedule: torch.FloatTensor = None # sigma(t_i) 99 | 100 | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: 101 | """ 102 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 103 | current timestep. 104 | 105 | Args: 106 | sample (`torch.FloatTensor`): input sample 107 | timestep (`int`, optional): current timestep 108 | 109 | Returns: 110 | `torch.FloatTensor`: scaled input sample 111 | """ 112 | return sample 113 | 114 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 115 | """ 116 | Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. 117 | 118 | Args: 119 | num_inference_steps (`int`): 120 | the number of diffusion steps used when generating samples with a pre-trained model. 121 | 122 | """ 123 | self.num_inference_steps = num_inference_steps 124 | timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() 125 | self.timesteps = torch.from_numpy(timesteps).to(device) 126 | schedule = [ 127 | ( 128 | self.config.sigma_max**2 129 | * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) 130 | ) 131 | for i in self.timesteps 132 | ] 133 | self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) 134 | 135 | def add_noise_to_input( 136 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None 137 | ) -> Tuple[torch.FloatTensor, float]: 138 | """ 139 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a 140 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. 141 | 142 | TODO Args: 143 | """ 144 | if self.config.s_min <= sigma <= self.config.s_max: 145 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) 146 | else: 147 | gamma = 0 148 | 149 | # sample eps ~ N(0, S_noise^2 * I) 150 | eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device) 151 | sigma_hat = sigma + gamma * sigma 152 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) 153 | 154 | return sample_hat, sigma_hat 155 | 156 | def step( 157 | self, 158 | model_output: torch.FloatTensor, 159 | sigma_hat: float, 160 | sigma_prev: float, 161 | sample_hat: torch.FloatTensor, 162 | return_dict: bool = True, 163 | ) -> Union[KarrasVeOutput, Tuple]: 164 | """ 165 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 166 | process from the learned model outputs (most often the predicted noise). 167 | 168 | Args: 169 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 170 | sigma_hat (`float`): TODO 171 | sigma_prev (`float`): TODO 172 | sample_hat (`torch.FloatTensor`): TODO 173 | return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class 174 | 175 | KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). 176 | Returns: 177 | [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: 178 | [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When 179 | returning a tuple, the first element is the sample tensor. 180 | 181 | """ 182 | 183 | pred_original_sample = sample_hat + sigma_hat * model_output 184 | derivative = (sample_hat - pred_original_sample) / sigma_hat 185 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative 186 | 187 | if not return_dict: 188 | return (sample_prev, derivative) 189 | 190 | return KarrasVeOutput( 191 | prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample 192 | ) 193 | 194 | def step_correct( 195 | self, 196 | model_output: torch.FloatTensor, 197 | sigma_hat: float, 198 | sigma_prev: float, 199 | sample_hat: torch.FloatTensor, 200 | sample_prev: torch.FloatTensor, 201 | derivative: torch.FloatTensor, 202 | return_dict: bool = True, 203 | ) -> Union[KarrasVeOutput, Tuple]: 204 | """ 205 | Correct the predicted sample based on the output model_output of the network. TODO complete description 206 | 207 | Args: 208 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 209 | sigma_hat (`float`): TODO 210 | sigma_prev (`float`): TODO 211 | sample_hat (`torch.FloatTensor`): TODO 212 | sample_prev (`torch.FloatTensor`): TODO 213 | derivative (`torch.FloatTensor`): TODO 214 | return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class 215 | 216 | Returns: 217 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO 218 | 219 | """ 220 | pred_original_sample = sample_prev + sigma_prev * model_output 221 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev 222 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) 223 | 224 | if not return_dict: 225 | return (sample_prev, derivative) 226 | 227 | return KarrasVeOutput( 228 | prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample 229 | ) 230 | 231 | def add_noise(self, original_samples, noise, timesteps): 232 | raise NotImplementedError() 233 | -------------------------------------------------------------------------------- /diffusers_official/utils/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities.""" 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from logging import ( 22 | CRITICAL, # NOQA 23 | DEBUG, # NOQA 24 | ERROR, # NOQA 25 | FATAL, # NOQA 26 | INFO, # NOQA 27 | NOTSET, # NOQA 28 | WARN, # NOQA 29 | WARNING, # NOQA 30 | ) 31 | from typing import Optional 32 | 33 | from tqdm import auto as tqdm_lib 34 | 35 | 36 | _lock = threading.Lock() 37 | _default_handler: Optional[logging.Handler] = None 38 | 39 | log_levels = { 40 | "debug": logging.DEBUG, 41 | "info": logging.INFO, 42 | "warning": logging.WARNING, 43 | "error": logging.ERROR, 44 | "critical": logging.CRITICAL, 45 | } 46 | 47 | _default_log_level = logging.WARNING 48 | 49 | _tqdm_active = True 50 | 51 | 52 | def _get_default_logging_level(): 53 | """ 54 | If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 55 | not - fall back to `_default_log_level` 56 | """ 57 | env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) 58 | if env_level_str: 59 | if env_level_str in log_levels: 60 | return log_levels[env_level_str] 61 | else: 62 | logging.getLogger().warning( 63 | f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " 64 | f"has to be one of: { ', '.join(log_levels.keys()) }" 65 | ) 66 | return _default_log_level 67 | 68 | 69 | def _get_library_name() -> str: 70 | return __name__.split(".")[0] 71 | 72 | 73 | def _get_library_root_logger() -> logging.Logger: 74 | return logging.getLogger(_get_library_name()) 75 | 76 | 77 | def _configure_library_root_logger() -> None: 78 | global _default_handler 79 | 80 | with _lock: 81 | if _default_handler: 82 | # This library has already configured the library root logger. 83 | return 84 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 85 | _default_handler.flush = sys.stderr.flush 86 | 87 | # Apply our default configuration to the library root logger. 88 | library_root_logger = _get_library_root_logger() 89 | library_root_logger.addHandler(_default_handler) 90 | library_root_logger.setLevel(_get_default_logging_level()) 91 | library_root_logger.propagate = False 92 | 93 | 94 | def _reset_library_root_logger() -> None: 95 | global _default_handler 96 | 97 | with _lock: 98 | if not _default_handler: 99 | return 100 | 101 | library_root_logger = _get_library_root_logger() 102 | library_root_logger.removeHandler(_default_handler) 103 | library_root_logger.setLevel(logging.NOTSET) 104 | _default_handler = None 105 | 106 | 107 | def get_log_levels_dict(): 108 | return log_levels 109 | 110 | 111 | def get_logger(name: Optional[str] = None) -> logging.Logger: 112 | """ 113 | Return a logger with the specified name. 114 | 115 | This function is not supposed to be directly accessed unless you are writing a custom diffusers module. 116 | """ 117 | 118 | if name is None: 119 | name = _get_library_name() 120 | 121 | _configure_library_root_logger() 122 | return logging.getLogger(name) 123 | 124 | 125 | def get_verbosity() -> int: 126 | """ 127 | Return the current level for the 🤗 Diffusers' root logger as an `int`. 128 | 129 | Returns: 130 | `int`: 131 | Logging level integers which can be one of: 132 | 133 | - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` 134 | - `40`: `diffusers.logging.ERROR` 135 | - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` 136 | - `20`: `diffusers.logging.INFO` 137 | - `10`: `diffusers.logging.DEBUG` 138 | 139 | """ 140 | 141 | _configure_library_root_logger() 142 | return _get_library_root_logger().getEffectiveLevel() 143 | 144 | 145 | def set_verbosity(verbosity: int) -> None: 146 | """ 147 | Set the verbosity level for the 🤗 Diffusers' root logger. 148 | 149 | Args: 150 | verbosity (`int`): 151 | Logging level which can be one of: 152 | 153 | - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` 154 | - `diffusers.logging.ERROR` 155 | - `diffusers.logging.WARNING` or `diffusers.logging.WARN` 156 | - `diffusers.logging.INFO` 157 | - `diffusers.logging.DEBUG` 158 | """ 159 | 160 | _configure_library_root_logger() 161 | _get_library_root_logger().setLevel(verbosity) 162 | 163 | 164 | def set_verbosity_info(): 165 | """Set the verbosity to the `INFO` level.""" 166 | return set_verbosity(INFO) 167 | 168 | 169 | def set_verbosity_warning(): 170 | """Set the verbosity to the `WARNING` level.""" 171 | return set_verbosity(WARNING) 172 | 173 | 174 | def set_verbosity_debug(): 175 | """Set the verbosity to the `DEBUG` level.""" 176 | return set_verbosity(DEBUG) 177 | 178 | 179 | def set_verbosity_error(): 180 | """Set the verbosity to the `ERROR` level.""" 181 | return set_verbosity(ERROR) 182 | 183 | 184 | def disable_default_handler() -> None: 185 | """Disable the default handler of the 🤗 Diffusers' root logger.""" 186 | 187 | _configure_library_root_logger() 188 | 189 | assert _default_handler is not None 190 | _get_library_root_logger().removeHandler(_default_handler) 191 | 192 | 193 | def enable_default_handler() -> None: 194 | """Enable the default handler of the 🤗 Diffusers' root logger.""" 195 | 196 | _configure_library_root_logger() 197 | 198 | assert _default_handler is not None 199 | _get_library_root_logger().addHandler(_default_handler) 200 | 201 | 202 | def add_handler(handler: logging.Handler) -> None: 203 | """adds a handler to the HuggingFace Diffusers' root logger.""" 204 | 205 | _configure_library_root_logger() 206 | 207 | assert handler is not None 208 | _get_library_root_logger().addHandler(handler) 209 | 210 | 211 | def remove_handler(handler: logging.Handler) -> None: 212 | """removes given handler from the HuggingFace Diffusers' root logger.""" 213 | 214 | _configure_library_root_logger() 215 | 216 | assert handler is not None and handler not in _get_library_root_logger().handlers 217 | _get_library_root_logger().removeHandler(handler) 218 | 219 | 220 | def disable_propagation() -> None: 221 | """ 222 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 223 | """ 224 | 225 | _configure_library_root_logger() 226 | _get_library_root_logger().propagate = False 227 | 228 | 229 | def enable_propagation() -> None: 230 | """ 231 | Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent 232 | double logging if the root logger has been configured. 233 | """ 234 | 235 | _configure_library_root_logger() 236 | _get_library_root_logger().propagate = True 237 | 238 | 239 | def enable_explicit_format() -> None: 240 | """ 241 | Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: 242 | ``` 243 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 244 | ``` 245 | All handlers currently bound to the root logger are affected by this method. 246 | """ 247 | handlers = _get_library_root_logger().handlers 248 | 249 | for handler in handlers: 250 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") 251 | handler.setFormatter(formatter) 252 | 253 | 254 | def reset_format() -> None: 255 | """ 256 | Resets the formatting for 🤗 Diffusers' loggers. 257 | 258 | All handlers currently bound to the root logger are affected by this method. 259 | """ 260 | handlers = _get_library_root_logger().handlers 261 | 262 | for handler in handlers: 263 | handler.setFormatter(None) 264 | 265 | 266 | def warning_advice(self, *args, **kwargs): 267 | """ 268 | This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this 269 | warning will not be printed 270 | """ 271 | no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) 272 | if no_advisory_warnings: 273 | return 274 | self.warning(*args, **kwargs) 275 | 276 | 277 | logging.Logger.warning_advice = warning_advice 278 | 279 | 280 | class EmptyTqdm: 281 | """Dummy tqdm which doesn't do anything.""" 282 | 283 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument 284 | self._iterator = args[0] if args else None 285 | 286 | def __iter__(self): 287 | return iter(self._iterator) 288 | 289 | def __getattr__(self, _): 290 | """Return empty function.""" 291 | 292 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument 293 | return 294 | 295 | return empty_fn 296 | 297 | def __enter__(self): 298 | return self 299 | 300 | def __exit__(self, type_, value, traceback): 301 | return 302 | 303 | 304 | class _tqdm_cls: 305 | def __call__(self, *args, **kwargs): 306 | if _tqdm_active: 307 | return tqdm_lib.tqdm(*args, **kwargs) 308 | else: 309 | return EmptyTqdm(*args, **kwargs) 310 | 311 | def set_lock(self, *args, **kwargs): 312 | self._lock = None 313 | if _tqdm_active: 314 | return tqdm_lib.tqdm.set_lock(*args, **kwargs) 315 | 316 | def get_lock(self): 317 | if _tqdm_active: 318 | return tqdm_lib.tqdm.get_lock() 319 | 320 | 321 | tqdm = _tqdm_cls() 322 | 323 | 324 | def is_progress_bar_enabled() -> bool: 325 | """Return a boolean indicating whether tqdm progress bars are enabled.""" 326 | global _tqdm_active 327 | return bool(_tqdm_active) 328 | 329 | 330 | def enable_progress_bar(): 331 | """Enable tqdm progress bar.""" 332 | global _tqdm_active 333 | _tqdm_active = True 334 | 335 | 336 | def disable_progress_bar(): 337 | """Disable tqdm progress bar.""" 338 | global _tqdm_active 339 | _tqdm_active = False 340 | -------------------------------------------------------------------------------- /diffusers_official/models/unet_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from ..configuration_utils import ConfigMixin, register_to_config 22 | from ..utils import BaseOutput 23 | from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps 24 | from .modeling_utils import ModelMixin 25 | from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block 26 | 27 | 28 | @dataclass 29 | class UNet1DOutput(BaseOutput): 30 | """ 31 | The output of [`UNet1DModel`]. 32 | 33 | Args: 34 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): 35 | The hidden states output from the last layer of the model. 36 | """ 37 | 38 | sample: torch.FloatTensor 39 | 40 | 41 | class UNet1DModel(ModelMixin, ConfigMixin): 42 | r""" 43 | A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. 44 | 45 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 46 | for all models (such as downloading or saving). 47 | 48 | Parameters: 49 | sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. 50 | in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. 51 | out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. 52 | extra_in_channels (`int`, *optional*, defaults to 0): 53 | Number of additional channels to be added to the input of the first down block. Useful for cases where the 54 | input data has more channels than what the model was initially designed for. 55 | time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. 56 | freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. 57 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 58 | Whether to flip sin to cos for Fourier time embedding. 59 | down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): 60 | Tuple of downsample block types. 61 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): 62 | Tuple of upsample block types. 63 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): 64 | Tuple of block output channels. 65 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. 66 | out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. 67 | act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. 68 | norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. 69 | layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. 70 | downsample_each_block (`int`, *optional*, defaults to `False`): 71 | Experimental feature for using a UNet without upsampling. 72 | """ 73 | 74 | @register_to_config 75 | def __init__( 76 | self, 77 | sample_size: int = 65536, 78 | sample_rate: Optional[int] = None, 79 | in_channels: int = 2, 80 | out_channels: int = 2, 81 | extra_in_channels: int = 0, 82 | time_embedding_type: str = "fourier", 83 | flip_sin_to_cos: bool = True, 84 | use_timestep_embedding: bool = False, 85 | freq_shift: float = 0.0, 86 | down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), 87 | up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), 88 | mid_block_type: Tuple[str] = "UNetMidBlock1D", 89 | out_block_type: str = None, 90 | block_out_channels: Tuple[int] = (32, 32, 64), 91 | act_fn: str = None, 92 | norm_num_groups: int = 8, 93 | layers_per_block: int = 1, 94 | downsample_each_block: bool = False, 95 | ): 96 | super().__init__() 97 | self.sample_size = sample_size 98 | 99 | # time 100 | if time_embedding_type == "fourier": 101 | self.time_proj = GaussianFourierProjection( 102 | embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos 103 | ) 104 | timestep_input_dim = 2 * block_out_channels[0] 105 | elif time_embedding_type == "positional": 106 | self.time_proj = Timesteps( 107 | block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift 108 | ) 109 | timestep_input_dim = block_out_channels[0] 110 | 111 | if use_timestep_embedding: 112 | time_embed_dim = block_out_channels[0] * 4 113 | self.time_mlp = TimestepEmbedding( 114 | in_channels=timestep_input_dim, 115 | time_embed_dim=time_embed_dim, 116 | act_fn=act_fn, 117 | out_dim=block_out_channels[0], 118 | ) 119 | 120 | self.down_blocks = nn.ModuleList([]) 121 | self.mid_block = None 122 | self.up_blocks = nn.ModuleList([]) 123 | self.out_block = None 124 | 125 | # down 126 | output_channel = in_channels 127 | for i, down_block_type in enumerate(down_block_types): 128 | input_channel = output_channel 129 | output_channel = block_out_channels[i] 130 | 131 | if i == 0: 132 | input_channel += extra_in_channels 133 | 134 | is_final_block = i == len(block_out_channels) - 1 135 | 136 | down_block = get_down_block( 137 | down_block_type, 138 | num_layers=layers_per_block, 139 | in_channels=input_channel, 140 | out_channels=output_channel, 141 | temb_channels=block_out_channels[0], 142 | add_downsample=not is_final_block or downsample_each_block, 143 | ) 144 | self.down_blocks.append(down_block) 145 | 146 | # mid 147 | self.mid_block = get_mid_block( 148 | mid_block_type, 149 | in_channels=block_out_channels[-1], 150 | mid_channels=block_out_channels[-1], 151 | out_channels=block_out_channels[-1], 152 | embed_dim=block_out_channels[0], 153 | num_layers=layers_per_block, 154 | add_downsample=downsample_each_block, 155 | ) 156 | 157 | # up 158 | reversed_block_out_channels = list(reversed(block_out_channels)) 159 | output_channel = reversed_block_out_channels[0] 160 | if out_block_type is None: 161 | final_upsample_channels = out_channels 162 | else: 163 | final_upsample_channels = block_out_channels[0] 164 | 165 | for i, up_block_type in enumerate(up_block_types): 166 | prev_output_channel = output_channel 167 | output_channel = ( 168 | reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels 169 | ) 170 | 171 | is_final_block = i == len(block_out_channels) - 1 172 | 173 | up_block = get_up_block( 174 | up_block_type, 175 | num_layers=layers_per_block, 176 | in_channels=prev_output_channel, 177 | out_channels=output_channel, 178 | temb_channels=block_out_channels[0], 179 | add_upsample=not is_final_block, 180 | ) 181 | self.up_blocks.append(up_block) 182 | prev_output_channel = output_channel 183 | 184 | # out 185 | num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) 186 | self.out_block = get_out_block( 187 | out_block_type=out_block_type, 188 | num_groups_out=num_groups_out, 189 | embed_dim=block_out_channels[0], 190 | out_channels=out_channels, 191 | act_fn=act_fn, 192 | fc_dim=block_out_channels[-1] // 4, 193 | ) 194 | 195 | def forward( 196 | self, 197 | sample: torch.FloatTensor, 198 | timestep: Union[torch.Tensor, float, int], 199 | return_dict: bool = True, 200 | ) -> Union[UNet1DOutput, Tuple]: 201 | r""" 202 | The [`UNet1DModel`] forward method. 203 | 204 | Args: 205 | sample (`torch.FloatTensor`): 206 | The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. 207 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 208 | return_dict (`bool`, *optional*, defaults to `True`): 209 | Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. 210 | 211 | Returns: 212 | [`~models.unet_1d.UNet1DOutput`] or `tuple`: 213 | If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is 214 | returned where the first element is the sample tensor. 215 | """ 216 | 217 | # 1. time 218 | timesteps = timestep 219 | if not torch.is_tensor(timesteps): 220 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) 221 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 222 | timesteps = timesteps[None].to(sample.device) 223 | 224 | timestep_embed = self.time_proj(timesteps) 225 | if self.config.use_timestep_embedding: 226 | timestep_embed = self.time_mlp(timestep_embed) 227 | else: 228 | timestep_embed = timestep_embed[..., None] 229 | timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) 230 | timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) 231 | 232 | # 2. down 233 | down_block_res_samples = () 234 | for downsample_block in self.down_blocks: 235 | sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) 236 | down_block_res_samples += res_samples 237 | 238 | # 3. mid 239 | if self.mid_block: 240 | sample = self.mid_block(sample, timestep_embed) 241 | 242 | # 4. up 243 | for i, upsample_block in enumerate(self.up_blocks): 244 | res_samples = down_block_res_samples[-1:] 245 | down_block_res_samples = down_block_res_samples[:-1] 246 | sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) 247 | 248 | # 5. post-process 249 | if self.out_block: 250 | sample = self.out_block(sample, timestep_embed) 251 | 252 | if not return_dict: 253 | return (sample,) 254 | 255 | return UNet1DOutput(sample=sample) 256 | --------------------------------------------------------------------------------