├── .gitignore
├── assets
├── color.png
├── font.png
├── size.png
└── footnote.png
├── scripts
├── models
│ ├── utils
│ │ ├── .DS_Store
│ │ └── richtext_utils.py
│ └── dual_transformer_2d.py
└── rich_text_settings.py
├── diffusers_official
├── models
│ ├── activations.py
│ ├── __init__.py
│ ├── embeddings_flax.py
│ ├── resnet_flax.py
│ ├── modeling_flax_pytorch_utils.py
│ ├── cross_attention.py
│ ├── vq_model.py
│ ├── modeling_pytorch_flax_utils.py
│ ├── dual_transformer_2d.py
│ ├── transformer_temporal.py
│ └── unet_1d.py
├── utils
│ ├── dummy_onnx_objects.py
│ ├── dummy_note_seq_objects.py
│ ├── dummy_torch_and_scipy_objects.py
│ ├── dummy_torch_and_torchsde_objects.py
│ ├── dummy_transformers_and_torch_and_note_seq_objects.py
│ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py
│ ├── dummy_torch_and_librosa_objects.py
│ ├── dummy_torch_and_transformers_and_invisible_watermark_objects.py
│ ├── constants.py
│ ├── doc_utils.py
│ ├── pil_utils.py
│ ├── model_card_template.md
│ ├── accelerate_utils.py
│ ├── dummy_flax_and_transformers_objects.py
│ ├── deprecation_utils.py
│ ├── dummy_torch_and_transformers_and_onnx_objects.py
│ ├── torch_utils.py
│ ├── __init__.py
│ ├── outputs.py
│ ├── dummy_flax_objects.py
│ └── logging.py
├── pipelines
│ ├── stable_diffusion_xl
│ │ ├── __init__.py
│ │ └── watermark.py
│ ├── __init__.py
│ └── onnx_utils.py
├── pipeline_utils.py
├── dependency_versions_table.py
├── dependency_versions_check.py
├── schedulers
│ ├── scheduling_sde_vp.py
│ ├── __init__.py
│ ├── scheduling_ipndm.py
│ ├── scheduling_utils.py
│ ├── scheduling_karras_ve_flax.py
│ └── scheduling_karras_ve.py
└── __init__.py
├── install.py
├── README.md
└── share_btn.py
/.gitignore:
--------------------------------------------------------------------------------
1 | venv
2 | __pycache__/
3 | *.pyc
4 | gradio_cached_examples/
--------------------------------------------------------------------------------
/assets/color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/color.png
--------------------------------------------------------------------------------
/assets/font.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/font.png
--------------------------------------------------------------------------------
/assets/size.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/size.png
--------------------------------------------------------------------------------
/assets/footnote.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/assets/footnote.png
--------------------------------------------------------------------------------
/scripts/models/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/songweige/sd-webui-rich-text/HEAD/scripts/models/utils/.DS_Store
--------------------------------------------------------------------------------
/diffusers_official/models/activations.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def get_activation(act_fn):
5 | if act_fn in ["swish", "silu"]:
6 | return nn.SiLU()
7 | elif act_fn == "mish":
8 | return nn.Mish()
9 | elif act_fn == "gelu":
10 | return nn.GELU()
11 | else:
12 | raise ValueError(f"Unsupported activation function: {act_fn}")
13 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_onnx_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class OnnxRuntimeModel(metaclass=DummyObject):
6 | _backends = ["onnx"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["onnx"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["onnx"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["onnx"])
18 |
--------------------------------------------------------------------------------
/scripts/rich_text_settings.py:
--------------------------------------------------------------------------------
1 | import modules.scripts as scripts
2 | import gradio as gr
3 | import os
4 |
5 | from modules import shared
6 | from modules import script_callbacks
7 |
8 | def on_ui_settings():
9 | section = ('template', "Rich-Text-to-Image")
10 | shared.opts.add_option(
11 | "option1",
12 | shared.OptionInfo(
13 | False,
14 | "This is a placeholder for option. It is not used yet.",
15 | gr.Checkbox,
16 | {"interactive": True},
17 | section=section)
18 | )
19 |
20 | script_callbacks.on_ui_settings(on_ui_settings)
21 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_note_seq_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class MidiProcessor(metaclass=DummyObject):
6 | _backends = ["note_seq"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["note_seq"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["note_seq"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["note_seq"])
18 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_torch_and_scipy_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class LMSDiscreteScheduler(metaclass=DummyObject):
6 | _backends = ["torch", "scipy"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "scipy"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "scipy"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "scipy"])
18 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_torch_and_torchsde_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class DPMSolverSDEScheduler(metaclass=DummyObject):
6 | _backends = ["torch", "torchsde"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "torchsde"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "torchsde"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "torchsde"])
18 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_transformers_and_torch_and_note_seq_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class SpectrogramDiffusionPipeline(metaclass=DummyObject):
6 | _backends = ["transformers", "torch", "note_seq"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["transformers", "torch", "note_seq"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["transformers", "torch", "note_seq"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["transformers", "torch", "note_seq"])
18 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class StableDiffusionKDiffusionPipeline(metaclass=DummyObject):
6 | _backends = ["torch", "transformers", "k_diffusion"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "transformers", "k_diffusion"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "transformers", "k_diffusion"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "transformers", "k_diffusion"])
18 |
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | import launch
2 |
3 | # TODO: add pip dependency if need extra module only on extension
4 |
5 | if not launch.is_installed("diffusers"):
6 | launch.run_pip("install diffusers==0.18.2", "requirements for Rich-Text-to-Image")
7 |
8 | if not launch.is_installed("invisible-watermark"):
9 | launch.run_pip("install invisible-watermark==0.2.0", "requirements for Rich-Text-to-Image")
10 |
11 | if not launch.is_installed("accelerate"):
12 | launch.run_pip("install accelerate==0.21.0", "requirements for Rich-Text-to-Image")
13 |
14 | if not launch.is_installed("safetensors"):
15 | launch.run_pip("install safetensors==0.3.1", "requirements for Rich-Text-to-Image")
16 |
17 | if not launch.is_installed("seaborn"):
18 | launch.run_pip("install seaborn==0.12.2", "requirements for Rich-Text-to-Image")
19 |
20 | if not launch.is_installed("scikit-learn"):
21 | launch.run_pip("install scikit-learn==1.3.0", "requirements for Rich-Text-to-Image")
22 |
23 | if not launch.is_installed("threadpoolctl"):
24 | launch.run_pip("install threadpoolctl==3.1.0", "requirements for Rich-Text-to-Image")
--------------------------------------------------------------------------------
/diffusers_official/pipelines/stable_diffusion_xl/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Union
3 |
4 | import numpy as np
5 | import PIL
6 |
7 | from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available
8 |
9 |
10 | @dataclass
11 | class StableDiffusionXLPipelineOutput(BaseOutput):
12 | """
13 | Output class for Stable Diffusion pipelines.
14 |
15 | Args:
16 | images (`List[PIL.Image.Image]` or `np.ndarray`)
17 | List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18 | num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19 | """
20 |
21 | images: Union[List[PIL.Image.Image], np.ndarray]
22 |
23 |
24 | if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
25 | from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
26 | from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
27 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_torch_and_librosa_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class AudioDiffusionPipeline(metaclass=DummyObject):
6 | _backends = ["torch", "librosa"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "librosa"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "librosa"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "librosa"])
18 |
19 |
20 | class Mel(metaclass=DummyObject):
21 | _backends = ["torch", "librosa"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["torch", "librosa"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["torch", "librosa"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["torch", "librosa"])
33 |
--------------------------------------------------------------------------------
/diffusers_official/pipeline_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 |
14 | # limitations under the License.
15 |
16 | # NOTE: This file is deprecated and will be removed in a future version.
17 | # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
18 |
19 | from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
20 | from .utils import deprecate
21 |
22 |
23 | deprecate(
24 | "pipelines_utils",
25 | "0.22.0",
26 | "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
27 | standard_warn=False,
28 | stacklevel=3,
29 | )
30 |
--------------------------------------------------------------------------------
/diffusers_official/pipelines/stable_diffusion_xl/watermark.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from imwatermark import WatermarkEncoder
4 |
5 |
6 | # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66
7 | WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
8 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
9 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
10 |
11 |
12 | class StableDiffusionXLWatermarker:
13 | def __init__(self):
14 | self.watermark = WATERMARK_BITS
15 | self.encoder = WatermarkEncoder()
16 |
17 | self.encoder.set_watermark("bits", self.watermark)
18 |
19 | def apply_watermark(self, images: torch.FloatTensor):
20 | # can't encode images that are smaller than 256
21 | if images.shape[-1] < 256:
22 | return images
23 |
24 | images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy()
25 |
26 | images = [self.encoder.encode(image, "dwtDct") for image in images]
27 |
28 | images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
29 |
30 | images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0)
31 | return images
32 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
6 | _backends = ["torch", "transformers", "invisible_watermark"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "transformers", "invisible_watermark"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
18 |
19 |
20 | class StableDiffusionXLPipeline(metaclass=DummyObject):
21 | _backends = ["torch", "transformers", "invisible_watermark"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["torch", "transformers", "invisible_watermark"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
33 |
--------------------------------------------------------------------------------
/diffusers_official/utils/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 |
16 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
17 |
18 |
19 | default_cache_path = HUGGINGFACE_HUB_CACHE
20 |
21 |
22 | CONFIG_NAME = "config.json"
23 | WEIGHTS_NAME = "diffusion_pytorch_model.bin"
24 | FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
25 | ONNX_WEIGHTS_NAME = "model.onnx"
26 | SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
27 | ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
28 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
29 | DIFFUSERS_CACHE = default_cache_path
30 | DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
31 | HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
32 | DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
33 | TEXT_ENCODER_ATTN_MODULE = ".self_attn"
34 |
--------------------------------------------------------------------------------
/diffusers_official/utils/doc_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Doc utilities: Utilities related to documentation
16 | """
17 | import re
18 |
19 |
20 | def replace_example_docstring(example_docstring):
21 | def docstring_decorator(fn):
22 | func_doc = fn.__doc__
23 | lines = func_doc.split("\n")
24 | i = 0
25 | while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None:
26 | i += 1
27 | if i < len(lines):
28 | lines[i] = example_docstring
29 | func_doc = "\n".join(lines)
30 | else:
31 | raise ValueError(
32 | f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, "
33 | f"current docstring is:\n{func_doc}"
34 | )
35 | fn.__doc__ = func_doc
36 | return fn
37 |
38 | return docstring_decorator
39 |
--------------------------------------------------------------------------------
/diffusers_official/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from ..utils import is_flax_available, is_torch_available
16 |
17 |
18 | if is_torch_available():
19 | from .autoencoder_kl import AutoencoderKL
20 | from .controlnet import ControlNetModel
21 | from .dual_transformer_2d import DualTransformer2DModel
22 | from .modeling_utils import ModelMixin
23 | from .prior_transformer import PriorTransformer
24 | from .t5_film_transformer import T5FilmDecoder
25 | from .transformer_2d import Transformer2DModel
26 | from .unet_1d import UNet1DModel
27 | from .unet_2d import UNet2DModel
28 | from .unet_2d_condition import UNet2DConditionModel
29 | from .unet_3d_condition import UNet3DConditionModel
30 | from .vq_model import VQModel
31 |
32 | if is_flax_available():
33 | from .controlnet_flax import FlaxControlNetModel
34 | from .unet_2d_condition_flax import FlaxUNet2DConditionModel
35 | from .vae_flax import FlaxAutoencoderKL
36 |
--------------------------------------------------------------------------------
/diffusers_official/utils/pil_utils.py:
--------------------------------------------------------------------------------
1 | import PIL.Image
2 | import PIL.ImageOps
3 | from packaging import version
4 | from PIL import Image
5 |
6 |
7 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
8 | PIL_INTERPOLATION = {
9 | "linear": PIL.Image.Resampling.BILINEAR,
10 | "bilinear": PIL.Image.Resampling.BILINEAR,
11 | "bicubic": PIL.Image.Resampling.BICUBIC,
12 | "lanczos": PIL.Image.Resampling.LANCZOS,
13 | "nearest": PIL.Image.Resampling.NEAREST,
14 | }
15 | else:
16 | PIL_INTERPOLATION = {
17 | "linear": PIL.Image.LINEAR,
18 | "bilinear": PIL.Image.BILINEAR,
19 | "bicubic": PIL.Image.BICUBIC,
20 | "lanczos": PIL.Image.LANCZOS,
21 | "nearest": PIL.Image.NEAREST,
22 | }
23 |
24 |
25 | def pt_to_pil(images):
26 | """
27 | Convert a torch image to a PIL image.
28 | """
29 | images = (images / 2 + 0.5).clamp(0, 1)
30 | images = images.cpu().permute(0, 2, 3, 1).float().numpy()
31 | images = numpy_to_pil(images)
32 | return images
33 |
34 |
35 | def numpy_to_pil(images):
36 | """
37 | Convert a numpy image or a batch of images to a PIL image.
38 | """
39 | if images.ndim == 3:
40 | images = images[None, ...]
41 | images = (images * 255).round().astype("uint8")
42 | if images.shape[-1] == 1:
43 | # special case for grayscale (single channel) images
44 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
45 | else:
46 | pil_images = [Image.fromarray(image) for image in images]
47 |
48 | return pil_images
49 |
--------------------------------------------------------------------------------
/diffusers_official/dependency_versions_table.py:
--------------------------------------------------------------------------------
1 | # THIS FILE HAS BEEN AUTOGENERATED. To update:
2 | # 1. modify the `_deps` dict in setup.py
3 | # 2. run `make deps_table_update``
4 | deps = {
5 | "Pillow": "Pillow",
6 | "accelerate": "accelerate>=0.11.0",
7 | "compel": "compel==0.1.8",
8 | "black": "black~=23.1",
9 | "datasets": "datasets",
10 | "filelock": "filelock",
11 | "flax": "flax>=0.4.1",
12 | "hf-doc-builder": "hf-doc-builder>=0.3.0",
13 | "huggingface-hub": "huggingface-hub>=0.13.2",
14 | "requests-mock": "requests-mock==1.10.0",
15 | "importlib_metadata": "importlib_metadata",
16 | "invisible-watermark": "invisible-watermark",
17 | "isort": "isort>=5.5.4",
18 | "jax": "jax>=0.2.8,!=0.3.2",
19 | "jaxlib": "jaxlib>=0.1.65",
20 | "Jinja2": "Jinja2",
21 | "k-diffusion": "k-diffusion>=0.0.12",
22 | "torchsde": "torchsde",
23 | "note_seq": "note_seq",
24 | "librosa": "librosa",
25 | "numpy": "numpy",
26 | "omegaconf": "omegaconf",
27 | "parameterized": "parameterized",
28 | "protobuf": "protobuf>=3.20.3,<4",
29 | "pytest": "pytest",
30 | "pytest-timeout": "pytest-timeout",
31 | "pytest-xdist": "pytest-xdist",
32 | "ruff": "ruff>=0.0.241",
33 | "safetensors": "safetensors",
34 | "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35 | "scipy": "scipy",
36 | "onnx": "onnx",
37 | "regex": "regex!=2019.12.17",
38 | "requests": "requests",
39 | "tensorboard": "tensorboard",
40 | "torch": "torch>=1.4",
41 | "torchvision": "torchvision",
42 | "transformers": "transformers>=4.25.1",
43 | "urllib3": "urllib3<=2.0.0",
44 | }
45 |
--------------------------------------------------------------------------------
/diffusers_official/utils/model_card_template.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 |
7 |
8 | # {{ model_name | default("Diffusion Model") }}
9 |
10 | ## Model description
11 |
12 | This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
13 | on the `{{ dataset_name }}` dataset.
14 |
15 | ## Intended uses & limitations
16 |
17 | #### How to use
18 |
19 | ```python
20 | # TODO: add an example code snippet for running this diffusion pipeline
21 | ```
22 |
23 | #### Limitations and bias
24 |
25 | [TODO: provide examples of latent issues and potential remediations]
26 |
27 | ## Training data
28 |
29 | [TODO: describe the data used to train the model]
30 |
31 | ### Training hyperparameters
32 |
33 | The following hyperparameters were used during training:
34 | - learning_rate: {{ learning_rate }}
35 | - train_batch_size: {{ train_batch_size }}
36 | - eval_batch_size: {{ eval_batch_size }}
37 | - gradient_accumulation_steps: {{ gradient_accumulation_steps }}
38 | - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
39 | - lr_scheduler: {{ lr_scheduler }}
40 | - lr_warmup_steps: {{ lr_warmup_steps }}
41 | - ema_inv_gamma: {{ ema_inv_gamma }}
42 | - ema_inv_gamma: {{ ema_power }}
43 | - ema_inv_gamma: {{ ema_max_decay }}
44 | - mixed_precision: {{ mixed_precision }}
45 |
46 | ### Training results
47 |
48 | 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
49 |
50 |
51 |
--------------------------------------------------------------------------------
/diffusers_official/dependency_versions_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import sys
15 |
16 | from .dependency_versions_table import deps
17 | from .utils.versions import require_version, require_version_core
18 |
19 |
20 | # define which module versions we always want to check at run time
21 | # (usually the ones defined in `install_requires` in setup.py)
22 | #
23 | # order specific notes:
24 | # - tqdm must be checked before tokenizers
25 |
26 | pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27 | if sys.version_info < (3, 7):
28 | pkgs_to_check_at_runtime.append("dataclasses")
29 | if sys.version_info < (3, 8):
30 | pkgs_to_check_at_runtime.append("importlib_metadata")
31 |
32 | for pkg in pkgs_to_check_at_runtime:
33 | if pkg in deps:
34 | if pkg == "tokenizers":
35 | # must be loaded here, or else tqdm check may fail
36 | from .utils import is_tokenizers_available
37 |
38 | if not is_tokenizers_available():
39 | continue # not required, check version only if installed
40 |
41 | require_version_core(deps[pkg])
42 | else:
43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44 |
45 |
46 | def dep_version_check(pkg, hint=None):
47 | require_version(deps[pkg], hint)
48 |
--------------------------------------------------------------------------------
/diffusers_official/utils/accelerate_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Accelerate utilities: Utilities related to accelerate
16 | """
17 |
18 | from packaging import version
19 |
20 | from .import_utils import is_accelerate_available
21 |
22 |
23 | if is_accelerate_available():
24 | import accelerate
25 |
26 |
27 | def apply_forward_hook(method):
28 | """
29 | Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful
30 | for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the
31 | appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`].
32 |
33 | This decorator looks inside the internal `_hf_hook` property to find a registered offload hook.
34 |
35 | :param method: The method to decorate. This method should be a method of a PyTorch module.
36 | """
37 | if not is_accelerate_available():
38 | return method
39 | accelerate_version = version.parse(accelerate.__version__).base_version
40 | if version.parse(accelerate_version) < version.parse("0.17.0"):
41 | return method
42 |
43 | def wrapper(self, *args, **kwargs):
44 | if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):
45 | self._hf_hook.pre_forward(self)
46 | return method(self, *args, **kwargs)
47 |
48 | return wrapper
49 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_flax_and_transformers_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject):
6 | _backends = ["flax", "transformers"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["flax", "transformers"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["flax", "transformers"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["flax", "transformers"])
18 |
19 |
20 | class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
21 | _backends = ["flax", "transformers"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["flax", "transformers"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["flax", "transformers"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["flax", "transformers"])
33 |
34 |
35 | class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject):
36 | _backends = ["flax", "transformers"]
37 |
38 | def __init__(self, *args, **kwargs):
39 | requires_backends(self, ["flax", "transformers"])
40 |
41 | @classmethod
42 | def from_config(cls, *args, **kwargs):
43 | requires_backends(cls, ["flax", "transformers"])
44 |
45 | @classmethod
46 | def from_pretrained(cls, *args, **kwargs):
47 | requires_backends(cls, ["flax", "transformers"])
48 |
49 |
50 | class FlaxStableDiffusionPipeline(metaclass=DummyObject):
51 | _backends = ["flax", "transformers"]
52 |
53 | def __init__(self, *args, **kwargs):
54 | requires_backends(self, ["flax", "transformers"])
55 |
56 | @classmethod
57 | def from_config(cls, *args, **kwargs):
58 | requires_backends(cls, ["flax", "transformers"])
59 |
60 | @classmethod
61 | def from_pretrained(cls, *args, **kwargs):
62 | requires_backends(cls, ["flax", "transformers"])
63 |
--------------------------------------------------------------------------------
/diffusers_official/utils/deprecation_utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import warnings
3 | from typing import Any, Dict, Optional, Union
4 |
5 | from packaging import version
6 |
7 |
8 | def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
9 | from .. import __version__
10 |
11 | deprecated_kwargs = take_from
12 | values = ()
13 | if not isinstance(args[0], tuple):
14 | args = (args,)
15 |
16 | for attribute, version_name, message in args:
17 | if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
18 | raise ValueError(
19 | f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
20 | f" version {__version__} is >= {version_name}"
21 | )
22 |
23 | warning = None
24 | if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
25 | values += (deprecated_kwargs.pop(attribute),)
26 | warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
27 | elif hasattr(deprecated_kwargs, attribute):
28 | values += (getattr(deprecated_kwargs, attribute),)
29 | warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
30 | elif deprecated_kwargs is None:
31 | warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
32 |
33 | if warning is not None:
34 | warning = warning + " " if standard_warn else ""
35 | warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
36 |
37 | if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
38 | call_frame = inspect.getouterframes(inspect.currentframe())[1]
39 | filename = call_frame.filename
40 | line_number = call_frame.lineno
41 | function = call_frame.function
42 | key, value = next(iter(deprecated_kwargs.items()))
43 | raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
44 |
45 | if len(values) == 0:
46 | return
47 | elif len(values) == 1:
48 | return values[0]
49 | return values
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Rich-Text-to-Image for Stable Diffusion WebUI
2 | #### [Project Page](https://rich-text-to-image.github.io/) | [Paper](https://arxiv.org/abs/2304.06720) | [Code](https://github.com/songweige/rich-text-to-image) | [HuggingFace Demo](https://huggingface.co/spaces/songweig/rich-text-to-image) | [Video](https://youtu.be/ihDbAUh0LXk)
3 |
4 | The WebUI extension for integrating a rich-text editor for text-to-image generation.
5 |
6 | 
7 |
8 | This extension is for [AUTOMATIC1111's Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [rich-text-to-image](https://rich-text-to-image.github.io/) to the original Stable Diffusion model to generate images.
9 |
10 | ## Installation
11 |
12 | 1. Open "Extensions" tab.
13 | 1. Open "Install from URL" tab in the tab.
14 | 1. Enter URL of this repo (https://github.com/songweige/sd-webui-rich-text) to "URL for extension's git repository".
15 | 1. Press "Install" button.
16 | 1. Restart Web UI.
17 |
18 | ## Usage
19 |
20 | The extension now supports [SD-v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) (default), [SD-XL-v1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [ANIMAGINE-XL](https://huggingface.co/Linaqruf/animagine-xl). The checkpoints will be automatically downloaded when the model is first selected.
21 |
22 |
23 | #### Font Color
24 |
25 | 
26 |
27 | Font color is used to control the precise color of the generated objects.
28 |
29 | #### Footnote
30 |
31 | 
32 |
33 | Footnotes provide supplementary descriptions for selected text elements.
34 |
35 | #### Font Style
36 |
37 | 
38 |
39 | Just as the font style distinguishes the styles of individual text elements, it is used to define the artistic style of specific areas in the generation.
40 |
41 | #### Font Size
42 |
43 | 
44 |
45 | Font size indicates the weight of each token in the final generation.
46 |
47 | ## Acknowledgement
48 |
49 | The extension is built on the [extension-templates](https://github.com/udon-universe/stable-diffusion-webui-extension-templates). The rich-text editor is built on [Quill](https://quilljs.com/). The model code is built on [huggingface / diffusers](https://github.com/huggingface/diffusers#readme).
50 |
--------------------------------------------------------------------------------
/diffusers_official/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from ..utils import (
2 | OptionalDependencyNotAvailable,
3 | is_flax_available,
4 | is_invisible_watermark_available,
5 | is_k_diffusion_available,
6 | is_librosa_available,
7 | is_note_seq_available,
8 | is_onnx_available,
9 | is_torch_available,
10 | is_transformers_available,
11 | )
12 |
13 |
14 | try:
15 | if not is_torch_available():
16 | raise OptionalDependencyNotAvailable()
17 | except OptionalDependencyNotAvailable:
18 | from ..utils.dummy_pt_objects import * # noqa F403
19 | else:
20 | from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
21 |
22 | try:
23 | if not (is_torch_available() and is_librosa_available()):
24 | raise OptionalDependencyNotAvailable()
25 | except OptionalDependencyNotAvailable:
26 | from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
27 |
28 | try:
29 | if not (is_torch_available() and is_transformers_available()):
30 | raise OptionalDependencyNotAvailable()
31 | except OptionalDependencyNotAvailable:
32 | from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
33 |
34 |
35 | try:
36 | if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
37 | raise OptionalDependencyNotAvailable()
38 | except OptionalDependencyNotAvailable:
39 | from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
40 | else:
41 | from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
42 |
43 | try:
44 | if not is_onnx_available():
45 | raise OptionalDependencyNotAvailable()
46 | except OptionalDependencyNotAvailable:
47 | from ..utils.dummy_onnx_objects import * # noqa F403
48 | else:
49 | from .onnx_utils import OnnxRuntimeModel
50 |
51 | try:
52 | if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
53 | raise OptionalDependencyNotAvailable()
54 | except OptionalDependencyNotAvailable:
55 | from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
56 |
57 |
58 | try:
59 | if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
60 | raise OptionalDependencyNotAvailable()
61 | except OptionalDependencyNotAvailable:
62 | from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
63 |
64 | try:
65 | if not is_flax_available():
66 | raise OptionalDependencyNotAvailable()
67 | except OptionalDependencyNotAvailable:
68 | from ..utils.dummy_flax_objects import * # noqa F403
69 | else:
70 | from .pipeline_flax_utils import FlaxDiffusionPipeline
71 |
72 |
73 | try:
74 | if not (is_flax_available() and is_transformers_available()):
75 | raise OptionalDependencyNotAvailable()
76 | except OptionalDependencyNotAvailable:
77 | from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
78 |
79 | try:
80 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
81 | raise OptionalDependencyNotAvailable()
82 | except OptionalDependencyNotAvailable:
83 | from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
84 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_torch_and_transformers_and_onnx_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
6 | _backends = ["torch", "transformers", "onnx"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["torch", "transformers", "onnx"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["torch", "transformers", "onnx"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["torch", "transformers", "onnx"])
18 |
19 |
20 | class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject):
21 | _backends = ["torch", "transformers", "onnx"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["torch", "transformers", "onnx"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["torch", "transformers", "onnx"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["torch", "transformers", "onnx"])
33 |
34 |
35 | class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
36 | _backends = ["torch", "transformers", "onnx"]
37 |
38 | def __init__(self, *args, **kwargs):
39 | requires_backends(self, ["torch", "transformers", "onnx"])
40 |
41 | @classmethod
42 | def from_config(cls, *args, **kwargs):
43 | requires_backends(cls, ["torch", "transformers", "onnx"])
44 |
45 | @classmethod
46 | def from_pretrained(cls, *args, **kwargs):
47 | requires_backends(cls, ["torch", "transformers", "onnx"])
48 |
49 |
50 | class OnnxStableDiffusionPipeline(metaclass=DummyObject):
51 | _backends = ["torch", "transformers", "onnx"]
52 |
53 | def __init__(self, *args, **kwargs):
54 | requires_backends(self, ["torch", "transformers", "onnx"])
55 |
56 | @classmethod
57 | def from_config(cls, *args, **kwargs):
58 | requires_backends(cls, ["torch", "transformers", "onnx"])
59 |
60 | @classmethod
61 | def from_pretrained(cls, *args, **kwargs):
62 | requires_backends(cls, ["torch", "transformers", "onnx"])
63 |
64 |
65 | class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject):
66 | _backends = ["torch", "transformers", "onnx"]
67 |
68 | def __init__(self, *args, **kwargs):
69 | requires_backends(self, ["torch", "transformers", "onnx"])
70 |
71 | @classmethod
72 | def from_config(cls, *args, **kwargs):
73 | requires_backends(cls, ["torch", "transformers", "onnx"])
74 |
75 | @classmethod
76 | def from_pretrained(cls, *args, **kwargs):
77 | requires_backends(cls, ["torch", "transformers", "onnx"])
78 |
79 |
80 | class StableDiffusionOnnxPipeline(metaclass=DummyObject):
81 | _backends = ["torch", "transformers", "onnx"]
82 |
83 | def __init__(self, *args, **kwargs):
84 | requires_backends(self, ["torch", "transformers", "onnx"])
85 |
86 | @classmethod
87 | def from_config(cls, *args, **kwargs):
88 | requires_backends(cls, ["torch", "transformers", "onnx"])
89 |
90 | @classmethod
91 | def from_pretrained(cls, *args, **kwargs):
92 | requires_backends(cls, ["torch", "transformers", "onnx"])
93 |
--------------------------------------------------------------------------------
/diffusers_official/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | PyTorch utilities: Utilities related to PyTorch
16 | """
17 | from typing import List, Optional, Tuple, Union
18 |
19 | from . import logging
20 | from .import_utils import is_torch_available, is_torch_version
21 |
22 |
23 | if is_torch_available():
24 | import torch
25 |
26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27 |
28 | try:
29 | from torch._dynamo import allow_in_graph as maybe_allow_in_graph
30 | except (ImportError, ModuleNotFoundError):
31 |
32 | def maybe_allow_in_graph(cls):
33 | return cls
34 |
35 |
36 | def randn_tensor(
37 | shape: Union[Tuple, List],
38 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
39 | device: Optional["torch.device"] = None,
40 | dtype: Optional["torch.dtype"] = None,
41 | layout: Optional["torch.layout"] = None,
42 | ):
43 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
44 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
45 | is always created on the CPU.
46 | """
47 | # device on which tensor is created defaults to device
48 | rand_device = device
49 | batch_size = shape[0]
50 |
51 | layout = layout or torch.strided
52 | device = device or torch.device("cpu")
53 |
54 | if generator is not None:
55 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
56 | if gen_device_type != device.type and gen_device_type == "cpu":
57 | rand_device = "cpu"
58 | if device != "mps":
59 | logger.info(
60 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
61 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
62 | f" slighly speed up this function by passing a generator that was created on the {device} device."
63 | )
64 | elif gen_device_type != device.type and gen_device_type == "cuda":
65 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
66 |
67 | if isinstance(generator, list):
68 | shape = (1,) + shape[1:]
69 | latents = [
70 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
71 | for i in range(batch_size)
72 | ]
73 | latents = torch.cat(latents, dim=0).to(device)
74 | else:
75 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
76 |
77 | return latents
78 |
79 |
80 | def is_compiled_module(module):
81 | """Check whether the module was compiled with torch.compile()"""
82 | if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
83 | return False
84 | return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
85 |
--------------------------------------------------------------------------------
/diffusers_official/schedulers/scheduling_sde_vp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
16 |
17 | import math
18 | from typing import Union
19 |
20 | import torch
21 |
22 | from ..configuration_utils import ConfigMixin, register_to_config
23 | from ..utils import randn_tensor
24 | from .scheduling_utils import SchedulerMixin
25 |
26 |
27 | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
28 | """
29 | The variance preserving stochastic differential equation (SDE) scheduler.
30 |
31 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
32 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
33 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
34 | [`~SchedulerMixin.from_pretrained`] functions.
35 |
36 | For more information, see the original paper: https://arxiv.org/abs/2011.13456
37 |
38 | UNDER CONSTRUCTION
39 |
40 | """
41 |
42 | order = 1
43 |
44 | @register_to_config
45 | def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
46 | self.sigmas = None
47 | self.discrete_sigmas = None
48 | self.timesteps = None
49 |
50 | def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
51 | self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
52 |
53 | def step_pred(self, score, x, t, generator=None):
54 | if self.timesteps is None:
55 | raise ValueError(
56 | "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
57 | )
58 |
59 | # TODO(Patrick) better comments + non-PyTorch
60 | # postprocess model score
61 | log_mean_coeff = (
62 | -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
63 | )
64 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
65 | std = std.flatten()
66 | while len(std.shape) < len(score.shape):
67 | std = std.unsqueeze(-1)
68 | score = -score / std
69 |
70 | # compute
71 | dt = -1.0 / len(self.timesteps)
72 |
73 | beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
74 | beta_t = beta_t.flatten()
75 | while len(beta_t.shape) < len(x.shape):
76 | beta_t = beta_t.unsqueeze(-1)
77 | drift = -0.5 * beta_t * x
78 |
79 | diffusion = torch.sqrt(beta_t)
80 | drift = drift - diffusion**2 * score
81 | x_mean = x + drift * dt
82 |
83 | # add noise
84 | noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype)
85 | x = x_mean + diffusion * math.sqrt(-dt) * noise
86 |
87 | return x, x_mean
88 |
89 | def __len__(self):
90 | return self.config.num_train_timesteps
91 |
--------------------------------------------------------------------------------
/diffusers_official/models/embeddings_flax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 |
16 | import flax.linen as nn
17 | import jax.numpy as jnp
18 |
19 |
20 | def get_sinusoidal_embeddings(
21 | timesteps: jnp.ndarray,
22 | embedding_dim: int,
23 | freq_shift: float = 1,
24 | min_timescale: float = 1,
25 | max_timescale: float = 1.0e4,
26 | flip_sin_to_cos: bool = False,
27 | scale: float = 1.0,
28 | ) -> jnp.ndarray:
29 | """Returns the positional encoding (same as Tensor2Tensor).
30 |
31 | Args:
32 | timesteps: a 1-D Tensor of N indices, one per batch element.
33 | These may be fractional.
34 | embedding_dim: The number of output channels.
35 | min_timescale: The smallest time unit (should probably be 0.0).
36 | max_timescale: The largest time unit.
37 | Returns:
38 | a Tensor of timing signals [N, num_channels]
39 | """
40 | assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
41 | assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
42 | num_timescales = float(embedding_dim // 2)
43 | log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
44 | inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
45 | emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
46 |
47 | # scale embeddings
48 | scaled_time = scale * emb
49 |
50 | if flip_sin_to_cos:
51 | signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
52 | else:
53 | signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
54 | signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
55 | return signal
56 |
57 |
58 | class FlaxTimestepEmbedding(nn.Module):
59 | r"""
60 | Time step Embedding Module. Learns embeddings for input time steps.
61 |
62 | Args:
63 | time_embed_dim (`int`, *optional*, defaults to `32`):
64 | Time step embedding dimension
65 | dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66 | Parameters `dtype`
67 | """
68 | time_embed_dim: int = 32
69 | dtype: jnp.dtype = jnp.float32
70 |
71 | @nn.compact
72 | def __call__(self, temb):
73 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
74 | temb = nn.silu(temb)
75 | temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
76 | return temb
77 |
78 |
79 | class FlaxTimesteps(nn.Module):
80 | r"""
81 | Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
82 |
83 | Args:
84 | dim (`int`, *optional*, defaults to `32`):
85 | Time step embedding dimension
86 | """
87 | dim: int = 32
88 | flip_sin_to_cos: bool = False
89 | freq_shift: float = 1
90 |
91 | @nn.compact
92 | def __call__(self, timesteps):
93 | return get_sinusoidal_embeddings(
94 | timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
95 | )
96 |
--------------------------------------------------------------------------------
/diffusers_official/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 |
18 | from packaging import version
19 |
20 | from .. import __version__
21 | from .accelerate_utils import apply_forward_hook
22 | from .constants import (
23 | CONFIG_NAME,
24 | DEPRECATED_REVISION_ARGS,
25 | DIFFUSERS_CACHE,
26 | DIFFUSERS_DYNAMIC_MODULE_NAME,
27 | FLAX_WEIGHTS_NAME,
28 | HF_MODULES_CACHE,
29 | HUGGINGFACE_CO_RESOLVE_ENDPOINT,
30 | ONNX_EXTERNAL_WEIGHTS_NAME,
31 | ONNX_WEIGHTS_NAME,
32 | SAFETENSORS_WEIGHTS_NAME,
33 | TEXT_ENCODER_ATTN_MODULE,
34 | WEIGHTS_NAME,
35 | )
36 | from .deprecation_utils import deprecate
37 | from .doc_utils import replace_example_docstring
38 | from .dynamic_modules_utils import get_class_from_dynamic_module
39 | from .hub_utils import (
40 | HF_HUB_OFFLINE,
41 | _add_variant,
42 | _get_model_file,
43 | extract_commit_hash,
44 | http_user_agent,
45 | )
46 | from .import_utils import (
47 | BACKENDS_MAPPING,
48 | ENV_VARS_TRUE_AND_AUTO_VALUES,
49 | ENV_VARS_TRUE_VALUES,
50 | USE_JAX,
51 | USE_TF,
52 | USE_TORCH,
53 | DummyObject,
54 | OptionalDependencyNotAvailable,
55 | is_accelerate_available,
56 | is_accelerate_version,
57 | is_bs4_available,
58 | is_flax_available,
59 | is_ftfy_available,
60 | is_inflect_available,
61 | is_invisible_watermark_available,
62 | is_k_diffusion_available,
63 | is_k_diffusion_version,
64 | is_librosa_available,
65 | is_note_seq_available,
66 | is_omegaconf_available,
67 | is_onnx_available,
68 | is_safetensors_available,
69 | is_scipy_available,
70 | is_tensorboard_available,
71 | is_tf_available,
72 | is_torch_available,
73 | is_torch_version,
74 | is_torchsde_available,
75 | is_transformers_available,
76 | is_transformers_version,
77 | is_unidecode_available,
78 | is_wandb_available,
79 | is_xformers_available,
80 | requires_backends,
81 | )
82 | from .logging import get_logger
83 | from .outputs import BaseOutput
84 | from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil
85 | from .torch_utils import is_compiled_module, randn_tensor
86 |
87 |
88 | if is_torch_available():
89 | from .testing_utils import (
90 | floats_tensor,
91 | load_hf_numpy,
92 | load_image,
93 | load_numpy,
94 | load_pt,
95 | nightly,
96 | parse_flag_from_env,
97 | print_tensor_test,
98 | require_torch_2,
99 | require_torch_gpu,
100 | skip_mps,
101 | slow,
102 | torch_all_close,
103 | torch_device,
104 | )
105 | from .torch_utils import maybe_allow_in_graph
106 |
107 | from .testing_utils import export_to_gif, export_to_video
108 |
109 |
110 | logger = get_logger(__name__)
111 |
112 |
113 | def check_min_version(min_version):
114 | if version.parse(__version__) < version.parse(min_version):
115 | if "dev" in min_version:
116 | error_message = (
117 | "This example requires a source install from HuggingFace diffusers (see "
118 | "`https://huggingface.co/docs/diffusers/installation#install-from-source`),"
119 | )
120 | else:
121 | error_message = f"This example requires a minimum version of {min_version},"
122 | error_message += f" but the version found is {__version__}.\n"
123 | raise ImportError(error_message)
124 |
--------------------------------------------------------------------------------
/diffusers_official/utils/outputs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Generic utilities
16 | """
17 |
18 | from collections import OrderedDict
19 | from dataclasses import fields
20 | from typing import Any, Tuple
21 |
22 | import numpy as np
23 |
24 | from .import_utils import is_torch_available
25 |
26 |
27 | def is_tensor(x):
28 | """
29 | Tests if `x` is a `torch.Tensor` or `np.ndarray`.
30 | """
31 | if is_torch_available():
32 | import torch
33 |
34 | if isinstance(x, torch.Tensor):
35 | return True
36 |
37 | return isinstance(x, np.ndarray)
38 |
39 |
40 | class BaseOutput(OrderedDict):
41 | """
42 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
43 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
44 | Python dictionary.
45 |
46 |
47 |
48 | You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
49 | first.
50 |
51 |
52 | """
53 |
54 | def __post_init__(self):
55 | class_fields = fields(self)
56 |
57 | # Safety and consistency checks
58 | if not len(class_fields):
59 | raise ValueError(f"{self.__class__.__name__} has no fields.")
60 |
61 | first_field = getattr(self, class_fields[0].name)
62 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
63 |
64 | if other_fields_are_none and isinstance(first_field, dict):
65 | for key, value in first_field.items():
66 | self[key] = value
67 | else:
68 | for field in class_fields:
69 | v = getattr(self, field.name)
70 | if v is not None:
71 | self[field.name] = v
72 |
73 | def __delitem__(self, *args, **kwargs):
74 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
75 |
76 | def setdefault(self, *args, **kwargs):
77 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
78 |
79 | def pop(self, *args, **kwargs):
80 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
81 |
82 | def update(self, *args, **kwargs):
83 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
84 |
85 | def __getitem__(self, k):
86 | if isinstance(k, str):
87 | inner_dict = dict(self.items())
88 | return inner_dict[k]
89 | else:
90 | return self.to_tuple()[k]
91 |
92 | def __setattr__(self, name, value):
93 | if name in self.keys() and value is not None:
94 | # Don't call self.__setitem__ to avoid recursion errors
95 | super().__setitem__(name, value)
96 | super().__setattr__(name, value)
97 |
98 | def __setitem__(self, key, value):
99 | # Will raise a KeyException if needed
100 | super().__setitem__(key, value)
101 | # Don't call self.__setattr__ to avoid recursion errors
102 | super().__setattr__(key, value)
103 |
104 | def to_tuple(self) -> Tuple[Any]:
105 | """
106 | Convert self to a tuple containing all the attributes/keys that are not `None`.
107 | """
108 | return tuple(self[k] for k in self.keys())
109 |
--------------------------------------------------------------------------------
/diffusers_official/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from ..utils import (
17 | OptionalDependencyNotAvailable,
18 | is_flax_available,
19 | is_scipy_available,
20 | is_torch_available,
21 | is_torchsde_available,
22 | )
23 |
24 |
25 | try:
26 | if not is_torch_available():
27 | raise OptionalDependencyNotAvailable()
28 | except OptionalDependencyNotAvailable:
29 | from ..utils.dummy_pt_objects import * # noqa F403
30 | else:
31 | from .scheduling_consistency_models import CMStochasticIterativeScheduler
32 | from .scheduling_ddim import DDIMScheduler
33 | from .scheduling_ddim_inverse import DDIMInverseScheduler
34 | from .scheduling_ddim_parallel import DDIMParallelScheduler
35 | from .scheduling_ddpm import DDPMScheduler
36 | from .scheduling_ddpm_parallel import DDPMParallelScheduler
37 | from .scheduling_deis_multistep import DEISMultistepScheduler
38 | from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
39 | from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
40 | from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
41 | from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
42 | from .scheduling_euler_discrete import EulerDiscreteScheduler
43 | from .scheduling_heun_discrete import HeunDiscreteScheduler
44 | from .scheduling_ipndm import IPNDMScheduler
45 | from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
46 | from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
47 | from .scheduling_karras_ve import KarrasVeScheduler
48 | from .scheduling_pndm import PNDMScheduler
49 | from .scheduling_repaint import RePaintScheduler
50 | from .scheduling_sde_ve import ScoreSdeVeScheduler
51 | from .scheduling_sde_vp import ScoreSdeVpScheduler
52 | from .scheduling_unclip import UnCLIPScheduler
53 | from .scheduling_unipc_multistep import UniPCMultistepScheduler
54 | from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
55 | from .scheduling_vq_diffusion import VQDiffusionScheduler
56 |
57 | try:
58 | if not is_flax_available():
59 | raise OptionalDependencyNotAvailable()
60 | except OptionalDependencyNotAvailable:
61 | from ..utils.dummy_flax_objects import * # noqa F403
62 | else:
63 | from .scheduling_ddim_flax import FlaxDDIMScheduler
64 | from .scheduling_ddpm_flax import FlaxDDPMScheduler
65 | from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
66 | from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
67 | from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
68 | from .scheduling_pndm_flax import FlaxPNDMScheduler
69 | from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
70 | from .scheduling_utils_flax import (
71 | FlaxKarrasDiffusionSchedulers,
72 | FlaxSchedulerMixin,
73 | FlaxSchedulerOutput,
74 | broadcast_to_shape_from_left,
75 | )
76 |
77 |
78 | try:
79 | if not (is_torch_available() and is_scipy_available()):
80 | raise OptionalDependencyNotAvailable()
81 | except OptionalDependencyNotAvailable:
82 | from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
83 | else:
84 | from .scheduling_lms_discrete import LMSDiscreteScheduler
85 |
86 | try:
87 | if not (is_torch_available() and is_torchsde_available()):
88 | raise OptionalDependencyNotAvailable()
89 | except OptionalDependencyNotAvailable:
90 | from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
91 | else:
92 | from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler
93 |
--------------------------------------------------------------------------------
/diffusers_official/models/resnet_flax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import flax.linen as nn
15 | import jax
16 | import jax.numpy as jnp
17 |
18 |
19 | class FlaxUpsample2D(nn.Module):
20 | out_channels: int
21 | dtype: jnp.dtype = jnp.float32
22 |
23 | def setup(self):
24 | self.conv = nn.Conv(
25 | self.out_channels,
26 | kernel_size=(3, 3),
27 | strides=(1, 1),
28 | padding=((1, 1), (1, 1)),
29 | dtype=self.dtype,
30 | )
31 |
32 | def __call__(self, hidden_states):
33 | batch, height, width, channels = hidden_states.shape
34 | hidden_states = jax.image.resize(
35 | hidden_states,
36 | shape=(batch, height * 2, width * 2, channels),
37 | method="nearest",
38 | )
39 | hidden_states = self.conv(hidden_states)
40 | return hidden_states
41 |
42 |
43 | class FlaxDownsample2D(nn.Module):
44 | out_channels: int
45 | dtype: jnp.dtype = jnp.float32
46 |
47 | def setup(self):
48 | self.conv = nn.Conv(
49 | self.out_channels,
50 | kernel_size=(3, 3),
51 | strides=(2, 2),
52 | padding=((1, 1), (1, 1)), # padding="VALID",
53 | dtype=self.dtype,
54 | )
55 |
56 | def __call__(self, hidden_states):
57 | # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
58 | # hidden_states = jnp.pad(hidden_states, pad_width=pad)
59 | hidden_states = self.conv(hidden_states)
60 | return hidden_states
61 |
62 |
63 | class FlaxResnetBlock2D(nn.Module):
64 | in_channels: int
65 | out_channels: int = None
66 | dropout_prob: float = 0.0
67 | use_nin_shortcut: bool = None
68 | dtype: jnp.dtype = jnp.float32
69 |
70 | def setup(self):
71 | out_channels = self.in_channels if self.out_channels is None else self.out_channels
72 |
73 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
74 | self.conv1 = nn.Conv(
75 | out_channels,
76 | kernel_size=(3, 3),
77 | strides=(1, 1),
78 | padding=((1, 1), (1, 1)),
79 | dtype=self.dtype,
80 | )
81 |
82 | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
83 |
84 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
85 | self.dropout = nn.Dropout(self.dropout_prob)
86 | self.conv2 = nn.Conv(
87 | out_channels,
88 | kernel_size=(3, 3),
89 | strides=(1, 1),
90 | padding=((1, 1), (1, 1)),
91 | dtype=self.dtype,
92 | )
93 |
94 | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
95 |
96 | self.conv_shortcut = None
97 | if use_nin_shortcut:
98 | self.conv_shortcut = nn.Conv(
99 | out_channels,
100 | kernel_size=(1, 1),
101 | strides=(1, 1),
102 | padding="VALID",
103 | dtype=self.dtype,
104 | )
105 |
106 | def __call__(self, hidden_states, temb, deterministic=True):
107 | residual = hidden_states
108 | hidden_states = self.norm1(hidden_states)
109 | hidden_states = nn.swish(hidden_states)
110 | hidden_states = self.conv1(hidden_states)
111 |
112 | temb = self.time_emb_proj(nn.swish(temb))
113 | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
114 | hidden_states = hidden_states + temb
115 |
116 | hidden_states = self.norm2(hidden_states)
117 | hidden_states = nn.swish(hidden_states)
118 | hidden_states = self.dropout(hidden_states, deterministic)
119 | hidden_states = self.conv2(hidden_states)
120 |
121 | if self.conv_shortcut is not None:
122 | residual = self.conv_shortcut(residual)
123 |
124 | return hidden_states + residual
125 |
--------------------------------------------------------------------------------
/diffusers_official/models/modeling_flax_pytorch_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch - Flax general utilities."""
16 | import re
17 |
18 | import jax.numpy as jnp
19 | from flax.traverse_util import flatten_dict, unflatten_dict
20 | from jax.random import PRNGKey
21 |
22 | from ..utils import logging
23 |
24 |
25 | logger = logging.get_logger(__name__)
26 |
27 |
28 | def rename_key(key):
29 | regex = r"\w+[.]\d+"
30 | pats = re.findall(regex, key)
31 | for pat in pats:
32 | key = key.replace(pat, "_".join(pat.split(".")))
33 | return key
34 |
35 |
36 | #####################
37 | # PyTorch => Flax #
38 | #####################
39 |
40 |
41 | # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
42 | # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
43 | def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
44 | """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45 |
46 | # conv norm or layer norm
47 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
48 | if (
49 | any("norm" in str_ for str_ in pt_tuple_key)
50 | and (pt_tuple_key[-1] == "bias")
51 | and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
52 | and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
53 | ):
54 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
55 | return renamed_pt_tuple_key, pt_tensor
56 | elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
57 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
58 | return renamed_pt_tuple_key, pt_tensor
59 |
60 | # embedding
61 | if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
62 | pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
63 | return renamed_pt_tuple_key, pt_tensor
64 |
65 | # conv layer
66 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
67 | if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
68 | pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
69 | return renamed_pt_tuple_key, pt_tensor
70 |
71 | # linear layer
72 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
73 | if pt_tuple_key[-1] == "weight":
74 | pt_tensor = pt_tensor.T
75 | return renamed_pt_tuple_key, pt_tensor
76 |
77 | # old PyTorch layer norm weight
78 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
79 | if pt_tuple_key[-1] == "gamma":
80 | return renamed_pt_tuple_key, pt_tensor
81 |
82 | # old PyTorch layer norm bias
83 | renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
84 | if pt_tuple_key[-1] == "beta":
85 | return renamed_pt_tuple_key, pt_tensor
86 |
87 | return pt_tuple_key, pt_tensor
88 |
89 |
90 | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
91 | # Step 1: Convert pytorch tensor to numpy
92 | pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
93 |
94 | # Step 2: Since the model is stateless, get random Flax params
95 | random_flax_params = flax_model.init_weights(PRNGKey(init_key))
96 |
97 | random_flax_state_dict = flatten_dict(random_flax_params)
98 | flax_state_dict = {}
99 |
100 | # Need to change some parameters name to match Flax names
101 | for pt_key, pt_tensor in pt_state_dict.items():
102 | renamed_pt_key = rename_key(pt_key)
103 | pt_tuple_key = tuple(renamed_pt_key.split("."))
104 |
105 | # Correctly rename weight parameters
106 | flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
107 |
108 | if flax_key in random_flax_state_dict:
109 | if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
110 | raise ValueError(
111 | f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
112 | f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
113 | )
114 |
115 | # also add unexpected weight so that warning is thrown
116 | flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
117 |
118 | return unflatten_dict(flax_state_dict)
119 |
--------------------------------------------------------------------------------
/diffusers_official/models/cross_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from ..utils import deprecate
15 | from .attention_processor import ( # noqa: F401
16 | Attention,
17 | AttentionProcessor,
18 | AttnAddedKVProcessor,
19 | AttnProcessor2_0,
20 | LoRAAttnProcessor,
21 | LoRALinearLayer,
22 | LoRAXFormersAttnProcessor,
23 | SlicedAttnAddedKVProcessor,
24 | SlicedAttnProcessor,
25 | XFormersAttnProcessor,
26 | )
27 | from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
28 |
29 |
30 | deprecate(
31 | "cross_attention",
32 | "0.20.0",
33 | "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
34 | standard_warn=False,
35 | )
36 |
37 |
38 | AttnProcessor = AttentionProcessor
39 |
40 |
41 | class CrossAttention(Attention):
42 | def __init__(self, *args, **kwargs):
43 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
44 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
45 | super().__init__(*args, **kwargs)
46 |
47 |
48 | class CrossAttnProcessor(AttnProcessorRename):
49 | def __init__(self, *args, **kwargs):
50 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
51 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
52 | super().__init__(*args, **kwargs)
53 |
54 |
55 | class LoRACrossAttnProcessor(LoRAAttnProcessor):
56 | def __init__(self, *args, **kwargs):
57 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
58 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
59 | super().__init__(*args, **kwargs)
60 |
61 |
62 | class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
63 | def __init__(self, *args, **kwargs):
64 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
65 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
66 | super().__init__(*args, **kwargs)
67 |
68 |
69 | class XFormersCrossAttnProcessor(XFormersAttnProcessor):
70 | def __init__(self, *args, **kwargs):
71 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
72 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
73 | super().__init__(*args, **kwargs)
74 |
75 |
76 | class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
77 | def __init__(self, *args, **kwargs):
78 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
79 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
80 | super().__init__(*args, **kwargs)
81 |
82 |
83 | class SlicedCrossAttnProcessor(SlicedAttnProcessor):
84 | def __init__(self, *args, **kwargs):
85 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
86 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
87 | super().__init__(*args, **kwargs)
88 |
89 |
90 | class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
91 | def __init__(self, *args, **kwargs):
92 | deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
93 | deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
94 | super().__init__(*args, **kwargs)
95 |
--------------------------------------------------------------------------------
/diffusers_official/utils/dummy_flax_objects.py:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by the command `make fix-copies`, do not edit.
2 | from ..utils import DummyObject, requires_backends
3 |
4 |
5 | class FlaxControlNetModel(metaclass=DummyObject):
6 | _backends = ["flax"]
7 |
8 | def __init__(self, *args, **kwargs):
9 | requires_backends(self, ["flax"])
10 |
11 | @classmethod
12 | def from_config(cls, *args, **kwargs):
13 | requires_backends(cls, ["flax"])
14 |
15 | @classmethod
16 | def from_pretrained(cls, *args, **kwargs):
17 | requires_backends(cls, ["flax"])
18 |
19 |
20 | class FlaxModelMixin(metaclass=DummyObject):
21 | _backends = ["flax"]
22 |
23 | def __init__(self, *args, **kwargs):
24 | requires_backends(self, ["flax"])
25 |
26 | @classmethod
27 | def from_config(cls, *args, **kwargs):
28 | requires_backends(cls, ["flax"])
29 |
30 | @classmethod
31 | def from_pretrained(cls, *args, **kwargs):
32 | requires_backends(cls, ["flax"])
33 |
34 |
35 | class FlaxUNet2DConditionModel(metaclass=DummyObject):
36 | _backends = ["flax"]
37 |
38 | def __init__(self, *args, **kwargs):
39 | requires_backends(self, ["flax"])
40 |
41 | @classmethod
42 | def from_config(cls, *args, **kwargs):
43 | requires_backends(cls, ["flax"])
44 |
45 | @classmethod
46 | def from_pretrained(cls, *args, **kwargs):
47 | requires_backends(cls, ["flax"])
48 |
49 |
50 | class FlaxAutoencoderKL(metaclass=DummyObject):
51 | _backends = ["flax"]
52 |
53 | def __init__(self, *args, **kwargs):
54 | requires_backends(self, ["flax"])
55 |
56 | @classmethod
57 | def from_config(cls, *args, **kwargs):
58 | requires_backends(cls, ["flax"])
59 |
60 | @classmethod
61 | def from_pretrained(cls, *args, **kwargs):
62 | requires_backends(cls, ["flax"])
63 |
64 |
65 | class FlaxDiffusionPipeline(metaclass=DummyObject):
66 | _backends = ["flax"]
67 |
68 | def __init__(self, *args, **kwargs):
69 | requires_backends(self, ["flax"])
70 |
71 | @classmethod
72 | def from_config(cls, *args, **kwargs):
73 | requires_backends(cls, ["flax"])
74 |
75 | @classmethod
76 | def from_pretrained(cls, *args, **kwargs):
77 | requires_backends(cls, ["flax"])
78 |
79 |
80 | class FlaxDDIMScheduler(metaclass=DummyObject):
81 | _backends = ["flax"]
82 |
83 | def __init__(self, *args, **kwargs):
84 | requires_backends(self, ["flax"])
85 |
86 | @classmethod
87 | def from_config(cls, *args, **kwargs):
88 | requires_backends(cls, ["flax"])
89 |
90 | @classmethod
91 | def from_pretrained(cls, *args, **kwargs):
92 | requires_backends(cls, ["flax"])
93 |
94 |
95 | class FlaxDDPMScheduler(metaclass=DummyObject):
96 | _backends = ["flax"]
97 |
98 | def __init__(self, *args, **kwargs):
99 | requires_backends(self, ["flax"])
100 |
101 | @classmethod
102 | def from_config(cls, *args, **kwargs):
103 | requires_backends(cls, ["flax"])
104 |
105 | @classmethod
106 | def from_pretrained(cls, *args, **kwargs):
107 | requires_backends(cls, ["flax"])
108 |
109 |
110 | class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
111 | _backends = ["flax"]
112 |
113 | def __init__(self, *args, **kwargs):
114 | requires_backends(self, ["flax"])
115 |
116 | @classmethod
117 | def from_config(cls, *args, **kwargs):
118 | requires_backends(cls, ["flax"])
119 |
120 | @classmethod
121 | def from_pretrained(cls, *args, **kwargs):
122 | requires_backends(cls, ["flax"])
123 |
124 |
125 | class FlaxKarrasVeScheduler(metaclass=DummyObject):
126 | _backends = ["flax"]
127 |
128 | def __init__(self, *args, **kwargs):
129 | requires_backends(self, ["flax"])
130 |
131 | @classmethod
132 | def from_config(cls, *args, **kwargs):
133 | requires_backends(cls, ["flax"])
134 |
135 | @classmethod
136 | def from_pretrained(cls, *args, **kwargs):
137 | requires_backends(cls, ["flax"])
138 |
139 |
140 | class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
141 | _backends = ["flax"]
142 |
143 | def __init__(self, *args, **kwargs):
144 | requires_backends(self, ["flax"])
145 |
146 | @classmethod
147 | def from_config(cls, *args, **kwargs):
148 | requires_backends(cls, ["flax"])
149 |
150 | @classmethod
151 | def from_pretrained(cls, *args, **kwargs):
152 | requires_backends(cls, ["flax"])
153 |
154 |
155 | class FlaxPNDMScheduler(metaclass=DummyObject):
156 | _backends = ["flax"]
157 |
158 | def __init__(self, *args, **kwargs):
159 | requires_backends(self, ["flax"])
160 |
161 | @classmethod
162 | def from_config(cls, *args, **kwargs):
163 | requires_backends(cls, ["flax"])
164 |
165 | @classmethod
166 | def from_pretrained(cls, *args, **kwargs):
167 | requires_backends(cls, ["flax"])
168 |
169 |
170 | class FlaxSchedulerMixin(metaclass=DummyObject):
171 | _backends = ["flax"]
172 |
173 | def __init__(self, *args, **kwargs):
174 | requires_backends(self, ["flax"])
175 |
176 | @classmethod
177 | def from_config(cls, *args, **kwargs):
178 | requires_backends(cls, ["flax"])
179 |
180 | @classmethod
181 | def from_pretrained(cls, *args, **kwargs):
182 | requires_backends(cls, ["flax"])
183 |
184 |
185 | class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
186 | _backends = ["flax"]
187 |
188 | def __init__(self, *args, **kwargs):
189 | requires_backends(self, ["flax"])
190 |
191 | @classmethod
192 | def from_config(cls, *args, **kwargs):
193 | requires_backends(cls, ["flax"])
194 |
195 | @classmethod
196 | def from_pretrained(cls, *args, **kwargs):
197 | requires_backends(cls, ["flax"])
198 |
--------------------------------------------------------------------------------
/diffusers_official/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.18.2"
2 |
3 | from .configuration_utils import ConfigMixin
4 | from .utils import (
5 | OptionalDependencyNotAvailable,
6 | is_flax_available,
7 | is_inflect_available,
8 | is_invisible_watermark_available,
9 | is_k_diffusion_available,
10 | is_k_diffusion_version,
11 | is_librosa_available,
12 | is_note_seq_available,
13 | is_onnx_available,
14 | is_scipy_available,
15 | is_torch_available,
16 | is_torchsde_available,
17 | is_transformers_available,
18 | is_transformers_version,
19 | is_unidecode_available,
20 | logging,
21 | )
22 |
23 |
24 | try:
25 | if not is_onnx_available():
26 | raise OptionalDependencyNotAvailable()
27 | except OptionalDependencyNotAvailable:
28 | from .utils.dummy_onnx_objects import * # noqa F403
29 | else:
30 | from .pipelines import OnnxRuntimeModel
31 |
32 | try:
33 | if not is_torch_available():
34 | raise OptionalDependencyNotAvailable()
35 | except OptionalDependencyNotAvailable:
36 | from .utils.dummy_pt_objects import * # noqa F403
37 | else:
38 | from .models import (
39 | AutoencoderKL,
40 | ControlNetModel,
41 | ModelMixin,
42 | PriorTransformer,
43 | T5FilmDecoder,
44 | Transformer2DModel,
45 | UNet1DModel,
46 | UNet2DConditionModel,
47 | UNet2DModel,
48 | UNet3DConditionModel,
49 | VQModel,
50 | )
51 | from .optimization import (
52 | get_constant_schedule,
53 | get_constant_schedule_with_warmup,
54 | get_cosine_schedule_with_warmup,
55 | get_cosine_with_hard_restarts_schedule_with_warmup,
56 | get_linear_schedule_with_warmup,
57 | get_polynomial_decay_schedule_with_warmup,
58 | get_scheduler,
59 | )
60 | from .schedulers import (
61 | CMStochasticIterativeScheduler,
62 | DDIMInverseScheduler,
63 | DDIMParallelScheduler,
64 | DDIMScheduler,
65 | DDPMParallelScheduler,
66 | DDPMScheduler,
67 | DEISMultistepScheduler,
68 | DPMSolverMultistepInverseScheduler,
69 | DPMSolverMultistepScheduler,
70 | DPMSolverSinglestepScheduler,
71 | EulerAncestralDiscreteScheduler,
72 | EulerDiscreteScheduler,
73 | HeunDiscreteScheduler,
74 | IPNDMScheduler,
75 | KarrasVeScheduler,
76 | KDPM2AncestralDiscreteScheduler,
77 | KDPM2DiscreteScheduler,
78 | PNDMScheduler,
79 | RePaintScheduler,
80 | SchedulerMixin,
81 | ScoreSdeVeScheduler,
82 | UnCLIPScheduler,
83 | UniPCMultistepScheduler,
84 | VQDiffusionScheduler,
85 | )
86 | from .training_utils import EMAModel
87 |
88 | try:
89 | if not (is_torch_available() and is_scipy_available()):
90 | raise OptionalDependencyNotAvailable()
91 | except OptionalDependencyNotAvailable:
92 | from .utils.dummy_torch_and_scipy_objects import * # noqa F403
93 | else:
94 | from .schedulers import LMSDiscreteScheduler
95 |
96 | try:
97 | if not (is_torch_available() and is_torchsde_available()):
98 | raise OptionalDependencyNotAvailable()
99 | except OptionalDependencyNotAvailable:
100 | from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
101 | else:
102 | from .schedulers import DPMSolverSDEScheduler
103 |
104 | try:
105 | if not (is_torch_available() and is_transformers_available()):
106 | raise OptionalDependencyNotAvailable()
107 | except OptionalDependencyNotAvailable:
108 | from .utils.dummy_torch_and_transformers_objects import * # noqa F403
109 |
110 | try:
111 | if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
112 | raise OptionalDependencyNotAvailable()
113 | except OptionalDependencyNotAvailable:
114 | from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
115 | else:
116 | from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
117 |
118 | try:
119 | if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
120 | raise OptionalDependencyNotAvailable()
121 | except OptionalDependencyNotAvailable:
122 | from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
123 |
124 | try:
125 | if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
126 | raise OptionalDependencyNotAvailable()
127 | except OptionalDependencyNotAvailable:
128 | from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
129 |
130 | try:
131 | if not (is_torch_available() and is_librosa_available()):
132 | raise OptionalDependencyNotAvailable()
133 | except OptionalDependencyNotAvailable:
134 | from .utils.dummy_torch_and_librosa_objects import * # noqa F403
135 |
136 | try:
137 | if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
138 | raise OptionalDependencyNotAvailable()
139 | except OptionalDependencyNotAvailable:
140 | from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
141 |
142 | try:
143 | if not is_flax_available():
144 | raise OptionalDependencyNotAvailable()
145 | except OptionalDependencyNotAvailable:
146 | from .utils.dummy_flax_objects import * # noqa F403
147 | else:
148 | from .models.controlnet_flax import FlaxControlNetModel
149 | from .models.modeling_flax_utils import FlaxModelMixin
150 | from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
151 | from .models.vae_flax import FlaxAutoencoderKL
152 | from .pipelines import FlaxDiffusionPipeline
153 | from .schedulers import (
154 | FlaxDDIMScheduler,
155 | FlaxDDPMScheduler,
156 | FlaxDPMSolverMultistepScheduler,
157 | FlaxKarrasVeScheduler,
158 | FlaxLMSDiscreteScheduler,
159 | FlaxPNDMScheduler,
160 | FlaxSchedulerMixin,
161 | FlaxScoreSdeVeScheduler,
162 | )
163 |
164 |
165 | try:
166 | if not (is_flax_available() and is_transformers_available()):
167 | raise OptionalDependencyNotAvailable()
168 | except OptionalDependencyNotAvailable:
169 | from .utils.dummy_flax_and_transformers_objects import * # noqa F403
170 |
171 | try:
172 | if not (is_note_seq_available()):
173 | raise OptionalDependencyNotAvailable()
174 | except OptionalDependencyNotAvailable:
175 | from .utils.dummy_note_seq_objects import * # noqa F403
176 |
--------------------------------------------------------------------------------
/diffusers_official/schedulers/scheduling_ipndm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import numpy as np
19 | import torch
20 |
21 | from ..configuration_utils import ConfigMixin, register_to_config
22 | from .scheduling_utils import SchedulerMixin, SchedulerOutput
23 |
24 |
25 | class IPNDMScheduler(SchedulerMixin, ConfigMixin):
26 | """
27 | Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion
28 | [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296)
29 |
30 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
31 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
32 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
33 | [`~SchedulerMixin.from_pretrained`] functions.
34 |
35 | For more details, see the original paper: https://arxiv.org/abs/2202.09778
36 |
37 | Args:
38 | num_train_timesteps (`int`): number of diffusion steps used to train the model.
39 | """
40 |
41 | order = 1
42 |
43 | @register_to_config
44 | def __init__(
45 | self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
46 | ):
47 | # set `betas`, `alphas`, `timesteps`
48 | self.set_timesteps(num_train_timesteps)
49 |
50 | # standard deviation of the initial noise distribution
51 | self.init_noise_sigma = 1.0
52 |
53 | # For now we only support F-PNDM, i.e. the runge-kutta method
54 | # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
55 | # mainly at formula (9), (12), (13) and the Algorithm 2.
56 | self.pndm_order = 4
57 |
58 | # running values
59 | self.ets = []
60 |
61 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
62 | """
63 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
64 |
65 | Args:
66 | num_inference_steps (`int`):
67 | the number of diffusion steps used when generating samples with a pre-trained model.
68 | """
69 | self.num_inference_steps = num_inference_steps
70 | steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
71 | steps = torch.cat([steps, torch.tensor([0.0])])
72 |
73 | if self.config.trained_betas is not None:
74 | self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
75 | else:
76 | self.betas = torch.sin(steps * math.pi / 2) ** 2
77 |
78 | self.alphas = (1.0 - self.betas**2) ** 0.5
79 |
80 | timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
81 | self.timesteps = timesteps.to(device)
82 |
83 | self.ets = []
84 |
85 | def step(
86 | self,
87 | model_output: torch.FloatTensor,
88 | timestep: int,
89 | sample: torch.FloatTensor,
90 | return_dict: bool = True,
91 | ) -> Union[SchedulerOutput, Tuple]:
92 | """
93 | Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
94 | times to approximate the solution.
95 |
96 | Args:
97 | model_output (`torch.FloatTensor`): direct output from learned diffusion model.
98 | timestep (`int`): current discrete timestep in the diffusion chain.
99 | sample (`torch.FloatTensor`):
100 | current instance of sample being created by diffusion process.
101 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
102 |
103 | Returns:
104 | [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
105 | True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
106 |
107 | """
108 | if self.num_inference_steps is None:
109 | raise ValueError(
110 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
111 | )
112 |
113 | timestep_index = (self.timesteps == timestep).nonzero().item()
114 | prev_timestep_index = timestep_index + 1
115 |
116 | ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
117 | self.ets.append(ets)
118 |
119 | if len(self.ets) == 1:
120 | ets = self.ets[-1]
121 | elif len(self.ets) == 2:
122 | ets = (3 * self.ets[-1] - self.ets[-2]) / 2
123 | elif len(self.ets) == 3:
124 | ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
125 | else:
126 | ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
127 |
128 | prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
129 |
130 | if not return_dict:
131 | return (prev_sample,)
132 |
133 | return SchedulerOutput(prev_sample=prev_sample)
134 |
135 | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
136 | """
137 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
138 | current timestep.
139 |
140 | Args:
141 | sample (`torch.FloatTensor`): input sample
142 |
143 | Returns:
144 | `torch.FloatTensor`: scaled input sample
145 | """
146 | return sample
147 |
148 | def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
149 | alpha = self.alphas[timestep_index]
150 | sigma = self.betas[timestep_index]
151 |
152 | next_alpha = self.alphas[prev_timestep_index]
153 | next_sigma = self.betas[prev_timestep_index]
154 |
155 | pred = (sample - sigma * ets) / max(alpha, 1e-8)
156 | prev_sample = next_alpha * pred + ets * next_sigma
157 |
158 | return prev_sample
159 |
160 | def __len__(self):
161 | return self.config.num_train_timesteps
162 |
--------------------------------------------------------------------------------
/diffusers_official/models/vq_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Optional, Tuple, Union
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 | from ..configuration_utils import ConfigMixin, register_to_config
21 | from ..utils import BaseOutput, apply_forward_hook
22 | from .modeling_utils import ModelMixin
23 | from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
24 |
25 |
26 | @dataclass
27 | class VQEncoderOutput(BaseOutput):
28 | """
29 | Output of VQModel encoding method.
30 |
31 | Args:
32 | latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33 | The encoded output sample from the last layer of the model.
34 | """
35 |
36 | latents: torch.FloatTensor
37 |
38 |
39 | class VQModel(ModelMixin, ConfigMixin):
40 | r"""
41 | A VQ-VAE model for decoding latent representations.
42 |
43 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
44 | for all models (such as downloading or saving).
45 |
46 | Parameters:
47 | in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
48 | out_channels (int, *optional*, defaults to 3): Number of channels in the output.
49 | down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
50 | Tuple of downsample block types.
51 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
52 | Tuple of upsample block types.
53 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
54 | Tuple of block output channels.
55 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
56 | latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
57 | sample_size (`int`, *optional*, defaults to `32`): Sample input size.
58 | num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
59 | vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
60 | scaling_factor (`float`, *optional*, defaults to `0.18215`):
61 | The component-wise standard deviation of the trained latent space computed using the first batch of the
62 | training set. This is used to scale the latent space to have unit variance when training the diffusion
63 | model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
64 | diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
65 | / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
66 | Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67 | """
68 |
69 | @register_to_config
70 | def __init__(
71 | self,
72 | in_channels: int = 3,
73 | out_channels: int = 3,
74 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
75 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
76 | block_out_channels: Tuple[int] = (64,),
77 | layers_per_block: int = 1,
78 | act_fn: str = "silu",
79 | latent_channels: int = 3,
80 | sample_size: int = 32,
81 | num_vq_embeddings: int = 256,
82 | norm_num_groups: int = 32,
83 | vq_embed_dim: Optional[int] = None,
84 | scaling_factor: float = 0.18215,
85 | norm_type: str = "group", # group, spatial
86 | ):
87 | super().__init__()
88 |
89 | # pass init params to Encoder
90 | self.encoder = Encoder(
91 | in_channels=in_channels,
92 | out_channels=latent_channels,
93 | down_block_types=down_block_types,
94 | block_out_channels=block_out_channels,
95 | layers_per_block=layers_per_block,
96 | act_fn=act_fn,
97 | norm_num_groups=norm_num_groups,
98 | double_z=False,
99 | )
100 |
101 | vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
102 |
103 | self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
104 | self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
105 | self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
106 |
107 | # pass init params to Decoder
108 | self.decoder = Decoder(
109 | in_channels=latent_channels,
110 | out_channels=out_channels,
111 | up_block_types=up_block_types,
112 | block_out_channels=block_out_channels,
113 | layers_per_block=layers_per_block,
114 | act_fn=act_fn,
115 | norm_num_groups=norm_num_groups,
116 | norm_type=norm_type,
117 | )
118 |
119 | @apply_forward_hook
120 | def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
121 | h = self.encoder(x)
122 | h = self.quant_conv(h)
123 |
124 | if not return_dict:
125 | return (h,)
126 |
127 | return VQEncoderOutput(latents=h)
128 |
129 | @apply_forward_hook
130 | def decode(
131 | self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
132 | ) -> Union[DecoderOutput, torch.FloatTensor]:
133 | # also go through quantization layer
134 | if not force_not_quantize:
135 | quant, emb_loss, info = self.quantize(h)
136 | else:
137 | quant = h
138 | quant2 = self.post_quant_conv(quant)
139 | dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
140 |
141 | if not return_dict:
142 | return (dec,)
143 |
144 | return DecoderOutput(sample=dec)
145 |
146 | def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
147 | r"""
148 | The [`VQModel`] forward method.
149 |
150 | Args:
151 | sample (`torch.FloatTensor`): Input sample.
152 | return_dict (`bool`, *optional*, defaults to `True`):
153 | Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
154 |
155 | Returns:
156 | [`~models.vq_model.VQEncoderOutput`] or `tuple`:
157 | If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
158 | is returned.
159 | """
160 | x = sample
161 | h = self.encode(x).latents
162 | dec = self.decode(h).sample
163 |
164 | if not return_dict:
165 | return (dec,)
166 |
167 | return DecoderOutput(sample=dec)
168 |
--------------------------------------------------------------------------------
/diffusers_official/models/modeling_pytorch_flax_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch - Flax general utilities."""
16 |
17 | from pickle import UnpicklingError
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | from flax.serialization import from_bytes
23 | from flax.traverse_util import flatten_dict
24 |
25 | from ..utils import logging
26 |
27 |
28 | logger = logging.get_logger(__name__)
29 |
30 |
31 | #####################
32 | # Flax => PyTorch #
33 | #####################
34 |
35 |
36 | # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
37 | def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
38 | try:
39 | with open(model_file, "rb") as flax_state_f:
40 | flax_state = from_bytes(None, flax_state_f.read())
41 | except UnpicklingError as e:
42 | try:
43 | with open(model_file) as f:
44 | if f.read().startswith("version"):
45 | raise OSError(
46 | "You seem to have cloned a repository without having git-lfs installed. Please"
47 | " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
48 | " folder you cloned."
49 | )
50 | else:
51 | raise ValueError from e
52 | except (UnicodeDecodeError, ValueError):
53 | raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
54 |
55 | return load_flax_weights_in_pytorch_model(pt_model, flax_state)
56 |
57 |
58 | def load_flax_weights_in_pytorch_model(pt_model, flax_state):
59 | """Load flax checkpoints in a PyTorch model"""
60 |
61 | try:
62 | import torch # noqa: F401
63 | except ImportError:
64 | logger.error(
65 | "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
66 | " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
67 | " instructions."
68 | )
69 | raise
70 |
71 | # check if we have bf16 weights
72 | is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
73 | if any(is_type_bf16):
74 | # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
75 |
76 | # and bf16 is not fully supported in PT yet.
77 | logger.warning(
78 | "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
79 | "before loading those in PyTorch model."
80 | )
81 | flax_state = jax.tree_util.tree_map(
82 | lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
83 | )
84 |
85 | pt_model.base_model_prefix = ""
86 |
87 | flax_state_dict = flatten_dict(flax_state, sep=".")
88 | pt_model_dict = pt_model.state_dict()
89 |
90 | # keep track of unexpected & missing keys
91 | unexpected_keys = []
92 | missing_keys = set(pt_model_dict.keys())
93 |
94 | for flax_key_tuple, flax_tensor in flax_state_dict.items():
95 | flax_key_tuple_array = flax_key_tuple.split(".")
96 |
97 | if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
98 | flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
99 | flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
100 | elif flax_key_tuple_array[-1] == "kernel":
101 | flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
102 | flax_tensor = flax_tensor.T
103 | elif flax_key_tuple_array[-1] == "scale":
104 | flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
105 |
106 | if "time_embedding" not in flax_key_tuple_array:
107 | for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
108 | flax_key_tuple_array[i] = (
109 | flax_key_tuple_string.replace("_0", ".0")
110 | .replace("_1", ".1")
111 | .replace("_2", ".2")
112 | .replace("_3", ".3")
113 | .replace("_4", ".4")
114 | .replace("_5", ".5")
115 | .replace("_6", ".6")
116 | .replace("_7", ".7")
117 | .replace("_8", ".8")
118 | .replace("_9", ".9")
119 | )
120 |
121 | flax_key = ".".join(flax_key_tuple_array)
122 |
123 | if flax_key in pt_model_dict:
124 | if flax_tensor.shape != pt_model_dict[flax_key].shape:
125 | raise ValueError(
126 | f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
127 | f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
128 | )
129 | else:
130 | # add weight to pytorch dict
131 | flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
132 | pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
133 | # remove from missing keys
134 | missing_keys.remove(flax_key)
135 | else:
136 | # weight is not expected by PyTorch model
137 | unexpected_keys.append(flax_key)
138 |
139 | pt_model.load_state_dict(pt_model_dict)
140 |
141 | # re-transform missing_keys to list
142 | missing_keys = list(missing_keys)
143 |
144 | if len(unexpected_keys) > 0:
145 | logger.warning(
146 | "Some weights of the Flax model were not used when initializing the PyTorch model"
147 | f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
148 | f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
149 | " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
150 | f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
151 | " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
152 | " FlaxBertForSequenceClassification model)."
153 | )
154 | if len(missing_keys) > 0:
155 | logger.warning(
156 | f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
157 | f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
158 | " use it for predictions and inference."
159 | )
160 |
161 | return pt_model
162 |
--------------------------------------------------------------------------------
/diffusers_official/models/dual_transformer_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Optional
15 |
16 | from torch import nn
17 |
18 | from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19 |
20 |
21 | class DualTransformer2DModel(nn.Module):
22 | """
23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24 |
25 | Parameters:
26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28 | in_channels (`int`, *optional*):
29 | Pass if the input is continuous. The number of channels in the input and output.
30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35 | `ImagePositionalEmbeddings`.
36 | num_vector_embeds (`int`, *optional*):
37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38 | Includes the class for the masked latent pixel.
39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43 | up to but not more than steps than `num_embeds_ada_norm`.
44 | attention_bias (`bool`, *optional*):
45 | Configure if the TransformerBlocks' attention should contain a bias parameter.
46 | """
47 |
48 | def __init__(
49 | self,
50 | num_attention_heads: int = 16,
51 | attention_head_dim: int = 88,
52 | in_channels: Optional[int] = None,
53 | num_layers: int = 1,
54 | dropout: float = 0.0,
55 | norm_num_groups: int = 32,
56 | cross_attention_dim: Optional[int] = None,
57 | attention_bias: bool = False,
58 | sample_size: Optional[int] = None,
59 | num_vector_embeds: Optional[int] = None,
60 | activation_fn: str = "geglu",
61 | num_embeds_ada_norm: Optional[int] = None,
62 | ):
63 | super().__init__()
64 | self.transformers = nn.ModuleList(
65 | [
66 | Transformer2DModel(
67 | num_attention_heads=num_attention_heads,
68 | attention_head_dim=attention_head_dim,
69 | in_channels=in_channels,
70 | num_layers=num_layers,
71 | dropout=dropout,
72 | norm_num_groups=norm_num_groups,
73 | cross_attention_dim=cross_attention_dim,
74 | attention_bias=attention_bias,
75 | sample_size=sample_size,
76 | num_vector_embeds=num_vector_embeds,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | )
80 | for _ in range(2)
81 | ]
82 | )
83 |
84 | # Variables that can be set by a pipeline:
85 |
86 | # The ratio of transformer1 to transformer2's output states to be combined during inference
87 | self.mix_ratio = 0.5
88 |
89 | # The shape of `encoder_hidden_states` is expected to be
90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91 | self.condition_lengths = [77, 257]
92 |
93 | # Which transformer to use to encode which condition.
94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95 | self.transformer_index_for_condition = [1, 0]
96 |
97 | def forward(
98 | self,
99 | hidden_states,
100 | encoder_hidden_states,
101 | timestep=None,
102 | attention_mask=None,
103 | cross_attention_kwargs=None,
104 | return_dict: bool = True,
105 | ):
106 | """
107 | Args:
108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110 | hidden_states
111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113 | self-attention.
114 | timestep ( `torch.long`, *optional*):
115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116 | attention_mask (`torch.FloatTensor`, *optional*):
117 | Optional attention mask to be applied in Attention
118 | return_dict (`bool`, *optional*, defaults to `True`):
119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120 |
121 | Returns:
122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124 | returning a tuple, the first element is the sample tensor.
125 | """
126 | input_states = hidden_states
127 |
128 | encoded_states = []
129 | tokens_start = 0
130 | # attention_mask is not used yet
131 | for i in range(2):
132 | # for each of the two transformers, pass the corresponding condition tokens
133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134 | transformer_index = self.transformer_index_for_condition[i]
135 | encoded_state = self.transformers[transformer_index](
136 | input_states,
137 | encoder_hidden_states=condition_state,
138 | timestep=timestep,
139 | cross_attention_kwargs=cross_attention_kwargs,
140 | return_dict=False,
141 | )[0]
142 | encoded_states.append(encoded_state - input_states)
143 | tokens_start += self.condition_lengths[i]
144 |
145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146 | output_states = output_states + input_states
147 |
148 | if not return_dict:
149 | return (output_states,)
150 |
151 | return Transformer2DModelOutput(sample=output_states)
152 |
--------------------------------------------------------------------------------
/scripts/models/dual_transformer_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Optional
15 |
16 | from torch import nn
17 |
18 | from scripts.models.transformer_2d import Transformer2DModel, Transformer2DModelOutput
19 |
20 |
21 | class DualTransformer2DModel(nn.Module):
22 | """
23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24 |
25 | Parameters:
26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28 | in_channels (`int`, *optional*):
29 | Pass if the input is continuous. The number of channels in the input and output.
30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35 | `ImagePositionalEmbeddings`.
36 | num_vector_embeds (`int`, *optional*):
37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38 | Includes the class for the masked latent pixel.
39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43 | up to but not more than steps than `num_embeds_ada_norm`.
44 | attention_bias (`bool`, *optional*):
45 | Configure if the TransformerBlocks' attention should contain a bias parameter.
46 | """
47 |
48 | def __init__(
49 | self,
50 | num_attention_heads: int = 16,
51 | attention_head_dim: int = 88,
52 | in_channels: Optional[int] = None,
53 | num_layers: int = 1,
54 | dropout: float = 0.0,
55 | norm_num_groups: int = 32,
56 | cross_attention_dim: Optional[int] = None,
57 | attention_bias: bool = False,
58 | sample_size: Optional[int] = None,
59 | num_vector_embeds: Optional[int] = None,
60 | activation_fn: str = "geglu",
61 | num_embeds_ada_norm: Optional[int] = None,
62 | ):
63 | super().__init__()
64 | self.transformers = nn.ModuleList(
65 | [
66 | Transformer2DModel(
67 | num_attention_heads=num_attention_heads,
68 | attention_head_dim=attention_head_dim,
69 | in_channels=in_channels,
70 | num_layers=num_layers,
71 | dropout=dropout,
72 | norm_num_groups=norm_num_groups,
73 | cross_attention_dim=cross_attention_dim,
74 | attention_bias=attention_bias,
75 | sample_size=sample_size,
76 | num_vector_embeds=num_vector_embeds,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | )
80 | for _ in range(2)
81 | ]
82 | )
83 |
84 | # Variables that can be set by a pipeline:
85 |
86 | # The ratio of transformer1 to transformer2's output states to be combined during inference
87 | self.mix_ratio = 0.5
88 |
89 | # The shape of `encoder_hidden_states` is expected to be
90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91 | self.condition_lengths = [77, 257]
92 |
93 | # Which transformer to use to encode which condition.
94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95 | self.transformer_index_for_condition = [1, 0]
96 |
97 | def forward(
98 | self,
99 | hidden_states,
100 | encoder_hidden_states,
101 | timestep=None,
102 | attention_mask=None,
103 | cross_attention_kwargs=None,
104 | return_dict: bool = True,
105 | ):
106 | """
107 | Args:
108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110 | hidden_states
111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113 | self-attention.
114 | timestep ( `torch.long`, *optional*):
115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116 | attention_mask (`torch.FloatTensor`, *optional*):
117 | Optional attention mask to be applied in Attention
118 | return_dict (`bool`, *optional*, defaults to `True`):
119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120 |
121 | Returns:
122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124 | returning a tuple, the first element is the sample tensor.
125 | """
126 | input_states = hidden_states
127 |
128 | encoded_states = []
129 | tokens_start = 0
130 | # attention_mask is not used yet
131 | for i in range(2):
132 | # for each of the two transformers, pass the corresponding condition tokens
133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134 | transformer_index = self.transformer_index_for_condition[i]
135 | encoded_state = self.transformers[transformer_index](
136 | input_states,
137 | encoder_hidden_states=condition_state,
138 | timestep=timestep,
139 | cross_attention_kwargs=cross_attention_kwargs,
140 | return_dict=False,
141 | )[0]
142 | encoded_states.append(encoded_state - input_states)
143 | tokens_start += self.condition_lengths[i]
144 |
145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146 | output_states = output_states + input_states
147 |
148 | if not return_dict:
149 | return (output_states,)
150 |
151 | return Transformer2DModelOutput(sample=output_states)
152 |
--------------------------------------------------------------------------------
/diffusers_official/models/transformer_temporal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Optional
16 |
17 | import torch
18 | from torch import nn
19 |
20 | from ..configuration_utils import ConfigMixin, register_to_config
21 | from ..utils import BaseOutput
22 | from .attention import BasicTransformerBlock
23 | from .modeling_utils import ModelMixin
24 |
25 |
26 | @dataclass
27 | class TransformerTemporalModelOutput(BaseOutput):
28 | """
29 | The output of [`TransformerTemporalModel`].
30 |
31 | Args:
32 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33 | The hidden states output conditioned on `encoder_hidden_states` input.
34 | """
35 |
36 | sample: torch.FloatTensor
37 |
38 |
39 | class TransformerTemporalModel(ModelMixin, ConfigMixin):
40 | """
41 | A Transformer model for video-like data.
42 |
43 | Parameters:
44 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46 | in_channels (`int`, *optional*):
47 | The number of channels in the input and output (specify if the input is **continuous**).
48 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52 | This is fixed during training since it is used to learn a number of position embeddings.
53 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
54 | attention_bias (`bool`, *optional*):
55 | Configure if the `TransformerBlock` attention should contain a bias parameter.
56 | double_self_attention (`bool`, *optional*):
57 | Configure if each `TransformerBlock` should contain two self-attention layers.
58 | """
59 |
60 | @register_to_config
61 | def __init__(
62 | self,
63 | num_attention_heads: int = 16,
64 | attention_head_dim: int = 88,
65 | in_channels: Optional[int] = None,
66 | out_channels: Optional[int] = None,
67 | num_layers: int = 1,
68 | dropout: float = 0.0,
69 | norm_num_groups: int = 32,
70 | cross_attention_dim: Optional[int] = None,
71 | attention_bias: bool = False,
72 | sample_size: Optional[int] = None,
73 | activation_fn: str = "geglu",
74 | norm_elementwise_affine: bool = True,
75 | double_self_attention: bool = True,
76 | ):
77 | super().__init__()
78 | self.num_attention_heads = num_attention_heads
79 | self.attention_head_dim = attention_head_dim
80 | inner_dim = num_attention_heads * attention_head_dim
81 |
82 | self.in_channels = in_channels
83 |
84 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85 | self.proj_in = nn.Linear(in_channels, inner_dim)
86 |
87 | # 3. Define transformers blocks
88 | self.transformer_blocks = nn.ModuleList(
89 | [
90 | BasicTransformerBlock(
91 | inner_dim,
92 | num_attention_heads,
93 | attention_head_dim,
94 | dropout=dropout,
95 | cross_attention_dim=cross_attention_dim,
96 | activation_fn=activation_fn,
97 | attention_bias=attention_bias,
98 | double_self_attention=double_self_attention,
99 | norm_elementwise_affine=norm_elementwise_affine,
100 | )
101 | for d in range(num_layers)
102 | ]
103 | )
104 |
105 | self.proj_out = nn.Linear(inner_dim, in_channels)
106 |
107 | def forward(
108 | self,
109 | hidden_states,
110 | encoder_hidden_states=None,
111 | timestep=None,
112 | class_labels=None,
113 | num_frames=1,
114 | cross_attention_kwargs=None,
115 | return_dict: bool = True,
116 | ):
117 | """
118 | The [`TransformerTemporal`] forward method.
119 |
120 | Args:
121 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
122 | Input hidden_states.
123 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125 | self-attention.
126 | timestep ( `torch.long`, *optional*):
127 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130 | `AdaLayerZeroNorm`.
131 | return_dict (`bool`, *optional*, defaults to `True`):
132 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
133 | tuple.
134 |
135 | Returns:
136 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
137 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
138 | returned, otherwise a `tuple` where the first element is the sample tensor.
139 | """
140 | # 1. Input
141 | batch_frames, channel, height, width = hidden_states.shape
142 | batch_size = batch_frames // num_frames
143 |
144 | residual = hidden_states
145 |
146 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
147 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
148 |
149 | hidden_states = self.norm(hidden_states)
150 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
151 |
152 | hidden_states = self.proj_in(hidden_states)
153 |
154 | # 2. Blocks
155 | for block in self.transformer_blocks:
156 | hidden_states = block(
157 | hidden_states,
158 | encoder_hidden_states=encoder_hidden_states,
159 | timestep=timestep,
160 | cross_attention_kwargs=cross_attention_kwargs,
161 | class_labels=class_labels,
162 | )
163 |
164 | # 3. Output
165 | hidden_states = self.proj_out(hidden_states)
166 | hidden_states = (
167 | hidden_states[None, None, :]
168 | .reshape(batch_size, height, width, channel, num_frames)
169 | .permute(0, 3, 4, 1, 2)
170 | .contiguous()
171 | )
172 | hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
173 |
174 | output = hidden_states + residual
175 |
176 | if not return_dict:
177 | return (output,)
178 |
179 | return TransformerTemporalModelOutput(sample=output)
180 |
--------------------------------------------------------------------------------
/share_btn.py:
--------------------------------------------------------------------------------
1 | community_icon_html = """"""
5 |
6 | loading_icon_html = """"""
7 |
8 | share_js = """async () => {
9 | async function uploadFile(file){
10 | const UPLOAD_URL = 'https://huggingface.co/uploads';
11 | const response = await fetch(UPLOAD_URL, {
12 | method: 'POST',
13 | headers: {
14 | 'Content-Type': file.type,
15 | 'X-Requested-With': 'XMLHttpRequest',
16 | },
17 | body: file, /// <- File inherits from Blob
18 | });
19 | const url = await response.text();
20 | return url;
21 | }
22 | async function getInputImageFile(imageEl){
23 | const res = await fetch(imageEl.src);
24 | const blob = await res.blob();
25 | const imageId = Date.now();
26 | const fileName = `rich-text-image-${{imageId}}.png`;
27 | return new File([blob], fileName, { type: 'image/png'});
28 | }
29 | const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
30 | const richEl = document.getElementById("rich-text-root");
31 | const data = richEl? richEl.contentDocument.body._data : {};
32 | const text_input = JSON.stringify(data);
33 | const negative_prompt = gradioEl.querySelector('#negative_prompt input').value;
34 | const seed = gradioEl.querySelector('#seed input').value;
35 | const richTextImg = gradioEl.querySelector('#rich-text-image img');
36 | const plainTextImg = gradioEl.querySelector('#plain-text-image img');
37 | const text_input_obj = JSON.parse(text_input);
38 | const plain_prompt = text_input_obj.ops.map(e=> e.insert).join('');
39 | const linkSrc = `https://huggingface.co/spaces/songweig/rich-text-to-image?prompt=${encodeURIComponent(text_input)}`;
40 |
41 | const titleTxt = `RT2I: ${plain_prompt.slice(0, 50)}...`;
42 | const shareBtnEl = gradioEl.querySelector('#share-btn');
43 | const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
44 | const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
45 | if(!richTextImg){
46 | return;
47 | };
48 | shareBtnEl.style.pointerEvents = 'none';
49 | shareIconEl.style.display = 'none';
50 | loadingIconEl.style.removeProperty('display');
51 |
52 | const richImgFile = await getInputImageFile(richTextImg);
53 | const plainImgFile = await getInputImageFile(plainTextImg);
54 | const richImgURL = await uploadFile(richImgFile);
55 | const plainImgURL = await uploadFile(plainImgFile);
56 |
57 | const descriptionMd = `
58 | ### Plain Prompt
59 | ${plain_prompt}
60 |
61 | 🔗 Shareable Link + Params: [here](${linkSrc})
62 |
63 | ### Rich Tech Image
64 |
65 |
66 | ### Plain Text Image
67 |
68 |
69 | `;
70 | const params = new URLSearchParams({
71 | title: titleTxt,
72 | description: descriptionMd,
73 | });
74 | const paramsStr = params.toString();
75 | window.open(`https://huggingface.co/spaces/songweig/rich-text-to-image/discussions/new?${paramsStr}`, '_blank');
76 | shareBtnEl.style.removeProperty('pointer-events');
77 | shareIconEl.style.removeProperty('display');
78 | loadingIconEl.style.display = 'none';
79 | }"""
80 |
81 | css = """
82 | #share-btn-container {
83 | display: flex;
84 | padding-left: 0.5rem !important;
85 | padding-right: 0.5rem !important;
86 | background-color: #000000;
87 | justify-content: center;
88 | align-items: center;
89 | border-radius: 9999px !important;
90 | width: 13rem;
91 | margin-top: 10px;
92 | margin-left: auto;
93 | flex: unset !important;
94 | }
95 | #share-btn {
96 | all: initial;
97 | color: #ffffff;
98 | font-weight: 600;
99 | cursor: pointer;
100 | font-family: 'IBM Plex Sans', sans-serif;
101 | margin-left: 0.5rem !important;
102 | padding-top: 0.25rem !important;
103 | padding-bottom: 0.25rem !important;
104 | right:0;
105 | }
106 | #share-btn * {
107 | all: unset !important;
108 | }
109 | #share-btn-container div:nth-child(-n+2){
110 | width: auto !important;
111 | min-height: 0px !important;
112 | }
113 | #share-btn-container .wrap {
114 | display: none !important;
115 | }
116 | """
117 |
--------------------------------------------------------------------------------
/diffusers_official/schedulers/scheduling_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import importlib
15 | import os
16 | from dataclasses import dataclass
17 | from enum import Enum
18 | from typing import Any, Dict, Optional, Union
19 |
20 | import torch
21 |
22 | from ..utils import BaseOutput
23 |
24 |
25 | SCHEDULER_CONFIG_NAME = "scheduler_config.json"
26 |
27 |
28 | # NOTE: We make this type an enum because it simplifies usage in docs and prevents
29 | # circular imports when used for `_compatibles` within the schedulers module.
30 | # When it's used as a type in pipelines, it really is a Union because the actual
31 | # scheduler instance is passed in.
32 | class KarrasDiffusionSchedulers(Enum):
33 | DDIMScheduler = 1
34 | DDPMScheduler = 2
35 | PNDMScheduler = 3
36 | LMSDiscreteScheduler = 4
37 | EulerDiscreteScheduler = 5
38 | HeunDiscreteScheduler = 6
39 | EulerAncestralDiscreteScheduler = 7
40 | DPMSolverMultistepScheduler = 8
41 | DPMSolverSinglestepScheduler = 9
42 | KDPM2DiscreteScheduler = 10
43 | KDPM2AncestralDiscreteScheduler = 11
44 | DEISMultistepScheduler = 12
45 | UniPCMultistepScheduler = 13
46 | DPMSolverSDEScheduler = 14
47 |
48 |
49 | @dataclass
50 | class SchedulerOutput(BaseOutput):
51 | """
52 | Base class for the scheduler's step function output.
53 |
54 | Args:
55 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
56 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
57 | denoising loop.
58 | """
59 |
60 | prev_sample: torch.FloatTensor
61 |
62 |
63 | class SchedulerMixin:
64 | """
65 | Mixin containing common functions for the schedulers.
66 |
67 | Class attributes:
68 | - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
69 | `from_config` can be used from a class different than the one used to save the config (should be overridden
70 | by parent class).
71 | """
72 |
73 | config_name = SCHEDULER_CONFIG_NAME
74 | _compatibles = []
75 | has_compatibles = True
76 |
77 | @classmethod
78 | def from_pretrained(
79 | cls,
80 | pretrained_model_name_or_path: Dict[str, Any] = None,
81 | subfolder: Optional[str] = None,
82 | return_unused_kwargs=False,
83 | **kwargs,
84 | ):
85 | r"""
86 | Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
87 |
88 | Parameters:
89 | pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
90 | Can be either:
91 |
92 | - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
93 | organization name, like `google/ddpm-celebahq-256`.
94 | - A path to a *directory* containing the schedluer configurations saved using
95 | [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
96 | subfolder (`str`, *optional*):
97 | In case the relevant files are located inside a subfolder of the model repo (either remote in
98 | huggingface.co or downloaded locally), you can specify the folder name here.
99 | return_unused_kwargs (`bool`, *optional*, defaults to `False`):
100 | Whether kwargs that are not consumed by the Python class should be returned or not.
101 | cache_dir (`Union[str, os.PathLike]`, *optional*):
102 | Path to a directory in which a downloaded pretrained model configuration should be cached if the
103 | standard cache should not be used.
104 | force_download (`bool`, *optional*, defaults to `False`):
105 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the
106 | cached versions if they exist.
107 | resume_download (`bool`, *optional*, defaults to `False`):
108 | Whether or not to delete incompletely received files. Will attempt to resume the download if such a
109 | file exists.
110 | proxies (`Dict[str, str]`, *optional*):
111 | A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
112 | 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
113 | output_loading_info(`bool`, *optional*, defaults to `False`):
114 | Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
115 | local_files_only(`bool`, *optional*, defaults to `False`):
116 | Whether or not to only look at local files (i.e., do not try to download the model).
117 | use_auth_token (`str` or *bool*, *optional*):
118 | The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
119 | when running `transformers-cli login` (stored in `~/.huggingface`).
120 | revision (`str`, *optional*, defaults to `"main"`):
121 | The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
122 | git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
123 | identifier allowed by git.
124 |
125 |
126 |
127 | It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
128 | models](https://huggingface.co/docs/hub/models-gated#gated-models).
129 |
130 |
131 |
132 |
133 |
134 | Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
135 | use this method in a firewalled environment.
136 |
137 |
138 |
139 | """
140 | config, kwargs, commit_hash = cls.load_config(
141 | pretrained_model_name_or_path=pretrained_model_name_or_path,
142 | subfolder=subfolder,
143 | return_unused_kwargs=True,
144 | return_commit_hash=True,
145 | **kwargs,
146 | )
147 | return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
148 |
149 | def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
150 | """
151 | Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
152 | [`~SchedulerMixin.from_pretrained`] class method.
153 |
154 | Args:
155 | save_directory (`str` or `os.PathLike`):
156 | Directory where the configuration JSON file will be saved (will be created if it does not exist).
157 | """
158 | self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
159 |
160 | @property
161 | def compatibles(self):
162 | """
163 | Returns all schedulers that are compatible with this scheduler
164 |
165 | Returns:
166 | `List[SchedulerMixin]`: List of compatible schedulers
167 | """
168 | return self._get_compatibles()
169 |
170 | @classmethod
171 | def _get_compatibles(cls):
172 | compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
173 | diffusers_library = importlib.import_module(__name__.split(".")[0])
174 | compatible_classes = [
175 | getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
176 | ]
177 | return compatible_classes
178 |
--------------------------------------------------------------------------------
/diffusers_official/pipelines/onnx_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import os
19 | import shutil
20 | from pathlib import Path
21 | from typing import Optional, Union
22 |
23 | import numpy as np
24 | from huggingface_hub import hf_hub_download
25 |
26 | from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
27 |
28 |
29 | if is_onnx_available():
30 | import onnxruntime as ort
31 |
32 |
33 | logger = logging.get_logger(__name__)
34 |
35 | ORT_TO_NP_TYPE = {
36 | "tensor(bool)": np.bool_,
37 | "tensor(int8)": np.int8,
38 | "tensor(uint8)": np.uint8,
39 | "tensor(int16)": np.int16,
40 | "tensor(uint16)": np.uint16,
41 | "tensor(int32)": np.int32,
42 | "tensor(uint32)": np.uint32,
43 | "tensor(int64)": np.int64,
44 | "tensor(uint64)": np.uint64,
45 | "tensor(float16)": np.float16,
46 | "tensor(float)": np.float32,
47 | "tensor(double)": np.float64,
48 | }
49 |
50 |
51 | class OnnxRuntimeModel:
52 | def __init__(self, model=None, **kwargs):
53 | logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
54 | self.model = model
55 | self.model_save_dir = kwargs.get("model_save_dir", None)
56 | self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)
57 |
58 | def __call__(self, **kwargs):
59 | inputs = {k: np.array(v) for k, v in kwargs.items()}
60 | return self.model.run(None, inputs)
61 |
62 | @staticmethod
63 | def load_model(path: Union[str, Path], provider=None, sess_options=None):
64 | """
65 | Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
66 |
67 | Arguments:
68 | path (`str` or `Path`):
69 | Directory from which to load
70 | provider(`str`, *optional*):
71 | Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
72 | """
73 | if provider is None:
74 | logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
75 | provider = "CPUExecutionProvider"
76 |
77 | return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
78 |
79 | def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
80 | """
81 | Save a model and its configuration file to a directory, so that it can be re-loaded using the
82 | [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
83 | latest_model_name.
84 |
85 | Arguments:
86 | save_directory (`str` or `Path`):
87 | Directory where to save the model file.
88 | file_name(`str`, *optional*):
89 | Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
90 | model with a different name.
91 | """
92 | model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
93 |
94 | src_path = self.model_save_dir.joinpath(self.latest_model_name)
95 | dst_path = Path(save_directory).joinpath(model_file_name)
96 | try:
97 | shutil.copyfile(src_path, dst_path)
98 | except shutil.SameFileError:
99 | pass
100 |
101 | # copy external weights (for models >2GB)
102 | src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
103 | if src_path.exists():
104 | dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
105 | try:
106 | shutil.copyfile(src_path, dst_path)
107 | except shutil.SameFileError:
108 | pass
109 |
110 | def save_pretrained(
111 | self,
112 | save_directory: Union[str, os.PathLike],
113 | **kwargs,
114 | ):
115 | """
116 | Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
117 | method.:
118 |
119 | Arguments:
120 | save_directory (`str` or `os.PathLike`):
121 | Directory to which to save. Will be created if it doesn't exist.
122 | """
123 | if os.path.isfile(save_directory):
124 | logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
125 | return
126 |
127 | os.makedirs(save_directory, exist_ok=True)
128 |
129 | # saving model weights/files
130 | self._save_pretrained(save_directory, **kwargs)
131 |
132 | @classmethod
133 | def _from_pretrained(
134 | cls,
135 | model_id: Union[str, Path],
136 | use_auth_token: Optional[Union[bool, str, None]] = None,
137 | revision: Optional[Union[str, None]] = None,
138 | force_download: bool = False,
139 | cache_dir: Optional[str] = None,
140 | file_name: Optional[str] = None,
141 | provider: Optional[str] = None,
142 | sess_options: Optional["ort.SessionOptions"] = None,
143 | **kwargs,
144 | ):
145 | """
146 | Load a model from a directory or the HF Hub.
147 |
148 | Arguments:
149 | model_id (`str` or `Path`):
150 | Directory from which to load
151 | use_auth_token (`str` or `bool`):
152 | Is needed to load models from a private or gated repository
153 | revision (`str`):
154 | Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
155 | cache_dir (`Union[str, Path]`, *optional*):
156 | Path to a directory in which a downloaded pretrained model configuration should be cached if the
157 | standard cache should not be used.
158 | force_download (`bool`, *optional*, defaults to `False`):
159 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the
160 | cached versions if they exist.
161 | file_name(`str`):
162 | Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
163 | different model files from the same repository or directory.
164 | provider(`str`):
165 | The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
166 | kwargs (`Dict`, *optional*):
167 | kwargs will be passed to the model during initialization
168 | """
169 | model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
170 | # load model from local directory
171 | if os.path.isdir(model_id):
172 | model = OnnxRuntimeModel.load_model(
173 | os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
174 | )
175 | kwargs["model_save_dir"] = Path(model_id)
176 | # load model from hub
177 | else:
178 | # download model
179 | model_cache_path = hf_hub_download(
180 | repo_id=model_id,
181 | filename=model_file_name,
182 | use_auth_token=use_auth_token,
183 | revision=revision,
184 | cache_dir=cache_dir,
185 | force_download=force_download,
186 | )
187 | kwargs["model_save_dir"] = Path(model_cache_path).parent
188 | kwargs["latest_model_name"] = Path(model_cache_path).name
189 | model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
190 | return cls(model=model, **kwargs)
191 |
192 | @classmethod
193 | def from_pretrained(
194 | cls,
195 | model_id: Union[str, Path],
196 | force_download: bool = True,
197 | use_auth_token: Optional[str] = None,
198 | cache_dir: Optional[str] = None,
199 | **model_kwargs,
200 | ):
201 | revision = None
202 | if len(str(model_id).split("@")) == 2:
203 | model_id, revision = model_id.split("@")
204 |
205 | return cls._from_pretrained(
206 | model_id=model_id,
207 | revision=revision,
208 | cache_dir=cache_dir,
209 | force_download=force_download,
210 | use_auth_token=use_auth_token,
211 | **model_kwargs,
212 | )
213 |
--------------------------------------------------------------------------------
/scripts/models/utils/richtext_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import random
5 | import numpy as np
6 |
7 | COLORS = {
8 | 'brown': [165, 42, 42],
9 | 'red': [255, 0, 0],
10 | 'pink': [253, 108, 158],
11 | 'orange': [255, 165, 0],
12 | 'yellow': [255, 255, 0],
13 | 'purple': [128, 0, 128],
14 | 'green': [0, 128, 0],
15 | 'blue': [0, 0, 255],
16 | 'white': [255, 255, 255],
17 | 'gray': [128, 128, 128],
18 | 'black': [0, 0, 0],
19 | }
20 |
21 |
22 | def seed_everything(seed):
23 | random.seed(seed)
24 | os.environ['PYTHONHASHSEED'] = str(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 |
29 |
30 | def hex_to_rgb(hex_string, return_nearest_color=False):
31 | r"""
32 | Covert Hex triplet to RGB triplet.
33 | """
34 | # Remove '#' symbol if present
35 | hex_string = hex_string.lstrip('#')
36 | # Convert hex values to integers
37 | red = int(hex_string[0:2], 16)
38 | green = int(hex_string[2:4], 16)
39 | blue = int(hex_string[4:6], 16)
40 | rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41 | if return_nearest_color:
42 | nearest_color = find_nearest_color(rgb)
43 | return rgb.cuda(), nearest_color
44 | return rgb.cuda()
45 |
46 |
47 | def find_nearest_color(rgb):
48 | r"""
49 | Find the nearest neighbor color given the RGB value.
50 | """
51 | if isinstance(rgb, list) or isinstance(rgb, tuple):
52 | rgb = torch.FloatTensor(rgb)[None, :, None, None]/255.
53 | color_distance = torch.FloatTensor([np.linalg.norm(
54 | rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()])
55 | nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()]
56 | return nearest_color
57 |
58 |
59 | def font2style(font):
60 | r"""
61 | Convert the font name to the style name.
62 | """
63 | return {'mirza': 'Claud Monet, impressionism, oil on canvas',
64 | 'roboto': 'Ukiyoe',
65 | 'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq',
66 | 'sofia': 'Pop Art, masterpiece, andy warhol',
67 | 'slabo': 'Vincent Van Gogh',
68 | 'inconsolata': 'Pixel Art, 8 bits, 16 bits',
69 | 'ubuntu': 'Rembrandt',
70 | 'Monoton': 'neon art, colorful light, highly details, octane render',
71 | 'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72 |
73 |
74 | def parse_json(json_str):
75 | r"""
76 | Convert the JSON string to attributes.
77 | """
78 | # initialze region-base attributes.
79 | base_text_prompt = ''
80 | style_text_prompts = []
81 | footnote_text_prompts = []
82 | footnote_target_tokens = []
83 | color_text_prompts = []
84 | color_rgbs = []
85 | color_names = []
86 | size_text_prompts_and_sizes = []
87 |
88 | # parse the attributes from JSON.
89 | prev_style = None
90 | prev_color_rgb = None
91 | use_grad_guidance = False
92 | for span in json_str['ops']:
93 | text_prompt = span['insert'].rstrip('\n')
94 | base_text_prompt += span['insert'].rstrip('\n')
95 | if text_prompt == ' ':
96 | continue
97 | if 'attributes' in span:
98 | if 'font' in span['attributes']:
99 | style = font2style(span['attributes']['font'])
100 | if prev_style == style:
101 | prev_text_prompt = style_text_prompts[-1].split('in the style of')[
102 | 0]
103 | style_text_prompts[-1] = prev_text_prompt + \
104 | ' ' + text_prompt + f' in the style of {style}'
105 | else:
106 | style_text_prompts.append(
107 | text_prompt + f' in the style of {style}')
108 | prev_style = style
109 | else:
110 | prev_style = None
111 | if 'link' in span['attributes']:
112 | footnote_text_prompts.append(span['attributes']['link'])
113 | footnote_target_tokens.append(text_prompt)
114 | font_size = 1
115 | if 'size' in span['attributes'] and 'strike' not in span['attributes']:
116 | font_size = float(span['attributes']['size'][:-2])/3.
117 | elif 'size' in span['attributes'] and 'strike' in span['attributes']:
118 | font_size = -float(span['attributes']['size'][:-2])/3.
119 | elif 'size' not in span['attributes'] and 'strike' not in span['attributes']:
120 | font_size = 1
121 | if 'color' in span['attributes']:
122 | use_grad_guidance = True
123 | color_rgb, nearest_color = hex_to_rgb(
124 | span['attributes']['color'], True)
125 | if prev_color_rgb == color_rgb:
126 | prev_text_prompt = color_text_prompts[-1]
127 | color_text_prompts[-1] = prev_text_prompt + \
128 | ' ' + text_prompt
129 | else:
130 | color_rgbs.append(color_rgb)
131 | color_names.append(nearest_color)
132 | color_text_prompts.append(text_prompt)
133 | if font_size != 1:
134 | size_text_prompts_and_sizes.append([text_prompt, font_size])
135 | return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
136 | color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance
137 |
138 |
139 | def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts,
140 | footnote_target_tokens, color_text_prompts, color_names):
141 | r"""
142 | Algorithm 1 in the paper.
143 | """
144 | region_text_prompts = []
145 | region_target_token_ids = []
146 | base_tokens = model.tokenizer._tokenize(base_text_prompt)
147 | # process the style text prompt
148 | for text_prompt in style_text_prompts:
149 | region_text_prompts.append(text_prompt)
150 | region_target_token_ids.append([])
151 | style_tokens = model.tokenizer._tokenize(
152 | text_prompt.split('in the style of')[0])
153 | for style_token in style_tokens:
154 | region_target_token_ids[-1].append(
155 | base_tokens.index(style_token)+1)
156 |
157 | # process the complementary text prompt
158 | for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens):
159 | region_target_token_ids.append([])
160 | region_text_prompts.append(footnote_text_prompt)
161 | style_tokens = model.tokenizer._tokenize(text_prompt)
162 | for style_token in style_tokens:
163 | region_target_token_ids[-1].append(
164 | base_tokens.index(style_token)+1)
165 |
166 | # process the color text prompt
167 | for color_text_prompt, color_name in zip(color_text_prompts, color_names):
168 | region_target_token_ids.append([])
169 | region_text_prompts.append(color_name+' '+color_text_prompt)
170 | style_tokens = model.tokenizer._tokenize(color_text_prompt)
171 | for style_token in style_tokens:
172 | region_target_token_ids[-1].append(
173 | base_tokens.index(style_token)+1)
174 |
175 | # process the remaining tokens without any attributes
176 | region_text_prompts.append(base_text_prompt)
177 | region_target_token_ids_all = [
178 | id for ids in region_target_token_ids for id in ids]
179 | target_token_ids_rest = [id for id in range(
180 | 1, len(base_tokens)+1) if id not in region_target_token_ids_all]
181 | region_target_token_ids.append(target_token_ids_rest)
182 |
183 | region_target_token_ids = [torch.LongTensor(
184 | obj_token_id) for obj_token_id in region_target_token_ids]
185 | return region_text_prompts, region_target_token_ids, base_tokens
186 |
187 |
188 | def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes):
189 | r"""
190 | Control the token impact using font sizes.
191 | """
192 | word_pos = []
193 | font_sizes = []
194 | for text_prompt, font_size in size_text_prompts_and_sizes:
195 | size_tokens = model.tokenizer._tokenize(text_prompt)
196 | for size_token in size_tokens:
197 | word_pos.append(base_tokens.index(size_token)+1)
198 | font_sizes.append(font_size)
199 | if len(word_pos) > 0:
200 | word_pos = torch.LongTensor(word_pos).cuda()
201 | font_sizes = torch.FloatTensor(font_sizes).cuda()
202 | else:
203 | word_pos = None
204 | font_sizes = None
205 | text_format_dict = {
206 | 'word_pos': word_pos,
207 | 'font_size': font_sizes,
208 | }
209 | return text_format_dict
210 |
211 |
212 | def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict,
213 | guidance_start_step=999, color_guidance_weight=1):
214 | r"""
215 | Control the token impact using font sizes.
216 | """
217 | color_target_token_ids = []
218 | for text_prompt in color_text_prompts:
219 | color_target_token_ids.append([])
220 | color_tokens = model.tokenizer._tokenize(text_prompt)
221 | for color_token in color_tokens:
222 | color_target_token_ids[-1].append(base_tokens.index(color_token)+1)
223 | color_target_token_ids_all = [
224 | id for ids in color_target_token_ids for id in ids]
225 | color_target_token_ids_rest = [id for id in range(
226 | 1, len(base_tokens)+1) if id not in color_target_token_ids_all]
227 | color_target_token_ids.append(color_target_token_ids_rest)
228 | color_target_token_ids = [torch.LongTensor(
229 | obj_token_id) for obj_token_id in color_target_token_ids]
230 |
231 | text_format_dict['target_RGB'] = color_rgbs
232 | text_format_dict['guidance_start_step'] = guidance_start_step
233 | text_format_dict['color_guidance_weight'] = color_guidance_weight
234 | return text_format_dict, color_target_token_ids
235 |
--------------------------------------------------------------------------------
/diffusers_official/schedulers/scheduling_karras_ve_flax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from dataclasses import dataclass
17 | from typing import Optional, Tuple, Union
18 |
19 | import flax
20 | import jax.numpy as jnp
21 | from jax import random
22 |
23 | from ..configuration_utils import ConfigMixin, register_to_config
24 | from ..utils import BaseOutput
25 | from .scheduling_utils_flax import FlaxSchedulerMixin
26 |
27 |
28 | @flax.struct.dataclass
29 | class KarrasVeSchedulerState:
30 | # setable values
31 | num_inference_steps: Optional[int] = None
32 | timesteps: Optional[jnp.ndarray] = None
33 | schedule: Optional[jnp.ndarray] = None # sigma(t_i)
34 |
35 | @classmethod
36 | def create(cls):
37 | return cls()
38 |
39 |
40 | @dataclass
41 | class FlaxKarrasVeOutput(BaseOutput):
42 | """
43 | Output class for the scheduler's step function output.
44 |
45 | Args:
46 | prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
47 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
48 | denoising loop.
49 | derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
50 | Derivative of predicted original image sample (x_0).
51 | state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
52 | """
53 |
54 | prev_sample: jnp.ndarray
55 | derivative: jnp.ndarray
56 | state: KarrasVeSchedulerState
57 |
58 |
59 | class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
60 | """
61 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
62 | the VE column of Table 1 from [1] for reference.
63 |
64 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
65 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
66 | differential equations." https://arxiv.org/abs/2011.13456
67 |
68 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
69 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
70 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
71 | [`~SchedulerMixin.from_pretrained`] functions.
72 |
73 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
74 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
75 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
76 |
77 | Args:
78 | sigma_min (`float`): minimum noise magnitude
79 | sigma_max (`float`): maximum noise magnitude
80 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
81 | A reasonable range is [1.000, 1.011].
82 | s_churn (`float`): the parameter controlling the overall amount of stochasticity.
83 | A reasonable range is [0, 100].
84 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
85 | A reasonable range is [0, 10].
86 | s_max (`float`): the end value of the sigma range where we add noise.
87 | A reasonable range is [0.2, 80].
88 | """
89 |
90 | @property
91 | def has_state(self):
92 | return True
93 |
94 | @register_to_config
95 | def __init__(
96 | self,
97 | sigma_min: float = 0.02,
98 | sigma_max: float = 100,
99 | s_noise: float = 1.007,
100 | s_churn: float = 80,
101 | s_min: float = 0.05,
102 | s_max: float = 50,
103 | ):
104 | pass
105 |
106 | def create_state(self):
107 | return KarrasVeSchedulerState.create()
108 |
109 | def set_timesteps(
110 | self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
111 | ) -> KarrasVeSchedulerState:
112 | """
113 | Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
114 |
115 | Args:
116 | state (`KarrasVeSchedulerState`):
117 | the `FlaxKarrasVeScheduler` state data class.
118 | num_inference_steps (`int`):
119 | the number of diffusion steps used when generating samples with a pre-trained model.
120 |
121 | """
122 | timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
123 | schedule = [
124 | (
125 | self.config.sigma_max**2
126 | * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
127 | )
128 | for i in timesteps
129 | ]
130 |
131 | return state.replace(
132 | num_inference_steps=num_inference_steps,
133 | schedule=jnp.array(schedule, dtype=jnp.float32),
134 | timesteps=timesteps,
135 | )
136 |
137 | def add_noise_to_input(
138 | self,
139 | state: KarrasVeSchedulerState,
140 | sample: jnp.ndarray,
141 | sigma: float,
142 | key: random.KeyArray,
143 | ) -> Tuple[jnp.ndarray, float]:
144 | """
145 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
146 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
147 |
148 | TODO Args:
149 | """
150 | if self.config.s_min <= sigma <= self.config.s_max:
151 | gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1)
152 | else:
153 | gamma = 0
154 |
155 | # sample eps ~ N(0, S_noise^2 * I)
156 | key = random.split(key, num=1)
157 | eps = self.config.s_noise * random.normal(key=key, shape=sample.shape)
158 | sigma_hat = sigma + gamma * sigma
159 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
160 |
161 | return sample_hat, sigma_hat
162 |
163 | def step(
164 | self,
165 | state: KarrasVeSchedulerState,
166 | model_output: jnp.ndarray,
167 | sigma_hat: float,
168 | sigma_prev: float,
169 | sample_hat: jnp.ndarray,
170 | return_dict: bool = True,
171 | ) -> Union[FlaxKarrasVeOutput, Tuple]:
172 | """
173 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
174 | process from the learned model outputs (most often the predicted noise).
175 |
176 | Args:
177 | state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
178 | model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
179 | sigma_hat (`float`): TODO
180 | sigma_prev (`float`): TODO
181 | sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
182 | return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
183 |
184 | Returns:
185 | [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
186 | chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is
187 | True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
188 | """
189 |
190 | pred_original_sample = sample_hat + sigma_hat * model_output
191 | derivative = (sample_hat - pred_original_sample) / sigma_hat
192 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
193 |
194 | if not return_dict:
195 | return (sample_prev, derivative, state)
196 |
197 | return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
198 |
199 | def step_correct(
200 | self,
201 | state: KarrasVeSchedulerState,
202 | model_output: jnp.ndarray,
203 | sigma_hat: float,
204 | sigma_prev: float,
205 | sample_hat: jnp.ndarray,
206 | sample_prev: jnp.ndarray,
207 | derivative: jnp.ndarray,
208 | return_dict: bool = True,
209 | ) -> Union[FlaxKarrasVeOutput, Tuple]:
210 | """
211 | Correct the predicted sample based on the output model_output of the network. TODO complete description
212 |
213 | Args:
214 | state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
215 | model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
216 | sigma_hat (`float`): TODO
217 | sigma_prev (`float`): TODO
218 | sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
219 | sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
220 | derivative (`torch.FloatTensor` or `np.ndarray`): TODO
221 | return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
222 |
223 | Returns:
224 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
225 |
226 | """
227 | pred_original_sample = sample_prev + sigma_prev * model_output
228 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
229 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
230 |
231 | if not return_dict:
232 | return (sample_prev, derivative, state)
233 |
234 | return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
235 |
236 | def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps):
237 | raise NotImplementedError()
238 |
--------------------------------------------------------------------------------
/diffusers_official/schedulers/scheduling_karras_ve.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from dataclasses import dataclass
17 | from typing import Optional, Tuple, Union
18 |
19 | import numpy as np
20 | import torch
21 |
22 | from ..configuration_utils import ConfigMixin, register_to_config
23 | from ..utils import BaseOutput, randn_tensor
24 | from .scheduling_utils import SchedulerMixin
25 |
26 |
27 | @dataclass
28 | class KarrasVeOutput(BaseOutput):
29 | """
30 | Output class for the scheduler's step function output.
31 |
32 | Args:
33 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
34 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
35 | denoising loop.
36 | derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37 | Derivative of predicted original image sample (x_0).
38 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39 | The predicted denoised sample (x_{0}) based on the model output from the current timestep.
40 | `pred_original_sample` can be used to preview progress or for guidance.
41 | """
42 |
43 | prev_sample: torch.FloatTensor
44 | derivative: torch.FloatTensor
45 | pred_original_sample: Optional[torch.FloatTensor] = None
46 |
47 |
48 | class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
49 | """
50 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
51 | the VE column of Table 1 from [1] for reference.
52 |
53 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
54 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
55 | differential equations." https://arxiv.org/abs/2011.13456
56 |
57 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
58 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
59 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
60 | [`~SchedulerMixin.from_pretrained`] functions.
61 |
62 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
63 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
64 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
65 |
66 | Args:
67 | sigma_min (`float`): minimum noise magnitude
68 | sigma_max (`float`): maximum noise magnitude
69 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
70 | A reasonable range is [1.000, 1.011].
71 | s_churn (`float`): the parameter controlling the overall amount of stochasticity.
72 | A reasonable range is [0, 100].
73 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
74 | A reasonable range is [0, 10].
75 | s_max (`float`): the end value of the sigma range where we add noise.
76 | A reasonable range is [0.2, 80].
77 |
78 | """
79 |
80 | order = 2
81 |
82 | @register_to_config
83 | def __init__(
84 | self,
85 | sigma_min: float = 0.02,
86 | sigma_max: float = 100,
87 | s_noise: float = 1.007,
88 | s_churn: float = 80,
89 | s_min: float = 0.05,
90 | s_max: float = 50,
91 | ):
92 | # standard deviation of the initial noise distribution
93 | self.init_noise_sigma = sigma_max
94 |
95 | # setable values
96 | self.num_inference_steps: int = None
97 | self.timesteps: np.IntTensor = None
98 | self.schedule: torch.FloatTensor = None # sigma(t_i)
99 |
100 | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
101 | """
102 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
103 | current timestep.
104 |
105 | Args:
106 | sample (`torch.FloatTensor`): input sample
107 | timestep (`int`, optional): current timestep
108 |
109 | Returns:
110 | `torch.FloatTensor`: scaled input sample
111 | """
112 | return sample
113 |
114 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
115 | """
116 | Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
117 |
118 | Args:
119 | num_inference_steps (`int`):
120 | the number of diffusion steps used when generating samples with a pre-trained model.
121 |
122 | """
123 | self.num_inference_steps = num_inference_steps
124 | timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
125 | self.timesteps = torch.from_numpy(timesteps).to(device)
126 | schedule = [
127 | (
128 | self.config.sigma_max**2
129 | * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
130 | )
131 | for i in self.timesteps
132 | ]
133 | self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
134 |
135 | def add_noise_to_input(
136 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
137 | ) -> Tuple[torch.FloatTensor, float]:
138 | """
139 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
140 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
141 |
142 | TODO Args:
143 | """
144 | if self.config.s_min <= sigma <= self.config.s_max:
145 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
146 | else:
147 | gamma = 0
148 |
149 | # sample eps ~ N(0, S_noise^2 * I)
150 | eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
151 | sigma_hat = sigma + gamma * sigma
152 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
153 |
154 | return sample_hat, sigma_hat
155 |
156 | def step(
157 | self,
158 | model_output: torch.FloatTensor,
159 | sigma_hat: float,
160 | sigma_prev: float,
161 | sample_hat: torch.FloatTensor,
162 | return_dict: bool = True,
163 | ) -> Union[KarrasVeOutput, Tuple]:
164 | """
165 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
166 | process from the learned model outputs (most often the predicted noise).
167 |
168 | Args:
169 | model_output (`torch.FloatTensor`): direct output from learned diffusion model.
170 | sigma_hat (`float`): TODO
171 | sigma_prev (`float`): TODO
172 | sample_hat (`torch.FloatTensor`): TODO
173 | return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
174 |
175 | KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
176 | Returns:
177 | [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
178 | [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
179 | returning a tuple, the first element is the sample tensor.
180 |
181 | """
182 |
183 | pred_original_sample = sample_hat + sigma_hat * model_output
184 | derivative = (sample_hat - pred_original_sample) / sigma_hat
185 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
186 |
187 | if not return_dict:
188 | return (sample_prev, derivative)
189 |
190 | return KarrasVeOutput(
191 | prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
192 | )
193 |
194 | def step_correct(
195 | self,
196 | model_output: torch.FloatTensor,
197 | sigma_hat: float,
198 | sigma_prev: float,
199 | sample_hat: torch.FloatTensor,
200 | sample_prev: torch.FloatTensor,
201 | derivative: torch.FloatTensor,
202 | return_dict: bool = True,
203 | ) -> Union[KarrasVeOutput, Tuple]:
204 | """
205 | Correct the predicted sample based on the output model_output of the network. TODO complete description
206 |
207 | Args:
208 | model_output (`torch.FloatTensor`): direct output from learned diffusion model.
209 | sigma_hat (`float`): TODO
210 | sigma_prev (`float`): TODO
211 | sample_hat (`torch.FloatTensor`): TODO
212 | sample_prev (`torch.FloatTensor`): TODO
213 | derivative (`torch.FloatTensor`): TODO
214 | return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
215 |
216 | Returns:
217 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
218 |
219 | """
220 | pred_original_sample = sample_prev + sigma_prev * model_output
221 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
222 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
223 |
224 | if not return_dict:
225 | return (sample_prev, derivative)
226 |
227 | return KarrasVeOutput(
228 | prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
229 | )
230 |
231 | def add_noise(self, original_samples, noise, timesteps):
232 | raise NotImplementedError()
233 |
--------------------------------------------------------------------------------
/diffusers_official/utils/logging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Optuna, Hugging Face
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Logging utilities."""
16 |
17 | import logging
18 | import os
19 | import sys
20 | import threading
21 | from logging import (
22 | CRITICAL, # NOQA
23 | DEBUG, # NOQA
24 | ERROR, # NOQA
25 | FATAL, # NOQA
26 | INFO, # NOQA
27 | NOTSET, # NOQA
28 | WARN, # NOQA
29 | WARNING, # NOQA
30 | )
31 | from typing import Optional
32 |
33 | from tqdm import auto as tqdm_lib
34 |
35 |
36 | _lock = threading.Lock()
37 | _default_handler: Optional[logging.Handler] = None
38 |
39 | log_levels = {
40 | "debug": logging.DEBUG,
41 | "info": logging.INFO,
42 | "warning": logging.WARNING,
43 | "error": logging.ERROR,
44 | "critical": logging.CRITICAL,
45 | }
46 |
47 | _default_log_level = logging.WARNING
48 |
49 | _tqdm_active = True
50 |
51 |
52 | def _get_default_logging_level():
53 | """
54 | If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
55 | not - fall back to `_default_log_level`
56 | """
57 | env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
58 | if env_level_str:
59 | if env_level_str in log_levels:
60 | return log_levels[env_level_str]
61 | else:
62 | logging.getLogger().warning(
63 | f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
64 | f"has to be one of: { ', '.join(log_levels.keys()) }"
65 | )
66 | return _default_log_level
67 |
68 |
69 | def _get_library_name() -> str:
70 | return __name__.split(".")[0]
71 |
72 |
73 | def _get_library_root_logger() -> logging.Logger:
74 | return logging.getLogger(_get_library_name())
75 |
76 |
77 | def _configure_library_root_logger() -> None:
78 | global _default_handler
79 |
80 | with _lock:
81 | if _default_handler:
82 | # This library has already configured the library root logger.
83 | return
84 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
85 | _default_handler.flush = sys.stderr.flush
86 |
87 | # Apply our default configuration to the library root logger.
88 | library_root_logger = _get_library_root_logger()
89 | library_root_logger.addHandler(_default_handler)
90 | library_root_logger.setLevel(_get_default_logging_level())
91 | library_root_logger.propagate = False
92 |
93 |
94 | def _reset_library_root_logger() -> None:
95 | global _default_handler
96 |
97 | with _lock:
98 | if not _default_handler:
99 | return
100 |
101 | library_root_logger = _get_library_root_logger()
102 | library_root_logger.removeHandler(_default_handler)
103 | library_root_logger.setLevel(logging.NOTSET)
104 | _default_handler = None
105 |
106 |
107 | def get_log_levels_dict():
108 | return log_levels
109 |
110 |
111 | def get_logger(name: Optional[str] = None) -> logging.Logger:
112 | """
113 | Return a logger with the specified name.
114 |
115 | This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
116 | """
117 |
118 | if name is None:
119 | name = _get_library_name()
120 |
121 | _configure_library_root_logger()
122 | return logging.getLogger(name)
123 |
124 |
125 | def get_verbosity() -> int:
126 | """
127 | Return the current level for the 🤗 Diffusers' root logger as an `int`.
128 |
129 | Returns:
130 | `int`:
131 | Logging level integers which can be one of:
132 |
133 | - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
134 | - `40`: `diffusers.logging.ERROR`
135 | - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
136 | - `20`: `diffusers.logging.INFO`
137 | - `10`: `diffusers.logging.DEBUG`
138 |
139 | """
140 |
141 | _configure_library_root_logger()
142 | return _get_library_root_logger().getEffectiveLevel()
143 |
144 |
145 | def set_verbosity(verbosity: int) -> None:
146 | """
147 | Set the verbosity level for the 🤗 Diffusers' root logger.
148 |
149 | Args:
150 | verbosity (`int`):
151 | Logging level which can be one of:
152 |
153 | - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
154 | - `diffusers.logging.ERROR`
155 | - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
156 | - `diffusers.logging.INFO`
157 | - `diffusers.logging.DEBUG`
158 | """
159 |
160 | _configure_library_root_logger()
161 | _get_library_root_logger().setLevel(verbosity)
162 |
163 |
164 | def set_verbosity_info():
165 | """Set the verbosity to the `INFO` level."""
166 | return set_verbosity(INFO)
167 |
168 |
169 | def set_verbosity_warning():
170 | """Set the verbosity to the `WARNING` level."""
171 | return set_verbosity(WARNING)
172 |
173 |
174 | def set_verbosity_debug():
175 | """Set the verbosity to the `DEBUG` level."""
176 | return set_verbosity(DEBUG)
177 |
178 |
179 | def set_verbosity_error():
180 | """Set the verbosity to the `ERROR` level."""
181 | return set_verbosity(ERROR)
182 |
183 |
184 | def disable_default_handler() -> None:
185 | """Disable the default handler of the 🤗 Diffusers' root logger."""
186 |
187 | _configure_library_root_logger()
188 |
189 | assert _default_handler is not None
190 | _get_library_root_logger().removeHandler(_default_handler)
191 |
192 |
193 | def enable_default_handler() -> None:
194 | """Enable the default handler of the 🤗 Diffusers' root logger."""
195 |
196 | _configure_library_root_logger()
197 |
198 | assert _default_handler is not None
199 | _get_library_root_logger().addHandler(_default_handler)
200 |
201 |
202 | def add_handler(handler: logging.Handler) -> None:
203 | """adds a handler to the HuggingFace Diffusers' root logger."""
204 |
205 | _configure_library_root_logger()
206 |
207 | assert handler is not None
208 | _get_library_root_logger().addHandler(handler)
209 |
210 |
211 | def remove_handler(handler: logging.Handler) -> None:
212 | """removes given handler from the HuggingFace Diffusers' root logger."""
213 |
214 | _configure_library_root_logger()
215 |
216 | assert handler is not None and handler not in _get_library_root_logger().handlers
217 | _get_library_root_logger().removeHandler(handler)
218 |
219 |
220 | def disable_propagation() -> None:
221 | """
222 | Disable propagation of the library log outputs. Note that log propagation is disabled by default.
223 | """
224 |
225 | _configure_library_root_logger()
226 | _get_library_root_logger().propagate = False
227 |
228 |
229 | def enable_propagation() -> None:
230 | """
231 | Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
232 | double logging if the root logger has been configured.
233 | """
234 |
235 | _configure_library_root_logger()
236 | _get_library_root_logger().propagate = True
237 |
238 |
239 | def enable_explicit_format() -> None:
240 | """
241 | Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
242 | ```
243 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
244 | ```
245 | All handlers currently bound to the root logger are affected by this method.
246 | """
247 | handlers = _get_library_root_logger().handlers
248 |
249 | for handler in handlers:
250 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
251 | handler.setFormatter(formatter)
252 |
253 |
254 | def reset_format() -> None:
255 | """
256 | Resets the formatting for 🤗 Diffusers' loggers.
257 |
258 | All handlers currently bound to the root logger are affected by this method.
259 | """
260 | handlers = _get_library_root_logger().handlers
261 |
262 | for handler in handlers:
263 | handler.setFormatter(None)
264 |
265 |
266 | def warning_advice(self, *args, **kwargs):
267 | """
268 | This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
269 | warning will not be printed
270 | """
271 | no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
272 | if no_advisory_warnings:
273 | return
274 | self.warning(*args, **kwargs)
275 |
276 |
277 | logging.Logger.warning_advice = warning_advice
278 |
279 |
280 | class EmptyTqdm:
281 | """Dummy tqdm which doesn't do anything."""
282 |
283 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
284 | self._iterator = args[0] if args else None
285 |
286 | def __iter__(self):
287 | return iter(self._iterator)
288 |
289 | def __getattr__(self, _):
290 | """Return empty function."""
291 |
292 | def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
293 | return
294 |
295 | return empty_fn
296 |
297 | def __enter__(self):
298 | return self
299 |
300 | def __exit__(self, type_, value, traceback):
301 | return
302 |
303 |
304 | class _tqdm_cls:
305 | def __call__(self, *args, **kwargs):
306 | if _tqdm_active:
307 | return tqdm_lib.tqdm(*args, **kwargs)
308 | else:
309 | return EmptyTqdm(*args, **kwargs)
310 |
311 | def set_lock(self, *args, **kwargs):
312 | self._lock = None
313 | if _tqdm_active:
314 | return tqdm_lib.tqdm.set_lock(*args, **kwargs)
315 |
316 | def get_lock(self):
317 | if _tqdm_active:
318 | return tqdm_lib.tqdm.get_lock()
319 |
320 |
321 | tqdm = _tqdm_cls()
322 |
323 |
324 | def is_progress_bar_enabled() -> bool:
325 | """Return a boolean indicating whether tqdm progress bars are enabled."""
326 | global _tqdm_active
327 | return bool(_tqdm_active)
328 |
329 |
330 | def enable_progress_bar():
331 | """Enable tqdm progress bar."""
332 | global _tqdm_active
333 | _tqdm_active = True
334 |
335 |
336 | def disable_progress_bar():
337 | """Disable tqdm progress bar."""
338 | global _tqdm_active
339 | _tqdm_active = False
340 |
--------------------------------------------------------------------------------
/diffusers_official/models/unet_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 | from typing import Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from ..configuration_utils import ConfigMixin, register_to_config
22 | from ..utils import BaseOutput
23 | from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
24 | from .modeling_utils import ModelMixin
25 | from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
26 |
27 |
28 | @dataclass
29 | class UNet1DOutput(BaseOutput):
30 | """
31 | The output of [`UNet1DModel`].
32 |
33 | Args:
34 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
35 | The hidden states output from the last layer of the model.
36 | """
37 |
38 | sample: torch.FloatTensor
39 |
40 |
41 | class UNet1DModel(ModelMixin, ConfigMixin):
42 | r"""
43 | A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
44 |
45 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46 | for all models (such as downloading or saving).
47 |
48 | Parameters:
49 | sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
50 | in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
51 | out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
52 | extra_in_channels (`int`, *optional*, defaults to 0):
53 | Number of additional channels to be added to the input of the first down block. Useful for cases where the
54 | input data has more channels than what the model was initially designed for.
55 | time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
56 | freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
57 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
58 | Whether to flip sin to cos for Fourier time embedding.
59 | down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
60 | Tuple of downsample block types.
61 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
62 | Tuple of upsample block types.
63 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
64 | Tuple of block output channels.
65 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
66 | out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
67 | act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
68 | norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
69 | layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
70 | downsample_each_block (`int`, *optional*, defaults to `False`):
71 | Experimental feature for using a UNet without upsampling.
72 | """
73 |
74 | @register_to_config
75 | def __init__(
76 | self,
77 | sample_size: int = 65536,
78 | sample_rate: Optional[int] = None,
79 | in_channels: int = 2,
80 | out_channels: int = 2,
81 | extra_in_channels: int = 0,
82 | time_embedding_type: str = "fourier",
83 | flip_sin_to_cos: bool = True,
84 | use_timestep_embedding: bool = False,
85 | freq_shift: float = 0.0,
86 | down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
87 | up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
88 | mid_block_type: Tuple[str] = "UNetMidBlock1D",
89 | out_block_type: str = None,
90 | block_out_channels: Tuple[int] = (32, 32, 64),
91 | act_fn: str = None,
92 | norm_num_groups: int = 8,
93 | layers_per_block: int = 1,
94 | downsample_each_block: bool = False,
95 | ):
96 | super().__init__()
97 | self.sample_size = sample_size
98 |
99 | # time
100 | if time_embedding_type == "fourier":
101 | self.time_proj = GaussianFourierProjection(
102 | embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
103 | )
104 | timestep_input_dim = 2 * block_out_channels[0]
105 | elif time_embedding_type == "positional":
106 | self.time_proj = Timesteps(
107 | block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
108 | )
109 | timestep_input_dim = block_out_channels[0]
110 |
111 | if use_timestep_embedding:
112 | time_embed_dim = block_out_channels[0] * 4
113 | self.time_mlp = TimestepEmbedding(
114 | in_channels=timestep_input_dim,
115 | time_embed_dim=time_embed_dim,
116 | act_fn=act_fn,
117 | out_dim=block_out_channels[0],
118 | )
119 |
120 | self.down_blocks = nn.ModuleList([])
121 | self.mid_block = None
122 | self.up_blocks = nn.ModuleList([])
123 | self.out_block = None
124 |
125 | # down
126 | output_channel = in_channels
127 | for i, down_block_type in enumerate(down_block_types):
128 | input_channel = output_channel
129 | output_channel = block_out_channels[i]
130 |
131 | if i == 0:
132 | input_channel += extra_in_channels
133 |
134 | is_final_block = i == len(block_out_channels) - 1
135 |
136 | down_block = get_down_block(
137 | down_block_type,
138 | num_layers=layers_per_block,
139 | in_channels=input_channel,
140 | out_channels=output_channel,
141 | temb_channels=block_out_channels[0],
142 | add_downsample=not is_final_block or downsample_each_block,
143 | )
144 | self.down_blocks.append(down_block)
145 |
146 | # mid
147 | self.mid_block = get_mid_block(
148 | mid_block_type,
149 | in_channels=block_out_channels[-1],
150 | mid_channels=block_out_channels[-1],
151 | out_channels=block_out_channels[-1],
152 | embed_dim=block_out_channels[0],
153 | num_layers=layers_per_block,
154 | add_downsample=downsample_each_block,
155 | )
156 |
157 | # up
158 | reversed_block_out_channels = list(reversed(block_out_channels))
159 | output_channel = reversed_block_out_channels[0]
160 | if out_block_type is None:
161 | final_upsample_channels = out_channels
162 | else:
163 | final_upsample_channels = block_out_channels[0]
164 |
165 | for i, up_block_type in enumerate(up_block_types):
166 | prev_output_channel = output_channel
167 | output_channel = (
168 | reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
169 | )
170 |
171 | is_final_block = i == len(block_out_channels) - 1
172 |
173 | up_block = get_up_block(
174 | up_block_type,
175 | num_layers=layers_per_block,
176 | in_channels=prev_output_channel,
177 | out_channels=output_channel,
178 | temb_channels=block_out_channels[0],
179 | add_upsample=not is_final_block,
180 | )
181 | self.up_blocks.append(up_block)
182 | prev_output_channel = output_channel
183 |
184 | # out
185 | num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
186 | self.out_block = get_out_block(
187 | out_block_type=out_block_type,
188 | num_groups_out=num_groups_out,
189 | embed_dim=block_out_channels[0],
190 | out_channels=out_channels,
191 | act_fn=act_fn,
192 | fc_dim=block_out_channels[-1] // 4,
193 | )
194 |
195 | def forward(
196 | self,
197 | sample: torch.FloatTensor,
198 | timestep: Union[torch.Tensor, float, int],
199 | return_dict: bool = True,
200 | ) -> Union[UNet1DOutput, Tuple]:
201 | r"""
202 | The [`UNet1DModel`] forward method.
203 |
204 | Args:
205 | sample (`torch.FloatTensor`):
206 | The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
208 | return_dict (`bool`, *optional*, defaults to `True`):
209 | Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
210 |
211 | Returns:
212 | [`~models.unet_1d.UNet1DOutput`] or `tuple`:
213 | If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
214 | returned where the first element is the sample tensor.
215 | """
216 |
217 | # 1. time
218 | timesteps = timestep
219 | if not torch.is_tensor(timesteps):
220 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
221 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
222 | timesteps = timesteps[None].to(sample.device)
223 |
224 | timestep_embed = self.time_proj(timesteps)
225 | if self.config.use_timestep_embedding:
226 | timestep_embed = self.time_mlp(timestep_embed)
227 | else:
228 | timestep_embed = timestep_embed[..., None]
229 | timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
230 | timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
231 |
232 | # 2. down
233 | down_block_res_samples = ()
234 | for downsample_block in self.down_blocks:
235 | sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
236 | down_block_res_samples += res_samples
237 |
238 | # 3. mid
239 | if self.mid_block:
240 | sample = self.mid_block(sample, timestep_embed)
241 |
242 | # 4. up
243 | for i, upsample_block in enumerate(self.up_blocks):
244 | res_samples = down_block_res_samples[-1:]
245 | down_block_res_samples = down_block_res_samples[:-1]
246 | sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
247 |
248 | # 5. post-process
249 | if self.out_block:
250 | sample = self.out_block(sample, timestep_embed)
251 |
252 | if not return_dict:
253 | return (sample,)
254 |
255 | return UNet1DOutput(sample=sample)
256 |
--------------------------------------------------------------------------------