├── FramePainter_node.py ├── LICENSE ├── README.md ├── __init__.py ├── example.png ├── examples ├── 635499.jpg ├── 752538.jpg ├── __init__.py └── anime1.png ├── intro_teaser.png ├── modules ├── attention_processors.py ├── pipelines │ └── pipeline_framepainter.py ├── sparse_control_encoder.py ├── transformer_temporal.py ├── unet_spatio_temporal_condition_edit.py └── utils │ ├── __init__.py │ ├── attention_utils.py │ └── scheduling_euler_discrete_karras_fix.py ├── node_utils.py ├── pyproject.toml ├── requirements.txt └── svd_repo ├── feature_extractor └── preprocessor_config.json ├── model_index.json ├── scheduler └── scheduler_config.json ├── unet └── config.json └── vae └── config.json /FramePainter_node.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | import gc 6 | import numpy as np 7 | from diffusers import AutoencoderKLTemporalDecoder 8 | from diffusers.schedulers import EulerDiscreteScheduler 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler 12 | from safetensors.torch import load_file 13 | from .modules.pipelines.pipeline_framepainter import FramePainterPipeline 14 | from .modules.sparse_control_encoder import SparseControlEncoder 15 | from .modules.unet_spatio_temporal_condition_edit import UNetSpatioTemporalConditionEdit 16 | from .modules.attention_processors import MatchingAttnProcessor2_0 17 | from .modules.utils.attention_utils import set_matching_attention, set_matching_attention_processor 18 | from .node_utils import process_image_with_mask,timer,pil2narry,tensor_upscale,convert_cf2diffuser 19 | import folder_paths 20 | 21 | 22 | MAX_SEED = np.iinfo(np.int32).max 23 | current_node_path = os.path.dirname(os.path.abspath(__file__)) 24 | device = torch.device( 25 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | 27 | # add checkpoints dir 28 | FramePainter_weigths_path = os.path.join(folder_paths.models_dir, "FramePainter") 29 | if not os.path.exists(FramePainter_weigths_path): 30 | os.makedirs(FramePainter_weigths_path) 31 | folder_paths.add_model_folder_path("FramePainter", FramePainter_weigths_path) 32 | 33 | 34 | 35 | class FramePainter_Loader: 36 | def __init__(self): 37 | pass 38 | 39 | @classmethod 40 | def INPUT_TYPES(s): 41 | FramePainter_unet_list = [i for i in folder_paths.get_filename_list("FramePainter") if 42 | "unet" in i] 43 | sparse_control_ckpt_list = [i for i in folder_paths.get_filename_list("FramePainter") if 44 | "encoder" in i] 45 | return { 46 | "required": { 47 | "model": ("MODEL",), 48 | "FramePainter_unet": (["none"] + FramePainter_unet_list,), 49 | "sparse_control_ckpt": (["none"] + sparse_control_ckpt_list,), 50 | }, 51 | } 52 | 53 | RETURN_TYPES = ("MODEL_FramePainter",) 54 | RETURN_NAMES = ("model",) 55 | FUNCTION = "loader_main" 56 | CATEGORY = "FramePainter" 57 | 58 | def loader_main(self,model,FramePainter_unet,sparse_control_ckpt): 59 | 60 | if FramePainter_unet!="none": 61 | FramePainter_unet=folder_paths.get_full_path("FramePainter", FramePainter_unet) 62 | else: 63 | raise ValueError("Please select a valid FramePainter_unet.") 64 | if sparse_control_ckpt!="none": 65 | sparse_control_ckpt=folder_paths.get_full_path("FramePainter", sparse_control_ckpt) 66 | else: 67 | raise ValueError("Please select a valid sparse_control_ckpt.") 68 | 69 | svd_repo = os.path.join(current_node_path, "svd_repo") 70 | # load model 71 | print("***********Load model ***********") 72 | unet_config_file=os.path.join(svd_repo, "unet") 73 | unet=convert_cf2diffuser(model.model,unet_config_file) 74 | 75 | # unet = UNetSpatioTemporalConditionEdit.from_pretrained( 76 | # svd_repo, 77 | # subfolder="unet", 78 | # low_cpu_mem_usage=True, 79 | # variant="fp16" 80 | # ) 81 | 82 | sparse_control_encoder = SparseControlEncoder() 83 | 84 | # vae = AutoencoderKLTemporalDecoder.from_pretrained( 85 | # svd_repo, subfolder="vae",variant="fp16") 86 | vae_config=os.path.join(svd_repo, "vae/config.json") 87 | vae_config=OmegaConf.load(vae_config) 88 | noise_scheduler = EulerDiscreteScheduler.from_pretrained( 89 | svd_repo, subfolder="scheduler") 90 | pipeline = FramePainterPipeline.from_pretrained( 91 | svd_repo, 92 | sparse_control_encoder=sparse_control_encoder, 93 | unet=unet, 94 | vae_config=vae_config, 95 | revision=None, 96 | noise_scheduler=noise_scheduler, 97 | #vae=vae, 98 | ) 99 | 100 | set_matching_attention(pipeline.unet) 101 | set_matching_attention_processor(pipeline.unet, MatchingAttnProcessor2_0(batch_size=2)) 102 | 103 | pipeline.set_progress_bar_config(disable=False) 104 | 105 | sparse_control_dict=load_file(sparse_control_ckpt) 106 | pipeline.sparse_control_encoder.load_state_dict(sparse_control_dict, strict=True) 107 | FramePainter_unet_dict=load_file(FramePainter_unet) 108 | pipeline.unet.load_state_dict(FramePainter_unet_dict, strict=True) 109 | 110 | print("***********Load model done ***********") 111 | del sparse_control_dict,FramePainter_unet_dict 112 | gc.collect() 113 | torch.cuda.empty_cache() 114 | return (pipeline,) 115 | 116 | 117 | 118 | class FramePainter_Sampler: 119 | def __init__(self): 120 | pass 121 | 122 | @classmethod 123 | def INPUT_TYPES(s): 124 | return { 125 | "required": { 126 | "model": ("MODEL_FramePainter",), 127 | "clip_vision": ("CLIP_VISION",), 128 | "vae": ("VAE",), 129 | "image": ("IMAGE",), # B H W C C=3 130 | "mask": ("MASK",), # B H W 131 | "seed": ("INT", {"default": 0, "min": 0, "max": MAX_SEED}), 132 | "steps": ("INT", {"default": 25, "min": 15, "max": 100, "step": 1, "display": "number"}), 133 | "guidance_scale":("FLOAT", {"default": 3.0, "min": 0.0, "max": 30.0,"step": 0.5}), 134 | "width": ("INT", {"default": 512, "min": 256, "max": 1920, "step": 64, "display": "number"}), 135 | "height": ("INT", {"default": 512, "min": 256, "max": 1920, "step": 64, "display": "number"}), 136 | "control_scale": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0,"step": 0.05}), 137 | }, 138 | } 139 | 140 | 141 | RETURN_TYPES = ("IMAGE",) 142 | RETURN_NAMES = ("image",) 143 | FUNCTION = "sampler_main" 144 | CATEGORY = "FramePainter" 145 | 146 | 147 | def sampler_main(self, model,clip_vision,vae,image,mask,seed,steps,guidance_scale,width,height,control_scale): 148 | 149 | model.to("cuda") 150 | #vae=kwargs.get("vae", None) 151 | image_embeds = clip_vision.encode_image(image)["image_embeds"] #torch.Size([1, 1024]) 152 | image_embeds=image_embeds.clone().detach().to(device, dtype=torch.float16) # dtype需要改成可选 153 | 154 | print("***********Start infer ***********") 155 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16), timer("inference"): 156 | 157 | input_image_resized,merged_image=process_image_with_mask(image,mask, width, height) 158 | 159 | validation_control_images = [ 160 | Image.new("RGB", (width, height), color=(0, 0, 0)), 161 | merged_image 162 | ] 163 | if vae is not None: 164 | validation_control_ = [np.array(pil2narry(image)) for image in validation_control_images] 165 | validation_control_ = [torch.from_numpy(img).to(device, dtype=torch.float16) for img in validation_control_] 166 | validation_control_ = torch.cat(validation_control_, dim=0) 167 | input_latents=vae.encode(tensor_upscale(image,width,height)).to(device, dtype=torch.float16) 168 | validation_control_images=validation_control_.unsqueeze(0).permute(0,1,4,2,3).to(device, dtype=torch.float16) 169 | negative_image_latents = torch.zeros_like(input_latents) 170 | input_latents=torch.cat([negative_image_latents, input_latents]) 171 | else: 172 | raise "vae is None" 173 | result = model( 174 | input_image_resized, 175 | validation_control_images, 176 | height=height, 177 | width=width, 178 | edit_cond_scale=control_scale, 179 | guidance_scale=guidance_scale, 180 | num_inference_steps=steps, 181 | generator=torch.Generator().manual_seed(seed), 182 | output_type="latent",#out lantents 183 | image_embs=image_embeds, 184 | input_latents=input_latents, 185 | 186 | ).frames[0], 187 | 188 | b,_,_,_=result[0].shape 189 | last= torch.chunk(result[0], chunks=b) 190 | image=vae.decode(last[-1]) 191 | gc.collect() 192 | torch.cuda.empty_cache() 193 | return (image,) 194 | #image = result[0][1] 195 | #return (pil2narry(image),out,) 196 | 197 | 198 | 199 | NODE_CLASS_MAPPINGS = { 200 | "FramePainter_Loader":FramePainter_Loader, 201 | "FramePainter_Sampler":FramePainter_Sampler, 202 | } 203 | 204 | NODE_DISPLAY_NAME_MAPPINGS = { 205 | "FramePainter_Loader":"FramePainter_Loader", 206 | "FramePainter_Sampler":"FramePainter_Sampler", 207 | } 208 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 smthemex 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_FramePainter 2 | Official pytorch implementation of "[FramePainter](https://github.com/YBYBZhang/FramePainter): Endowing Interactive Image Editing with Video Diffusion Priors",you can use it in comfyUI 3 | 4 | # Update 5 | * use single checkpoint now 改成SVD单体模型加载方式 6 | * now 8G VRAM can run 512*512 峰值显存7G多,按理8G也能跑512了 7 | 8 | 9 | # 1. Installation 10 | 11 | In the ./ComfyUI /custom_node directory, run the following: 12 | ``` 13 | git clone https://github.com/smthemex/ComfyUI_FramePainter.git 14 | ``` 15 | --- 16 | 17 | # 2. Requirements 18 | * no need, because it's normal for comfyUI ,Perhaps someone may be missing the library.没什么特殊的库,懒得删了 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | # 3.Model 24 | * 3.1 download checkpoints from [here](https://huggingface.co/Yabo/FramePainter/tree/main) 从抱脸下载必须的模型,文件结构如下图 25 | ``` 26 | -- ComfyUI/models/FramePainter/ 27 | |-- unet_diffusion_pytorch_model.safetensors 28 | |-- encoder_diffusion_pytorch_model.safetensors 29 | ``` 30 | * 3.2 SVD checkpoints [svd_xt.safetensors](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) or [svd_xt_1_1.safetensors](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) 31 | 32 | ``` 33 | -- ComfyUI/models/checkpoints 34 | ├── svd_xt.safetensors or svd_xt_1_1.safetensors 35 | ``` 36 | 37 | 38 | # 4.Example 39 | ![](https://github.com/smthemex/ComfyUI_FramePainter/blob/main/example.png) 40 | 41 | # 5.Citation 42 | [FramePainter](https://github.com/YBYBZhang/FramePainter) 43 | 44 | * diffusers 45 | ``` 46 | @misc{von-platen-etal-2022-diffusers, 47 | author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Dhruv Nair and Sayak Paul and William Berman and Yiyi Xu and Steven Liu and Thomas Wolf}, 48 | title = {Diffusers: State-of-the-art diffusion models}, 49 | year = {2022}, 50 | publisher = {GitHub}, 51 | journal = {GitHub repository}, 52 | howpublished = {\url{https://github.com/huggingface/diffusers}} 53 | } 54 | ``` 55 | * controlnext 56 | ``` 57 | @article{peng2024controlnext, 58 | title={ControlNeXt: Powerful and Efficient Control for Image and Video Generation}, 59 | author={Peng, Bohao and Wang, Jian and Zhang, Yuechen and Li, Wenbo and Yang, Ming-Chang and Jia, Jiaya}, 60 | journal={arXiv preprint arXiv:2408.06070}, 61 | year={2024} 62 | } 63 | `` 64 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .FramePainter_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 3 | 4 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 5 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_FramePainter/ad6193bb88b317f9e49d0d711fae5102a53d6c47/example.png -------------------------------------------------------------------------------- /examples/635499.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_FramePainter/ad6193bb88b317f9e49d0d711fae5102a53d6c47/examples/635499.jpg -------------------------------------------------------------------------------- /examples/752538.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_FramePainter/ad6193bb88b317f9e49d0d711fae5102a53d6c47/examples/752538.jpg -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/anime1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_FramePainter/ad6193bb88b317f9e49d0d711fae5102a53d6c47/examples/anime1.png -------------------------------------------------------------------------------- /intro_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_FramePainter/ad6193bb88b317f9e49d0d711fae5102a53d6c47/intro_teaser.png -------------------------------------------------------------------------------- /modules/attention_processors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def rearrange_3(tensor, f): 6 | F, D, C = tensor.size() 7 | return torch.reshape(tensor, (F // f, f, D, C)) 8 | 9 | 10 | def rearrange_4(tensor): 11 | B, F, D, C = tensor.size() 12 | return torch.reshape(tensor, (B * F, D, C)) 13 | 14 | class MatchingAttnProcessor2_0: 15 | """ 16 | Matching attention processor with scaled_dot_product attention of Pytorch 2.0. 17 | 18 | Args: 19 | batch_size: The number that represents actual batch size, other than the frames. 20 | For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 21 | 2, due to classifier-free guidance. 22 | """ 23 | 24 | def __init__(self, batch_size=2): 25 | if not hasattr(F, "scaled_dot_product_attention"): 26 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 27 | self.batch_size = batch_size 28 | self.attention_weights = None 29 | 30 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 31 | batch_size, sequence_length, _ = ( 32 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 33 | ) 34 | inner_dim = hidden_states.shape[-1] 35 | 36 | if attention_mask is not None: 37 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 38 | # scaled_dot_product_attention expects attention_mask shape to be 39 | # (batch, heads, source_length, target_length) 40 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 41 | 42 | query = attn.to_q(hidden_states) 43 | 44 | is_cross_attention = encoder_hidden_states is not None 45 | if encoder_hidden_states is None: 46 | encoder_hidden_states = hidden_states 47 | elif attn.norm_cross: 48 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 49 | 50 | key = attn.to_k(encoder_hidden_states) 51 | value = attn.to_v(encoder_hidden_states) 52 | 53 | # Cross Frame Attention 54 | if not is_cross_attention: 55 | video_length = max(1, key.size()[0] // self.batch_size) 56 | first_frame_index = [0] * video_length 57 | 58 | # rearrange keys to have batch and frames in the 1st and 2nd dims respectively 59 | key = rearrange_3(key, video_length) 60 | key = key[:, first_frame_index] 61 | # rearrange values to have batch and frames in the 1st and 2nd dims respectively 62 | value = rearrange_3(value, video_length) 63 | value = value[:, first_frame_index] 64 | 65 | # rearrange back to original shape 66 | key = rearrange_4(key) 67 | value = rearrange_4(value) 68 | 69 | head_dim = inner_dim // attn.heads 70 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 71 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 72 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 73 | 74 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 75 | # TODO: add support for attn.scale when we move to Torch 2.1 76 | hidden_states = F.scaled_dot_product_attention( 77 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 78 | ) 79 | 80 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 81 | hidden_states = hidden_states.to(query.dtype) 82 | 83 | # linear proj 84 | hidden_states = attn.to_out[0](hidden_states) 85 | # dropout 86 | hidden_states = attn.to_out[1](hidden_states) 87 | 88 | return hidden_states 89 | -------------------------------------------------------------------------------- /modules/pipelines/pipeline_framepainter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Callable, Dict, List, Optional, Union 17 | 18 | import numpy as np 19 | import PIL.Image 20 | import torch 21 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 22 | 23 | from diffusers.image_processor import VaeImageProcessor 24 | from diffusers.models import AutoencoderKLTemporalDecoder 25 | from diffusers.utils import BaseOutput, logging 26 | from diffusers.utils.torch_utils import randn_tensor 27 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 28 | 29 | from ..sparse_control_encoder import SparseControlEncoder 30 | from ..unet_spatio_temporal_condition_edit import UNetSpatioTemporalConditionEdit 31 | from ..utils.scheduling_euler_discrete_karras_fix import EulerDiscreteScheduler 32 | 33 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 34 | 35 | 36 | def _get_add_time_ids( 37 | noise_aug_strength, 38 | dtype, 39 | batch_size, 40 | fps=4, 41 | motion_bucket_id=128, 42 | unet=None, 43 | ): 44 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength] 45 | 46 | passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids) 47 | expected_add_embed_dim = unet.add_embedding.linear_1.in_features 48 | 49 | if expected_add_embed_dim != passed_add_embed_dim: 50 | raise ValueError( 51 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 52 | ) 53 | 54 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 55 | # add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 56 | 57 | 58 | return add_time_ids 59 | 60 | 61 | def tensor2vid(video: torch.Tensor, processor, output_type="np"): 62 | # Based on: 63 | # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 64 | 65 | batch_size, channels, num_frames, height, width = video.shape 66 | outputs = [] 67 | for batch_idx in range(batch_size): 68 | batch_vid = video[batch_idx].permute(1, 0, 2, 3) 69 | batch_output = processor.postprocess(batch_vid, output_type) 70 | 71 | outputs.append(batch_output) 72 | 73 | return outputs 74 | 75 | 76 | @dataclass 77 | class FramePainterPipelineOutput(BaseOutput): 78 | r""" 79 | Output class for sketch-based image editing pipeline. 80 | 81 | Args: 82 | frames (`[List[PIL.Image.Image]`, `np.ndarray`]): 83 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 84 | num_channels)`. 85 | """ 86 | 87 | frames: Union[List[PIL.Image.Image], np.ndarray] 88 | 89 | 90 | class FramePainterPipeline(DiffusionPipeline): 91 | r""" 92 | Pipeline to edit images from an input image and a sketch image using FramePainter. 93 | 94 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 95 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 96 | 97 | Args: 98 | vae ([`AutoencoderKL`]): 99 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 100 | image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): 101 | Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). 102 | unet ([`UNetSpatioTemporalConditionEdit`]): 103 | A `UNetSpatioTemporalConditionEdit` to denoise the encoded image latents. 104 | scheduler ([`EulerDiscreteScheduler`]): 105 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. 106 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 107 | A `CLIPImageProcessor` to extract features from generated images. 108 | """ 109 | 110 | #model_cpu_offload_seq = "image_encoder->unet->vae" 111 | #model_cpu_offload_seq = "unet->vae" 112 | model_cpu_offload_seq = "unet" 113 | _callback_tensor_inputs = ["latents"] 114 | 115 | def __init__( 116 | self, 117 | #vae: AutoencoderKLTemporalDecoder, 118 | #image_encoder: CLIPVisionModelWithProjection, 119 | unet: UNetSpatioTemporalConditionEdit, 120 | sparse_control_encoder: SparseControlEncoder, 121 | scheduler: EulerDiscreteScheduler, 122 | feature_extractor: CLIPImageProcessor, 123 | vae_config=None, 124 | ): 125 | super().__init__() 126 | 127 | self.register_modules( 128 | #vae=vae, 129 | # image_encoder=image_encoder, 130 | sparse_control_encoder=sparse_control_encoder, 131 | unet=unet, 132 | scheduler=scheduler, 133 | feature_extractor=feature_extractor, 134 | ) 135 | self.vae_config = vae_config 136 | self.vae_scale_factor = 2 ** (len(self.vae_config.block_out_channels) - 1) 137 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 138 | 139 | def _encode_image(self, image_embs, device, num_videos_per_prompt, do_classifier_free_guidance): 140 | #dtype = next(self.image_encoder.parameters()).dtype 141 | #dtype=image_embs.dtype 142 | # if not isinstance(image, torch.Tensor): 143 | # image = self.image_processor.pil_to_numpy(image) 144 | # image = self.image_processor.numpy_to_pt(image) 145 | 146 | # # We normalize the image before resizing to match with the original implementation. 147 | # # Then we unnormalize it after resizing. 148 | # image = image * 2.0 - 1.0 149 | # image = _resize_with_antialiasing(image, (224, 224)) 150 | # image = (image + 1.0) / 2.0 151 | 152 | # # Normalize the image with for CLIP input 153 | # image = self.feature_extractor( 154 | # images=image, 155 | # do_normalize=True, 156 | # do_center_crop=False, 157 | # do_resize=False, 158 | # do_rescale=False, 159 | # return_tensors="pt", 160 | # ).pixel_values 161 | 162 | #image = image.to(device=device, dtype=dtype) 163 | #image_embeddings = self.image_encoder(image).image_embeds 164 | image_embeddings = image_embs.unsqueeze(1) 165 | 166 | # duplicate image embeddings for each generation per prompt, using mps friendly method 167 | bs_embed, seq_len, _ = image_embeddings.shape 168 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) 169 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 170 | 171 | if do_classifier_free_guidance: 172 | negative_image_embeddings = torch.zeros_like(image_embeddings) 173 | 174 | # For classifier free guidance, we need to do two forward passes. 175 | # Here we concatenate the unconditional and text embeddings into a single batch 176 | # to avoid doing two forward passes 177 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) 178 | 179 | return image_embeddings 180 | 181 | def _encode_vae_image( 182 | self, 183 | image: torch.Tensor, 184 | device, 185 | num_videos_per_prompt, 186 | do_classifier_free_guidance, 187 | ): 188 | image = image.to(device=device) 189 | image_latents = self.vae.encode(image).latent_dist.mode() 190 | 191 | if do_classifier_free_guidance: 192 | negative_image_latents = torch.zeros_like(image_latents) 193 | 194 | # For classifier free guidance, we need to do two forward passes. 195 | # Here we concatenate the unconditional and text embeddings into a single batch 196 | # to avoid doing two forward passes 197 | image_latents = torch.cat([negative_image_latents, image_latents]) 198 | 199 | # duplicate image_latents for each generation per prompt, using mps friendly method 200 | image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) 201 | 202 | return image_latents 203 | 204 | def _get_add_time_ids( 205 | self, 206 | fps, 207 | motion_bucket_id, 208 | noise_aug_strength, 209 | dtype, 210 | batch_size, 211 | num_videos_per_prompt, 212 | do_classifier_free_guidance, 213 | ): 214 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength] 215 | 216 | passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) 217 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 218 | 219 | if expected_add_embed_dim != passed_add_embed_dim: 220 | raise ValueError( 221 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 222 | ) 223 | 224 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 225 | add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 226 | 227 | if do_classifier_free_guidance: 228 | add_time_ids = torch.cat([add_time_ids, add_time_ids]) 229 | 230 | return add_time_ids 231 | 232 | def decode_latents(self, latents, num_frames, decode_chunk_size=14): 233 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] 234 | latents = latents.flatten(0, 1) 235 | 236 | latents = 1 / self.vae.config.scaling_factor * latents 237 | 238 | # decode decode_chunk_size frames at a time to avoid OOM 239 | frames = [] 240 | for i in range(0, latents.shape[0], decode_chunk_size): 241 | num_frames_in = latents[i : i + decode_chunk_size].shape[0] 242 | decode_kwargs = {} 243 | decode_kwargs["num_frames"] = num_frames_in 244 | frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample 245 | frames.append(frame) 246 | frames = torch.cat(frames, dim=0) 247 | 248 | # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] 249 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) 250 | 251 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 252 | frames = frames.float() 253 | return frames 254 | 255 | def check_inputs(self, image, height, width): 256 | if ( 257 | not isinstance(image, torch.Tensor) 258 | and not isinstance(image, PIL.Image.Image) 259 | and not isinstance(image, list) 260 | ): 261 | raise ValueError( 262 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" 263 | f" {type(image)}" 264 | ) 265 | 266 | if height % 8 != 0 or width % 8 != 0: 267 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 268 | 269 | def prepare_latents( 270 | self, 271 | batch_size, 272 | num_frames, 273 | num_channels_latents, 274 | height, 275 | width, 276 | dtype, 277 | device, 278 | generator, 279 | latents=None, 280 | ): 281 | shape = ( 282 | batch_size, 283 | num_frames, 284 | num_channels_latents // 2, 285 | height // self.vae_scale_factor, 286 | width // self.vae_scale_factor, 287 | ) 288 | if isinstance(generator, list) and len(generator) != batch_size: 289 | raise ValueError( 290 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 291 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 292 | ) 293 | 294 | if latents is None: 295 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 296 | else: 297 | latents = latents.to(device) 298 | 299 | # scale the initial noise by the standard deviation required by the scheduler 300 | latents = latents * self.scheduler.init_noise_sigma 301 | return latents 302 | 303 | @property 304 | def guidance_scale(self): 305 | return self._guidance_scale 306 | 307 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 308 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 309 | # corresponds to doing no classifier free guidance. 310 | @property 311 | def do_classifier_free_guidance(self): 312 | return self._guidance_scale >= 1 and self.unet.config.time_cond_proj_dim is None 313 | 314 | @property 315 | def num_timesteps(self): 316 | return self._num_timesteps 317 | 318 | @torch.no_grad() 319 | def __call__( 320 | self, 321 | image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], 322 | edit_condition:Optional[torch.FloatTensor] = None, 323 | height: int = 576, 324 | width: int = 1024, 325 | num_inference_steps: int = 25, 326 | guidance_scale = 3.0, 327 | fps: int = 7, 328 | motion_bucket_id: int = 127, 329 | noise_aug_strength: int = 0.02, 330 | decode_chunk_size: Optional[int] = 4, 331 | num_videos_per_prompt: Optional[int] = 1, 332 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 333 | latents: Optional[torch.FloatTensor] = None, 334 | output_type: Optional[str] = "pil", 335 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 336 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 337 | return_dict: bool = True, 338 | edit_cond_scale=1.0, 339 | batch_size=1, 340 | guidance_scale_decay="inv_square", 341 | image_embs=None, 342 | input_latents=None, 343 | 344 | ): 345 | r""" 346 | The call function to the pipeline for generation. 347 | 348 | Args: 349 | image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): 350 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with 351 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). 352 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 353 | The height in pixels of the generated image. 354 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 355 | The width in pixels of the generated image. 356 | num_frames (`int`, *optional*): 357 | The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` 358 | num_inference_steps (`int`, *optional*, defaults to 25): 359 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 360 | expense of slower inference. This parameter is modulated by `strength`. 361 | min_guidance_scale (`float`, *optional*, defaults to 1.0): 362 | The minimum guidance scale. Used for the classifier free guidance with first frame. 363 | max_guidance_scale (`float`, *optional*, defaults to 3.0): 364 | The maximum guidance scale. Used for the classifier free guidance with last frame. 365 | fps (`int`, *optional*, defaults to 7): 366 | Frames per second. The rate at which the generated images shall be exported to a video after generation. 367 | Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. 368 | motion_bucket_id (`int`, *optional*, defaults to 127): 369 | The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. 370 | noise_aug_strength (`int`, *optional*, defaults to 0.02): 371 | The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. 372 | decode_chunk_size (`int`, *optional*): 373 | The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency 374 | between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once 375 | for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. 376 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 377 | The number of images to generate per prompt. 378 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 379 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 380 | generation deterministic. 381 | latents (`torch.FloatTensor`, *optional*): 382 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 383 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 384 | tensor is generated by sampling using the supplied random `generator`. 385 | output_type (`str`, *optional*, defaults to `"pil"`): 386 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 387 | callback_on_step_end (`Callable`, *optional*): 388 | A function that calls at the end of each denoising steps during the inference. The function is called 389 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 390 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 391 | `callback_on_step_end_tensor_inputs`. 392 | callback_on_step_end_tensor_inputs (`List`, *optional*): 393 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 394 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 395 | `._callback_tensor_inputs` attribute of your pipeline class. 396 | return_dict (`bool`, *optional*, defaults to `True`): 397 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 398 | plain tuple. 399 | 400 | Returns: 401 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: 402 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, 403 | otherwise a `tuple` is returned where the first element is a list of list with the generated frames. 404 | 405 | Examples: 406 | 407 | ```py 408 | from diffusers import StableVideoDiffusionPipeline 409 | from diffusers.utils import load_image, export_to_video 410 | 411 | pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") 412 | pipe.to("cuda") 413 | 414 | image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") 415 | image = image.resize((1024, 576)) 416 | 417 | frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] 418 | export_to_video(frames, "generated.mp4", fps=7) 419 | ``` 420 | """ 421 | # 0. Default height and width to unet 422 | height = height or self.unet.config.sample_size * self.vae_scale_factor 423 | width = width or self.unet.config.sample_size * self.vae_scale_factor 424 | 425 | num_frames = 2 426 | 427 | # 1. Check inputs. Raise error if not correct 428 | self.check_inputs(image, height, width) 429 | 430 | device = self._execution_device 431 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 432 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 433 | # corresponds to doing no classifier free guidance. 434 | do_classifier_free_guidance = guidance_scale >= 1.0 435 | 436 | # 3. Encode input image 437 | image_embeddings = self._encode_image(image_embs, device, num_videos_per_prompt, do_classifier_free_guidance) 438 | 439 | # NOTE: Stable Diffusion Video was conditioned on fps - 1, which 440 | # is why it is reduced here. 441 | # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 442 | fps = fps - 1 443 | 444 | # 4. Encode input image using VAE 445 | if input_latents is None: 446 | image = self.image_processor.preprocess(image, height=height, width=width) 447 | noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) 448 | image = image + noise_aug_strength * noise # 449 | 450 | needs_upcasting = (self.vae.dtype == torch.float16 or self.vae.dtype == torch.bfloat16) and self.vae.config.force_upcast 451 | if needs_upcasting: 452 | self_vae_dtype = self.vae.dtype 453 | self.vae.to(dtype=torch.float32) 454 | 455 | image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) 456 | else: 457 | image_latents = input_latents 458 | image_latents = image_latents.to(image_embeddings.dtype) 459 | #print(image_latents.shape,123)#torch.Size([2, 4, 64, 64]) 460 | # cast back to fp16 if needed 461 | if input_latents is None and needs_upcasting: 462 | self.vae.to(dtype=self_vae_dtype) 463 | 464 | # Repeat the image latents for each frame so we can concatenate them with the noise 465 | # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] 466 | image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) 467 | 468 | # 5. Get Added Time IDs 469 | added_time_ids = self._get_add_time_ids( 470 | fps, 471 | motion_bucket_id, 472 | noise_aug_strength, 473 | image_embeddings.dtype, 474 | batch_size, 475 | num_videos_per_prompt, 476 | do_classifier_free_guidance, 477 | ) 478 | added_time_ids = added_time_ids.to(device) 479 | 480 | # 4. Prepare timesteps 481 | self.scheduler.set_timesteps(num_inference_steps, device=device) 482 | timesteps = self.scheduler.timesteps 483 | 484 | # 5. Prepare latent variables 485 | num_channels_latents = self.unet.config.in_channels 486 | 487 | latents = self.prepare_latents( 488 | batch_size * num_videos_per_prompt, 489 | num_frames, 490 | num_channels_latents, 491 | height, 492 | width, 493 | image_embeddings.dtype, 494 | device, 495 | generator, 496 | latents, 497 | ) 498 | 499 | # prepare edit condition 500 | if not isinstance(edit_condition, torch.Tensor): 501 | edit_condition = self.image_processor.preprocess(edit_condition, height=height, width=width) 502 | edit_condition = (edit_condition + 1.0) / 2 503 | edit_condition = edit_condition.unsqueeze(0) 504 | #print(edit_condition.shape,234) #torch.Size([1, 2, 3, 512, 512]) 505 | if do_classifier_free_guidance: 506 | edit_condition = torch.cat([edit_condition] * 2) 507 | #print(edit_condition.shape,2345) #torch.Size([2, 2, 3, 512, 512]) 508 | edit_condition = edit_condition.to(device, latents.dtype) 509 | 510 | self._guidance_scale = guidance_scale 511 | 512 | noise_aug_strength = 0.02 #"¯\_(ツ)_/¯ 513 | added_time_ids = _get_add_time_ids( 514 | noise_aug_strength, 515 | image_embeddings.dtype, 516 | batch_size, 517 | 6, 518 | 128, 519 | unet=self.unet, 520 | ) 521 | if do_classifier_free_guidance: 522 | added_time_ids = torch.cat([added_time_ids] * 2) 523 | added_time_ids = added_time_ids.to(latents.device) 524 | 525 | # 8. Denoising loop 526 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 527 | self._num_timesteps = len(timesteps) 528 | with self.progress_bar(total=num_inference_steps) as progress_bar: 529 | for i, t in enumerate(timesteps): 530 | # expand the latents if we are doing classifier free guidance 531 | 532 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents ##Size([2, 2, 4, 64, 64]) torch.Size([1, 2, 4, 64, 64]) 533 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 534 | 535 | edit_output = self.sparse_control_encoder( 536 | edit_condition, 537 | t, 538 | ) 539 | if do_classifier_free_guidance: 540 | N = edit_output['output'].shape[0] 541 | edit_output['scale'] = torch.tensor(edit_output['scale']).to(latent_model_input).repeat(N)[:, None, None, None] 542 | edit_output['scale'][:N // 2] *= 0 543 | edit_output['scale'][N // 2:] *= edit_cond_scale 544 | 545 | # Concatenate image_latents over channels dimention 546 | #print(latent_model_input.shape,123,image_latents.shape) #([2,2, 4, 64, 64]) 123 torch.Size([2, 2, 4, 64, 64]) 547 | latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) 548 | 549 | # predict the noise residual 550 | noise_pred = self.unet( 551 | latent_model_input, 552 | t, 553 | encoder_hidden_states=image_embeddings, 554 | added_time_ids=added_time_ids, 555 | conditional_controls=edit_output, 556 | return_dict=False, 557 | )[0] 558 | 559 | # perform guidance 560 | if guidance_scale_decay == "none": 561 | cur_guidance_scale_points = self.guidance_scale 562 | elif guidance_scale_decay == "inv_square": 563 | cur_guidance_scale_points = \ 564 | (self.guidance_scale-1.0) * (1.0 - i/len(timesteps)) ** 2 + 1 565 | else: 566 | raise NotImplementedError("decay schedule not implemented") 567 | 568 | if do_classifier_free_guidance: 569 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 570 | noise_pred = noise_pred_uncond + cur_guidance_scale_points * (noise_pred_cond - noise_pred_uncond) 571 | 572 | # compute the previous noisy sample x_t -> x_t-1 573 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 574 | 575 | if callback_on_step_end is not None: 576 | callback_kwargs = {} 577 | for k in callback_on_step_end_tensor_inputs: 578 | callback_kwargs[k] = locals()[k] 579 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 580 | 581 | latents = callback_outputs.pop("latents", latents) 582 | 583 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 584 | progress_bar.update() 585 | 586 | if not output_type == "latent": 587 | # cast back to fp16 if needed 588 | if needs_upcasting: 589 | self.vae.to(dtype=self_vae_dtype) 590 | frames = self.decode_latents(latents, num_frames, decode_chunk_size) 591 | frames = tensor2vid(frames, self.image_processor, output_type=output_type) 592 | else: 593 | frames = latents / 0.18215 594 | #print(frames.shape) #([1, 2, 4, 64, 64]) 595 | self.maybe_free_model_hooks() 596 | 597 | if not return_dict: 598 | return frames 599 | 600 | return FramePainterPipelineOutput(frames=frames) 601 | 602 | 603 | # resizing utils 604 | # TODO: clean up later 605 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 606 | 607 | if input.ndim == 3: 608 | input = input.unsqueeze(0) # Add a batch dimension 609 | 610 | h, w = input.shape[-2:] 611 | factors = (h / size[0], w / size[1]) 612 | 613 | # First, we have to determine sigma 614 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 615 | sigmas = ( 616 | max((factors[0] - 1.0) / 2.0, 0.001), 617 | max((factors[1] - 1.0) / 2.0, 0.001), 618 | ) 619 | 620 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 621 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 622 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 623 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 624 | 625 | # Make sure it is odd 626 | if (ks[0] % 2) == 0: 627 | ks = ks[0] + 1, ks[1] 628 | 629 | if (ks[1] % 2) == 0: 630 | ks = ks[0], ks[1] + 1 631 | 632 | input = _gaussian_blur2d(input, ks, sigmas) 633 | 634 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 635 | return output 636 | 637 | 638 | def _compute_padding(kernel_size): 639 | """Compute padding tuple.""" 640 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 641 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 642 | if len(kernel_size) < 2: 643 | raise AssertionError(kernel_size) 644 | computed = [k - 1 for k in kernel_size] 645 | 646 | # for even kernels we need to do asymmetric padding :( 647 | out_padding = 2 * len(kernel_size) * [0] 648 | 649 | for i in range(len(kernel_size)): 650 | computed_tmp = computed[-(i + 1)] 651 | 652 | pad_front = computed_tmp // 2 653 | pad_rear = computed_tmp - pad_front 654 | 655 | out_padding[2 * i + 0] = pad_front 656 | out_padding[2 * i + 1] = pad_rear 657 | 658 | return out_padding 659 | 660 | 661 | def _filter2d(input, kernel): 662 | # prepare kernel 663 | b, c, h, w = input.shape 664 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 665 | 666 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 667 | 668 | height, width = tmp_kernel.shape[-2:] 669 | 670 | padding_shape: list[int] = _compute_padding([height, width]) 671 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 672 | 673 | # kernel and input tensor reshape to align element-wise or batch-wise params 674 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 675 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 676 | 677 | # convolve the tensor with the kernel. 678 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 679 | 680 | out = output.view(b, c, h, w) 681 | return out 682 | 683 | 684 | def _gaussian(window_size: int, sigma): 685 | if isinstance(sigma, float): 686 | sigma = torch.tensor([[sigma]]) 687 | 688 | batch_size = sigma.shape[0] 689 | 690 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 691 | 692 | if window_size % 2 == 0: 693 | x = x + 0.5 694 | 695 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 696 | 697 | return gauss / gauss.sum(-1, keepdim=True) 698 | 699 | 700 | def _gaussian_blur2d(input, kernel_size, sigma): 701 | if isinstance(sigma, tuple): 702 | sigma = torch.tensor([sigma], dtype=input.dtype) 703 | else: 704 | sigma = sigma.to(dtype=input.dtype) 705 | 706 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 707 | bs = sigma.shape[0] 708 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 709 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 710 | out_x = _filter2d(input, kernel_x[..., None, :]) 711 | out = _filter2d(out_x, kernel_y[..., None]) 712 | 713 | return out 714 | -------------------------------------------------------------------------------- /modules/sparse_control_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D 10 | 11 | 12 | ACTIVATION_FUNCTIONS = { 13 | "swish": nn.SiLU(), 14 | "silu": nn.SiLU(), 15 | "mish": nn.Mish(), 16 | "gelu": nn.GELU(), 17 | "relu": nn.ReLU(), 18 | } 19 | 20 | class SparseControlEncoder(ModelMixin, ConfigMixin): 21 | _supports_gradient_checkpointing = True 22 | 23 | @register_to_config 24 | def __init__( 25 | self, 26 | input_dim = 3, 27 | time_embed_dim = 256, 28 | in_channels = [128, 128], 29 | out_channels = [128, 256], 30 | groups = [4, 8], 31 | non_linearity: str = "relu", 32 | ): 33 | super().__init__() 34 | self.nonlinearity = ACTIVATION_FUNCTIONS[non_linearity] 35 | 36 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0) 37 | self.time_embedding = TimestepEmbedding(128, time_embed_dim) 38 | self.embedding = nn.Sequential( 39 | nn.Conv2d(input_dim, 64, kernel_size=3, stride=2, padding=1), 40 | nn.GroupNorm(2, 64), 41 | self.nonlinearity, 42 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 43 | nn.GroupNorm(2, 64), 44 | self.nonlinearity, 45 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 46 | nn.GroupNorm(2, 128), 47 | self.nonlinearity, 48 | ) 49 | 50 | self.down_res = nn.ModuleList() 51 | self.down_sample = nn.ModuleList() 52 | for i in range(len(in_channels)): 53 | self.down_res.append( 54 | ResnetBlock2D( 55 | in_channels=in_channels[i], 56 | out_channels=out_channels[i], 57 | temb_channels=time_embed_dim, 58 | groups=groups[i] 59 | ), 60 | ) 61 | self.down_sample.append( 62 | Downsample2D( 63 | out_channels[i], 64 | use_conv=True, 65 | out_channels=out_channels[i], 66 | padding=1, 67 | name="op", 68 | ) 69 | ) 70 | 71 | self.mid_convs = nn.ModuleList() 72 | self.mid_convs.append(nn.Sequential( 73 | nn.Conv2d( 74 | in_channels=out_channels[-1], 75 | out_channels=out_channels[-1], 76 | kernel_size=3, 77 | stride=1, 78 | padding=1 79 | ), 80 | self.nonlinearity, 81 | nn.GroupNorm(8, out_channels[-1]), 82 | nn.Conv2d( 83 | in_channels=out_channels[-1], 84 | out_channels=out_channels[-1], 85 | kernel_size=3, 86 | stride=1, 87 | padding=1 88 | ), 89 | nn.GroupNorm(8, out_channels[-1]), 90 | )) 91 | self.mid_convs.append( 92 | nn.Conv2d( 93 | in_channels=out_channels[-1], 94 | out_channels=320, 95 | kernel_size=1, 96 | stride=1, 97 | )) 98 | 99 | self.scale = 1. 100 | 101 | def _set_gradient_checkpointing(self, module, value=False): 102 | if hasattr(module, "gradient_checkpointing"): 103 | module.gradient_checkpointing = value 104 | 105 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 106 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 107 | """ 108 | Sets the attention processor to use [feed forward 109 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 110 | 111 | Parameters: 112 | chunk_size (`int`, *optional*): 113 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 114 | over each tensor of dim=`dim`. 115 | dim (`int`, *optional*, defaults to `0`): 116 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 117 | or dim=1 (sequence length). 118 | """ 119 | if dim not in [0, 1]: 120 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 121 | 122 | # By default chunk size is 1 123 | chunk_size = chunk_size or 1 124 | 125 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 126 | if hasattr(module, "set_chunk_feed_forward"): 127 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 128 | 129 | for child in module.children(): 130 | fn_recursive_feed_forward(child, chunk_size, dim) 131 | 132 | for module in self.children(): 133 | fn_recursive_feed_forward(module, chunk_size, dim) 134 | 135 | def forward( 136 | self, 137 | sample: torch.FloatTensor, 138 | timestep: Union[torch.Tensor, float, int], 139 | ): 140 | 141 | timesteps = timestep 142 | if not torch.is_tensor(timesteps): 143 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 144 | # This would be a good case for the `match` statement (Python 3.10+) 145 | is_mps = sample.device.type == "mps" 146 | if isinstance(timestep, float): 147 | dtype = torch.float32 if is_mps else torch.float64 148 | else: 149 | dtype = torch.int32 if is_mps else torch.int64 150 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 151 | elif len(timesteps.shape) == 0: 152 | timesteps = timesteps[None].to(sample.device) 153 | 154 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 155 | batch_size, num_frames = sample.shape[:2] 156 | timesteps = timesteps.expand(batch_size) 157 | 158 | t_emb = self.time_proj(timesteps) 159 | 160 | # `Timesteps` does not contain any weights and will always return f32 tensors 161 | # but time_embedding might actually be running in fp16. so we need to cast here. 162 | # there might be better ways to encapsulate this. 163 | t_emb = t_emb.to(dtype=sample.dtype) 164 | 165 | emb_batch = self.time_embedding(t_emb) 166 | 167 | # Flatten the batch and frames dimensions 168 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 169 | sample = sample.flatten(0, 1) 170 | # Repeat the embeddings num_video_frames times 171 | # emb: [batch, channels] -> [batch * frames, channels] 172 | emb = emb_batch.repeat_interleave(num_frames, dim=0) 173 | 174 | sample = self.embedding(sample) 175 | 176 | for res, downsample in zip(self.down_res, self.down_sample): 177 | sample = res(sample, emb) 178 | sample = downsample(sample, emb) 179 | 180 | sample = self.mid_convs[0](sample) + sample 181 | sample = self.mid_convs[1](sample) 182 | 183 | return { 184 | 'output': sample, 185 | 'scale': self.scale, 186 | } 187 | 188 | -------------------------------------------------------------------------------- /modules/transformer_temporal.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | 4 | from typing import Any, Dict, Optional 5 | 6 | from diffusers.models.attention import BasicTransformerBlock, _chunked_feed_forward 7 | 8 | 9 | class MatchingBasicTransformerBlock(BasicTransformerBlock): 10 | r""" 11 | A Matching Transformer block. 12 | 13 | Parameters: 14 | dim (`int`): The number of channels in the input and output. 15 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 16 | attention_head_dim (`int`): The number of channels in each head. 17 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 18 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 19 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 20 | num_embeds_ada_norm (: 21 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 22 | attention_bias (: 23 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 24 | only_cross_attention (`bool`, *optional*): 25 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 26 | double_self_attention (`bool`, *optional*): 27 | Whether to use two self-attention layers. In this case no cross attention layers are used. 28 | upcast_attention (`bool`, *optional*): 29 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 30 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 31 | Whether to use learnable elementwise affine parameters for normalization. 32 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 33 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 34 | final_dropout (`bool` *optional*, defaults to False): 35 | Whether to apply a final dropout after the last feed-forward layer. 36 | attention_type (`str`, *optional*, defaults to `"default"`): 37 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 38 | positional_embeddings (`str`, *optional*, defaults to `None`): 39 | The type of positional embeddings to apply to. 40 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 41 | The maximum number of positional embeddings to apply. 42 | """ 43 | 44 | def __init__( 45 | self, *args, **kwargs 46 | ): 47 | super().__init__(*args, **kwargs) 48 | self.sparse_attention = copy.deepcopy(self.attn1) 49 | if hasattr(self.sparse_attention, "to_out"): 50 | self.sparse_attention.to_out[0].weight.data.fill_(0.0) 51 | if self.sparse_attention.to_out[0].bias is not None: 52 | self.sparse_attention.to_out[0].bias.data.fill_(0.0) 53 | 54 | 55 | def forward( 56 | self, 57 | hidden_states: torch.Tensor, 58 | attention_mask: Optional[torch.Tensor] = None, 59 | encoder_hidden_states: Optional[torch.Tensor] = None, 60 | encoder_attention_mask: Optional[torch.Tensor] = None, 61 | timestep: Optional[torch.LongTensor] = None, 62 | cross_attention_kwargs: Dict[str, Any] = None, 63 | class_labels: Optional[torch.LongTensor] = None, 64 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 65 | ) -> torch.Tensor: 66 | 67 | # Notice that normalization is always applied before the real computation in the following blocks. 68 | # 0. Self-Attention 69 | batch_size = hidden_states.shape[0] 70 | 71 | if self.norm_type == "ada_norm": 72 | norm_hidden_states = self.norm1(hidden_states, timestep) 73 | elif self.norm_type == "ada_norm_zero": 74 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 75 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 76 | ) 77 | elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: 78 | norm_hidden_states = self.norm1(hidden_states) 79 | elif self.norm_type == "ada_norm_continuous": 80 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 81 | elif self.norm_type == "ada_norm_single": 82 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 83 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 84 | ).chunk(6, dim=1) 85 | norm_hidden_states = self.norm1(hidden_states) 86 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 87 | else: 88 | raise ValueError("Incorrect norm used") 89 | 90 | if self.pos_embed is not None: 91 | norm_hidden_states = self.pos_embed(norm_hidden_states) 92 | 93 | # 1. Prepare GLIGEN inputs 94 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 95 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 96 | 97 | attn_output = self.attn1( 98 | norm_hidden_states, 99 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 100 | attention_mask=attention_mask, 101 | **cross_attention_kwargs, 102 | ) 103 | 104 | if self.sparse_attention is not None: 105 | sparse_attn_output = self.sparse_attention( 106 | norm_hidden_states, 107 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 108 | attention_mask=attention_mask, 109 | **cross_attention_kwargs, 110 | ) 111 | attn_output = attn_output + sparse_attn_output 112 | 113 | if self.norm_type == "ada_norm_zero": 114 | attn_output = gate_msa.unsqueeze(1) * attn_output 115 | elif self.norm_type == "ada_norm_single": 116 | attn_output = gate_msa * attn_output 117 | 118 | hidden_states = attn_output + hidden_states 119 | if hidden_states.ndim == 4: 120 | hidden_states = hidden_states.squeeze(1) 121 | 122 | # 1.2 GLIGEN Control 123 | if gligen_kwargs is not None: 124 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 125 | 126 | # 3. Cross-Attention 127 | if self.attn2 is not None: 128 | if self.norm_type == "ada_norm": 129 | norm_hidden_states = self.norm2(hidden_states, timestep) 130 | elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: 131 | norm_hidden_states = self.norm2(hidden_states) 132 | elif self.norm_type == "ada_norm_single": 133 | # For PixArt norm2 isn't applied here: 134 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 135 | norm_hidden_states = hidden_states 136 | elif self.norm_type == "ada_norm_continuous": 137 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 138 | else: 139 | raise ValueError("Incorrect norm") 140 | 141 | if self.pos_embed is not None and self.norm_type != "ada_norm_single": 142 | norm_hidden_states = self.pos_embed(norm_hidden_states) 143 | 144 | attn_output = self.attn2( 145 | norm_hidden_states, 146 | encoder_hidden_states=encoder_hidden_states, 147 | attention_mask=encoder_attention_mask, 148 | **cross_attention_kwargs, 149 | ) 150 | hidden_states = attn_output + hidden_states 151 | 152 | # 4. Feed-forward 153 | # i2vgen doesn't have this norm 🤷‍♂️ 154 | if self.norm_type == "ada_norm_continuous": 155 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 156 | elif not self.norm_type == "ada_norm_single": 157 | norm_hidden_states = self.norm3(hidden_states) 158 | 159 | if self.norm_type == "ada_norm_zero": 160 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 161 | 162 | if self.norm_type == "ada_norm_single": 163 | norm_hidden_states = self.norm2(hidden_states) 164 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 165 | 166 | if self._chunk_size is not None: 167 | # "feed_forward_chunk_size" can be used to save memory 168 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 169 | else: 170 | ff_output = self.ff(norm_hidden_states) 171 | 172 | if self.norm_type == "ada_norm_zero": 173 | ff_output = gate_mlp.unsqueeze(1) * ff_output 174 | elif self.norm_type == "ada_norm_single": 175 | ff_output = gate_mlp * ff_output 176 | 177 | hidden_states = ff_output + hidden_states 178 | if hidden_states.ndim == 4: 179 | hidden_states = hidden_states.squeeze(1) 180 | 181 | return hidden_states 182 | -------------------------------------------------------------------------------- /modules/unet_spatio_temporal_condition_edit.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.loaders import UNet2DConditionLoadersMixin 9 | from diffusers.utils import BaseOutput, logging 10 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor 11 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 12 | from diffusers.models.modeling_utils import ModelMixin 13 | from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block 14 | import torch.nn.functional as F 15 | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17 | 18 | 19 | @dataclass 20 | class UNetSpatioTemporalConditionOutput(BaseOutput): 21 | """ 22 | The output of [`UNetSpatioTemporalConditionModel`]. 23 | 24 | Args: 25 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): 26 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 27 | """ 28 | 29 | sample: torch.FloatTensor = None 30 | 31 | 32 | class UNetSpatioTemporalConditionEdit(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 33 | r""" 34 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample 35 | shaped output. 36 | 37 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 38 | for all models (such as downloading or saving). 39 | 40 | Parameters: 41 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 42 | Height and width of input/output sample. 43 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 44 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 45 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 46 | The tuple of downsample blocks to use. 47 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 48 | The tuple of upsample blocks to use. 49 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 50 | The tuple of output channels for each block. 51 | addition_time_embed_dim: (`int`, defaults to 256): 52 | Dimension to to encode the additional time ids. 53 | projection_class_embeddings_input_dim (`int`, defaults to 768): 54 | The dimension of the projection of encoded `added_time_ids`. 55 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 56 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 57 | The dimension of the cross attention features. 58 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 59 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 60 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 61 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 62 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 63 | The number of attention heads. 64 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 65 | """ 66 | 67 | _supports_gradient_checkpointing = True 68 | 69 | @register_to_config 70 | def __init__( 71 | self, 72 | sample_size: Optional[int] = None, 73 | in_channels: int = 8, 74 | out_channels: int = 4, 75 | down_block_types: Tuple[str] = ( 76 | "CrossAttnDownBlockSpatioTemporal", 77 | "CrossAttnDownBlockSpatioTemporal", 78 | "CrossAttnDownBlockSpatioTemporal", 79 | "DownBlockSpatioTemporal", 80 | ), 81 | up_block_types: Tuple[str] = ( 82 | "UpBlockSpatioTemporal", 83 | "CrossAttnUpBlockSpatioTemporal", 84 | "CrossAttnUpBlockSpatioTemporal", 85 | "CrossAttnUpBlockSpatioTemporal", 86 | ), 87 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 88 | addition_time_embed_dim: int = 256, 89 | projection_class_embeddings_input_dim: int = 768, 90 | layers_per_block: Union[int, Tuple[int]] = 2, 91 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 92 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 93 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), 94 | num_frames: int = 25, 95 | upcast_attention: bool = False, 96 | ): 97 | super().__init__() 98 | 99 | self.sample_size = sample_size 100 | 101 | # Check inputs 102 | if len(down_block_types) != len(up_block_types): 103 | raise ValueError( 104 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 105 | ) 106 | 107 | if len(block_out_channels) != len(down_block_types): 108 | raise ValueError( 109 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 110 | ) 111 | 112 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 113 | raise ValueError( 114 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 115 | ) 116 | 117 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 118 | raise ValueError( 119 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 120 | ) 121 | 122 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 123 | raise ValueError( 124 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 125 | ) 126 | 127 | # input 128 | self.conv_in = nn.Conv2d( 129 | in_channels, 130 | block_out_channels[0], 131 | kernel_size=3, 132 | padding=1, 133 | ) 134 | 135 | # time 136 | time_embed_dim = block_out_channels[0] * 4 137 | 138 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 139 | timestep_input_dim = block_out_channels[0] 140 | 141 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 142 | 143 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 144 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 145 | 146 | self.down_blocks = nn.ModuleList([]) 147 | self.up_blocks = nn.ModuleList([]) 148 | 149 | if isinstance(num_attention_heads, int): 150 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 151 | 152 | if isinstance(cross_attention_dim, int): 153 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 154 | 155 | if isinstance(layers_per_block, int): 156 | layers_per_block = [layers_per_block] * len(down_block_types) 157 | 158 | if isinstance(transformer_layers_per_block, int): 159 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 160 | 161 | blocks_time_embed_dim = time_embed_dim 162 | 163 | # down 164 | output_channel = block_out_channels[0] 165 | for i, down_block_type in enumerate(down_block_types): 166 | input_channel = output_channel 167 | output_channel = block_out_channels[i] 168 | is_final_block = i == len(block_out_channels) - 1 169 | 170 | down_block = get_down_block( 171 | down_block_type, 172 | num_layers=layers_per_block[i], 173 | transformer_layers_per_block=transformer_layers_per_block[i], 174 | in_channels=input_channel, 175 | out_channels=output_channel, 176 | temb_channels=blocks_time_embed_dim, 177 | add_downsample=not is_final_block, 178 | resnet_eps=1e-5, 179 | cross_attention_dim=cross_attention_dim[i], 180 | num_attention_heads=num_attention_heads[i], 181 | resnet_act_fn="silu", 182 | upcast_attention=upcast_attention, 183 | ) 184 | self.down_blocks.append(down_block) 185 | 186 | # mid 187 | self.mid_block = UNetMidBlockSpatioTemporal( 188 | block_out_channels[-1], 189 | temb_channels=blocks_time_embed_dim, 190 | transformer_layers_per_block=transformer_layers_per_block[-1], 191 | cross_attention_dim=cross_attention_dim[-1], 192 | num_attention_heads=num_attention_heads[-1], 193 | ) 194 | 195 | # count how many layers upsample the images 196 | self.num_upsamplers = 0 197 | 198 | # up 199 | reversed_block_out_channels = list(reversed(block_out_channels)) 200 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 201 | reversed_layers_per_block = list(reversed(layers_per_block)) 202 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 203 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 204 | 205 | output_channel = reversed_block_out_channels[0] 206 | for i, up_block_type in enumerate(up_block_types): 207 | is_final_block = i == len(block_out_channels) - 1 208 | 209 | prev_output_channel = output_channel 210 | output_channel = reversed_block_out_channels[i] 211 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 212 | 213 | # add upsample block for all BUT final layer 214 | if not is_final_block: 215 | add_upsample = True 216 | self.num_upsamplers += 1 217 | else: 218 | add_upsample = False 219 | 220 | up_block = get_up_block( 221 | up_block_type, 222 | num_layers=reversed_layers_per_block[i] + 1, 223 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 224 | in_channels=input_channel, 225 | out_channels=output_channel, 226 | prev_output_channel=prev_output_channel, 227 | temb_channels=blocks_time_embed_dim, 228 | add_upsample=add_upsample, 229 | resnet_eps=1e-5, 230 | resolution_idx=i, 231 | cross_attention_dim=reversed_cross_attention_dim[i], 232 | num_attention_heads=reversed_num_attention_heads[i], 233 | resnet_act_fn="silu", 234 | upcast_attention=upcast_attention, 235 | ) 236 | self.up_blocks.append(up_block) 237 | prev_output_channel = output_channel 238 | 239 | # out 240 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) 241 | self.conv_act = nn.SiLU() 242 | 243 | self.conv_out = nn.Conv2d( 244 | block_out_channels[0], 245 | out_channels, 246 | kernel_size=3, 247 | padding=1, 248 | ) 249 | 250 | @property 251 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 252 | r""" 253 | Returns: 254 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 255 | indexed by its weight name. 256 | """ 257 | # set recursively 258 | processors = {} 259 | 260 | def fn_recursive_add_processors( 261 | name: str, 262 | module: torch.nn.Module, 263 | processors: Dict[str, AttentionProcessor], 264 | ): 265 | if hasattr(module, "get_processor"): 266 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 267 | 268 | for sub_name, child in module.named_children(): 269 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 270 | 271 | return processors 272 | 273 | for name, module in self.named_children(): 274 | fn_recursive_add_processors(name, module, processors) 275 | 276 | return processors 277 | 278 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 279 | r""" 280 | Sets the attention processor to use to compute attention. 281 | 282 | Parameters: 283 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 284 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 285 | for **all** `Attention` layers. 286 | 287 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 288 | processor. This is strongly recommended when setting trainable attention processors. 289 | 290 | """ 291 | count = len(self.attn_processors.keys()) 292 | 293 | if isinstance(processor, dict) and len(processor) != count: 294 | raise ValueError( 295 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 296 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 297 | ) 298 | 299 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 300 | if hasattr(module, "set_processor"): 301 | if not isinstance(processor, dict): 302 | module.set_processor(processor) 303 | else: 304 | module.set_processor(processor.pop(f"{name}.processor")) 305 | 306 | for sub_name, child in module.named_children(): 307 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 308 | 309 | for name, module in self.named_children(): 310 | fn_recursive_attn_processor(name, module, processor) 311 | 312 | def set_default_attn_processor(self): 313 | """ 314 | Disables custom attention processors and sets the default attention implementation. 315 | """ 316 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 317 | processor = AttnProcessor() 318 | else: 319 | raise ValueError( 320 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 321 | ) 322 | 323 | self.set_attn_processor(processor) 324 | 325 | def _set_gradient_checkpointing(self, module, value=False): 326 | if hasattr(module, "gradient_checkpointing"): 327 | module.gradient_checkpointing = value 328 | 329 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 330 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 331 | """ 332 | Sets the attention processor to use [feed forward 333 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 334 | 335 | Parameters: 336 | chunk_size (`int`, *optional*): 337 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 338 | over each tensor of dim=`dim`. 339 | dim (`int`, *optional*, defaults to `0`): 340 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 341 | or dim=1 (sequence length). 342 | """ 343 | if dim not in [0, 1]: 344 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 345 | 346 | # By default chunk size is 1 347 | chunk_size = chunk_size or 1 348 | 349 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 350 | if hasattr(module, "set_chunk_feed_forward"): 351 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 352 | 353 | for child in module.children(): 354 | fn_recursive_feed_forward(child, chunk_size, dim) 355 | 356 | for module in self.children(): 357 | fn_recursive_feed_forward(module, chunk_size, dim) 358 | 359 | def forward( 360 | self, 361 | sample: torch.FloatTensor, 362 | timestep: Union[torch.Tensor, float, int], 363 | encoder_hidden_states: torch.Tensor, 364 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 365 | mid_block_additional_residual: Optional[torch.Tensor] = None, 366 | conditional_controls: Optional[torch.Tensor] = None, 367 | return_dict: bool = True, 368 | added_time_ids: torch.Tensor=None, 369 | image_only_indicator: torch.Tensor=None, 370 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 371 | r""" 372 | The [`UNetSpatioTemporalConditionModel`] forward method. 373 | 374 | Args: 375 | sample (`torch.FloatTensor`): 376 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 377 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 378 | encoder_hidden_states (`torch.FloatTensor`): 379 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 380 | added_time_ids: (`torch.FloatTensor`): 381 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 382 | embeddings and added to the time embeddings. 383 | return_dict (`bool`, *optional*, defaults to `True`): 384 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 385 | tuple. 386 | Returns: 387 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 388 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 389 | a `tuple` is returned where the first element is the sample tensor. 390 | """ 391 | # 1. time 392 | timesteps = timestep 393 | if not torch.is_tensor(timesteps): 394 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 395 | # This would be a good case for the `match` statement (Python 3.10+) 396 | is_mps = sample.device.type == "mps" 397 | if isinstance(timestep, float): 398 | dtype = torch.float32 if is_mps else torch.float64 399 | else: 400 | dtype = torch.int32 if is_mps else torch.int64 401 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 402 | elif len(timesteps.shape) == 0: 403 | timesteps = timesteps[None].to(sample.device) 404 | 405 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 406 | batch_size, num_frames = sample.shape[:2] 407 | timesteps = timesteps.expand(batch_size) 408 | 409 | t_emb = self.time_proj(timesteps) 410 | 411 | # `Timesteps` does not contain any weights and will always return f32 tensors 412 | # but time_embedding might actually be running in fp16. so we need to cast here. 413 | # there might be better ways to encapsulate this. 414 | t_emb = t_emb.to(dtype=sample.dtype) 415 | 416 | emb = self.time_embedding(t_emb) 417 | 418 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 419 | time_embeds = time_embeds.reshape((batch_size, -1)) 420 | time_embeds = time_embeds.to(emb.dtype) 421 | aug_emb = self.add_embedding(time_embeds) 422 | emb = emb + aug_emb 423 | 424 | # Flatten the batch and frames dimensions 425 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 426 | sample = sample.flatten(0, 1) 427 | # Repeat the embeddings num_video_frames times 428 | # emb: [batch, channels] -> [batch * frames, channels] 429 | emb = emb.repeat_interleave(num_frames, dim=0) 430 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 431 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 432 | 433 | # 2. pre-process 434 | sample = self.conv_in(sample) 435 | if image_only_indicator is None: 436 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 437 | 438 | down_block_res_samples = (sample,) 439 | for idx,downsample_block in enumerate(self.down_blocks): 440 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 441 | sample, res_samples = downsample_block( 442 | hidden_states=sample, 443 | temb=emb, 444 | encoder_hidden_states=encoder_hidden_states, 445 | image_only_indicator=image_only_indicator, 446 | ) 447 | else: 448 | sample, res_samples = downsample_block( 449 | hidden_states=sample, 450 | temb=emb, 451 | image_only_indicator=image_only_indicator, 452 | ) 453 | 454 | down_block_res_samples += res_samples 455 | 456 | if idx == 0 and conditional_controls is not None: 457 | scale = conditional_controls['scale'] 458 | conditional_controls = conditional_controls['output'] 459 | mean_latents, std_latents = torch.mean(sample, dim=(1, 2, 3), keepdim=True), torch.std(sample, dim=(1, 2, 3), keepdim=True) 460 | mean_control, std_control = torch.mean(conditional_controls, dim=(1, 2, 3), keepdim=True), torch.std(conditional_controls, dim=(1, 2, 3), keepdim=True) 461 | conditional_controls = (conditional_controls - mean_control) * (std_latents / (std_control + 1e-5)) + mean_latents 462 | conditional_controls = F.adaptive_avg_pool2d(conditional_controls, sample.shape[-2:]) 463 | sample = sample + conditional_controls * scale * 0.2 464 | 465 | if down_block_additional_residuals is not None: 466 | new_down_block_res_samples = () 467 | 468 | for down_block_res_sample, down_block_additional_residual in zip( 469 | down_block_res_samples, down_block_additional_residuals 470 | ): 471 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 472 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 473 | 474 | down_block_res_samples = new_down_block_res_samples 475 | 476 | # 4. mid 477 | sample = self.mid_block( 478 | hidden_states=sample, 479 | temb=emb, 480 | encoder_hidden_states=encoder_hidden_states, 481 | image_only_indicator=image_only_indicator, 482 | ) 483 | 484 | # 5. up 485 | for i, upsample_block in enumerate(self.up_blocks): 486 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 487 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 488 | 489 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 490 | sample = upsample_block( 491 | hidden_states=sample, 492 | temb=emb, 493 | res_hidden_states_tuple=res_samples, 494 | encoder_hidden_states=encoder_hidden_states, 495 | image_only_indicator=image_only_indicator, 496 | ) 497 | else: 498 | sample = upsample_block( 499 | hidden_states=sample, 500 | temb=emb, 501 | res_hidden_states_tuple=res_samples, 502 | image_only_indicator=image_only_indicator, 503 | ) 504 | 505 | # 6. post-process 506 | sample = self.conv_norm_out(sample) 507 | sample = self.conv_act(sample) 508 | sample = self.conv_out(sample) 509 | 510 | # 7. Reshape back to original shape 511 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 512 | 513 | if not return_dict: 514 | return (sample,) 515 | 516 | return UNetSpatioTemporalConditionOutput(sample=sample) 517 | -------------------------------------------------------------------------------- /modules/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /modules/utils/attention_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from diffusers.models.attention import BasicTransformerBlock 3 | from diffusers.models.transformers.transformer_temporal import TransformerSpatioTemporalModel 4 | from ..transformer_temporal import MatchingBasicTransformerBlock 5 | 6 | 7 | def set_matching_attention(unet): 8 | for name, module in unet.named_children(): 9 | if isinstance(module, BasicTransformerBlock): 10 | new_module = MatchingBasicTransformerBlock( 11 | dim = module.dim, 12 | num_attention_heads = module.num_attention_heads, 13 | attention_head_dim = module.attention_head_dim, 14 | dropout = module.dropout, 15 | cross_attention_dim = module.cross_attention_dim, 16 | activation_fn = module.activation_fn, 17 | num_embeds_ada_norm = module.num_embeds_ada_norm, 18 | attention_bias = module.attention_bias, 19 | only_cross_attention = module.only_cross_attention, 20 | double_self_attention = module.double_self_attention, 21 | norm_elementwise_affine = module.norm_elementwise_affine, 22 | norm_type = module.norm_type, 23 | positional_embeddings = module.positional_embeddings, 24 | num_positional_embeddings = module.num_positional_embeddings, 25 | ) 26 | new_module.load_state_dict(module.state_dict(),strict=False) 27 | setattr(unet, name, new_module) 28 | else: 29 | set_matching_attention(module) 30 | return unet 31 | 32 | def set_matching_attention_processor(unet, attn_processor): 33 | for block in unet.down_blocks: 34 | if hasattr(block, "attentions"): 35 | for attn in block.attentions: 36 | if isinstance(attn, TransformerSpatioTemporalModel): 37 | for a_block in attn.transformer_blocks: 38 | if isinstance(a_block, MatchingBasicTransformerBlock): 39 | a_block.sparse_attention.processor = attn_processor 40 | 41 | for attn in unet.mid_block.attentions: 42 | if isinstance(attn, TransformerSpatioTemporalModel): 43 | for a_block in attn.transformer_blocks: 44 | if isinstance(a_block, MatchingBasicTransformerBlock): 45 | a_block.sparse_attention.processor = attn_processor 46 | 47 | for block in unet.up_blocks: 48 | if hasattr(block, "attentions"): 49 | for attn in block.attentions: 50 | if isinstance(attn, TransformerSpatioTemporalModel): 51 | for a_block in attn.transformer_blocks: 52 | if isinstance(a_block, MatchingBasicTransformerBlock): 53 | a_block.sparse_attention.processor = attn_processor 54 | return unet 55 | -------------------------------------------------------------------------------- /modules/utils/scheduling_euler_discrete_karras_fix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from dataclasses import dataclass 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.utils.torch_utils import randn_tensor 25 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 26 | import torch.nn.functional as F 27 | 28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 30 | 31 | 32 | @dataclass 33 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete 34 | class EulerDiscreteSchedulerOutput(BaseOutput): 35 | """ 36 | Output class for the scheduler's `step` function output. 37 | 38 | Args: 39 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 40 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 41 | denoising loop. 42 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 43 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 44 | `pred_original_sample` can be used to preview progress or for guidance. 45 | """ 46 | 47 | prev_sample: torch.FloatTensor 48 | pred_original_sample: Optional[torch.FloatTensor] = None 49 | 50 | 51 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 52 | def betas_for_alpha_bar( 53 | num_diffusion_timesteps, 54 | max_beta=0.999, 55 | alpha_transform_type="cosine", 56 | ): 57 | """ 58 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 59 | (1-beta) over time from t = [0,1]. 60 | 61 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 62 | to that part of the diffusion process. 63 | 64 | 65 | Args: 66 | num_diffusion_timesteps (`int`): the number of betas to produce. 67 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 68 | prevent singularities. 69 | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. 70 | Choose from `cosine` or `exp` 71 | 72 | Returns: 73 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 74 | """ 75 | if alpha_transform_type == "cosine": 76 | 77 | def alpha_bar_fn(t): 78 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 79 | 80 | elif alpha_transform_type == "exp": 81 | 82 | def alpha_bar_fn(t): 83 | return math.exp(t * -12.0) 84 | 85 | else: 86 | raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") 87 | 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) 93 | return torch.tensor(betas, dtype=torch.float32) 94 | 95 | 96 | # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr 97 | def rescale_zero_terminal_snr(betas): 98 | """ 99 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 100 | 101 | 102 | Args: 103 | betas (`torch.FloatTensor`): 104 | the betas that the scheduler is being initialized with. 105 | 106 | Returns: 107 | `torch.FloatTensor`: rescaled betas with zero terminal SNR 108 | """ 109 | # Convert betas to alphas_bar_sqrt 110 | alphas = 1.0 - betas 111 | alphas_cumprod = torch.cumprod(alphas, dim=0) 112 | alphas_bar_sqrt = alphas_cumprod.sqrt() 113 | 114 | # Store old values. 115 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 116 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 117 | 118 | # Shift so the last timestep is zero. 119 | alphas_bar_sqrt -= alphas_bar_sqrt_T 120 | 121 | # Scale so the first timestep is back to the old value. 122 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 123 | 124 | # Convert alphas_bar_sqrt to betas 125 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 126 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 127 | alphas = torch.cat([alphas_bar[0:1], alphas]) 128 | betas = 1 - alphas 129 | 130 | return betas 131 | 132 | 133 | class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 134 | """ 135 | Euler scheduler. 136 | 137 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 138 | methods the library implements for all schedulers such as loading and saving. 139 | 140 | Args: 141 | num_train_timesteps (`int`, defaults to 1000): 142 | The number of diffusion steps to train the model. 143 | beta_start (`float`, defaults to 0.0001): 144 | The starting `beta` value of inference. 145 | beta_end (`float`, defaults to 0.02): 146 | The final `beta` value. 147 | beta_schedule (`str`, defaults to `"linear"`): 148 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 149 | `linear` or `scaled_linear`. 150 | trained_betas (`np.ndarray`, *optional*): 151 | Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. 152 | prediction_type (`str`, defaults to `epsilon`, *optional*): 153 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), 154 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen 155 | Video](https://imagen.research.google/video/paper.pdf) paper). 156 | interpolation_type(`str`, defaults to `"linear"`, *optional*): 157 | The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of 158 | `"linear"` or `"log_linear"`. 159 | use_karras_sigmas (`bool`, *optional*, defaults to `False`): 160 | Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, 161 | the sigmas are determined according to a sequence of noise levels {σi}. 162 | timestep_spacing (`str`, defaults to `"linspace"`): 163 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 164 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 165 | steps_offset (`int`, defaults to 0): 166 | An offset added to the inference steps. You can use a combination of `offset=1` and 167 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable 168 | Diffusion. 169 | rescale_betas_zero_snr (`bool`, defaults to `False`): 170 | Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and 171 | dark samples instead of limiting it to samples with medium brightness. Loosely related to 172 | [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). 173 | """ 174 | 175 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 176 | order = 1 177 | 178 | @register_to_config 179 | def __init__( 180 | self, 181 | num_train_timesteps: int = 1000, 182 | beta_start: float = 0.0001, 183 | beta_end: float = 0.02, 184 | beta_schedule: str = "linear", 185 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 186 | prediction_type: str = "epsilon", 187 | interpolation_type: str = "linear", 188 | use_karras_sigmas: Optional[bool] = False, 189 | sigma_min: Optional[float] = None, 190 | sigma_max: Optional[float] = None, 191 | timestep_spacing: str = "linspace", 192 | timestep_type: str = "discrete", # can be "discrete" or "continuous" 193 | steps_offset: int = 0, 194 | rescale_betas_zero_snr: bool = False, 195 | ): 196 | if trained_betas is not None: 197 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 198 | elif beta_schedule == "linear": 199 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 200 | elif beta_schedule == "scaled_linear": 201 | # this schedule is very specific to the latent diffusion model. 202 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 203 | elif beta_schedule == "squaredcos_cap_v2": 204 | # Glide cosine schedule 205 | self.betas = betas_for_alpha_bar(num_train_timesteps) 206 | else: 207 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 208 | 209 | if rescale_betas_zero_snr: 210 | self.betas = rescale_zero_terminal_snr(self.betas) 211 | 212 | self.alphas = 1.0 - self.betas 213 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 214 | 215 | if rescale_betas_zero_snr: 216 | # Close to 0 without being 0 so first sigma is not inf 217 | # FP16 smallest positive subnormal works well here 218 | self.alphas_cumprod[-1] = 2**-24 219 | 220 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 221 | timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() 222 | 223 | sigmas = sigmas[::-1].copy() 224 | 225 | if self.use_karras_sigmas: 226 | log_sigmas = np.log(sigmas) 227 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_train_timesteps) 228 | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) 229 | 230 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) 231 | 232 | # setable values 233 | self.num_inference_steps = None 234 | 235 | # TODO: Support the full EDM scalings for all prediction types and timestep types 236 | if timestep_type == "continuous" and prediction_type == "v_prediction": 237 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) 238 | else: 239 | self.timesteps = torch.from_numpy(timesteps.astype(np.float32)) 240 | 241 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 242 | 243 | self.is_scale_input_called = False 244 | self.use_karras_sigmas = use_karras_sigmas 245 | 246 | self._step_index = None 247 | 248 | @property 249 | def init_noise_sigma(self): 250 | # standard deviation of the initial noise distribution 251 | max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() 252 | if self.config.timestep_spacing in ["linspace", "trailing"]: 253 | return max_sigma 254 | 255 | return (max_sigma**2 + 1) ** 0.5 256 | 257 | @property 258 | def step_index(self): 259 | """ 260 | The index counter for current timestep. It will increae 1 after each scheduler step. 261 | """ 262 | return self._step_index 263 | 264 | def scale_model_input( 265 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] 266 | ) -> torch.FloatTensor: 267 | """ 268 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 269 | current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. 270 | 271 | Args: 272 | sample (`torch.FloatTensor`): 273 | The input sample. 274 | timestep (`int`, *optional*): 275 | The current timestep in the diffusion chain. 276 | 277 | Returns: 278 | `torch.FloatTensor`: 279 | A scaled input sample. 280 | """ 281 | if self.step_index is None: 282 | self._init_step_index(timestep) 283 | 284 | sigma = self.sigmas[self.step_index] 285 | sample = sample / ((sigma**2 + 1) ** 0.5) 286 | 287 | self.is_scale_input_called = True 288 | return sample 289 | 290 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 291 | """ 292 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 293 | 294 | Args: 295 | num_inference_steps (`int`): 296 | The number of diffusion steps used when generating samples with a pre-trained model. 297 | device (`str` or `torch.device`, *optional*): 298 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 299 | """ 300 | self.num_inference_steps = num_inference_steps 301 | 302 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 303 | if self.config.timestep_spacing == "linspace": 304 | timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ 305 | ::-1 306 | ].copy() 307 | elif self.config.timestep_spacing == "leading": 308 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 309 | # creates integer timesteps by multiplying by ratio 310 | # casting to int to avoid issues when num_inference_step is power of 3 311 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) 312 | timesteps += self.config.steps_offset 313 | elif self.config.timestep_spacing == "trailing": 314 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 315 | # creates integer timesteps by multiplying by ratio 316 | # casting to int to avoid issues when num_inference_step is power of 3 317 | timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) 318 | timesteps -= 1 319 | else: 320 | raise ValueError( 321 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." 322 | ) 323 | 324 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 325 | log_sigmas = np.log(sigmas) 326 | 327 | if self.config.interpolation_type == "linear": 328 | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) 329 | elif self.config.interpolation_type == "log_linear": 330 | sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() 331 | else: 332 | raise ValueError( 333 | f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" 334 | " 'linear' or 'log_linear'" 335 | ) 336 | 337 | if self.use_karras_sigmas: 338 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) 339 | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) 340 | 341 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 342 | 343 | # TODO: Support the full EDM scalings for all prediction types and timestep types 344 | if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": 345 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device) 346 | else: 347 | self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) 348 | 349 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 350 | self._step_index = None 351 | 352 | def _sigma_to_t(self, sigma, log_sigmas): 353 | # get log sigma 354 | log_sigma = np.log(np.maximum(sigma, 1e-10)) 355 | 356 | # get distribution 357 | dists = log_sigma - log_sigmas[:, np.newaxis] 358 | 359 | # get sigmas range 360 | low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) 361 | high_idx = low_idx + 1 362 | 363 | low = log_sigmas[low_idx] 364 | high = log_sigmas[high_idx] 365 | 366 | # interpolate sigmas 367 | w = (low - log_sigma) / (low - high) 368 | w = np.clip(w, 0, 1) 369 | 370 | # transform interpolation to time range 371 | t = (1 - w) * low_idx + w * high_idx 372 | t = t.reshape(sigma.shape) 373 | return t 374 | 375 | # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 376 | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: 377 | """Constructs the noise schedule of Karras et al. (2022).""" 378 | 379 | # Hack to make sure that other schedulers which copy this function don't break 380 | # TODO: Add this logic to the other schedulers 381 | if hasattr(self.config, "sigma_min"): 382 | sigma_min = self.config.sigma_min 383 | else: 384 | sigma_min = None 385 | 386 | if hasattr(self.config, "sigma_max"): 387 | sigma_max = self.config.sigma_max 388 | else: 389 | sigma_max = None 390 | 391 | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() 392 | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() 393 | 394 | rho = 7.0 # 7.0 is the value used in the paper 395 | ramp = np.linspace(0, 1, num_inference_steps) 396 | min_inv_rho = sigma_min ** (1 / rho) 397 | max_inv_rho = sigma_max ** (1 / rho) 398 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 399 | return sigmas 400 | 401 | def _init_step_index(self, timestep): 402 | if isinstance(timestep, torch.Tensor): 403 | timestep = timestep.to(self.timesteps.device) 404 | 405 | index_candidates = (self.timesteps == timestep).nonzero() 406 | 407 | # The sigma index that is taken for the **very** first `step` 408 | # is always the second index (or the last index if there is only 1) 409 | # This way we can ensure we don't accidentally skip a sigma in 410 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 411 | if len(index_candidates) > 1: 412 | step_index = index_candidates[1] 413 | else: 414 | step_index = index_candidates[0] 415 | 416 | self._step_index = step_index.item() 417 | 418 | def step( 419 | self, 420 | model_output: torch.FloatTensor, 421 | timestep: Union[float, torch.FloatTensor], 422 | sample: torch.FloatTensor, 423 | s_churn: float = 0.0, 424 | s_tmin: float = 0.0, 425 | s_tmax: float = float("inf"), 426 | s_noise: float = 1.0, 427 | generator: Optional[torch.Generator] = None, 428 | return_dict: bool = True, 429 | ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: 430 | """ 431 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 432 | process from the learned model outputs (most often the predicted noise). 433 | 434 | Args: 435 | model_output (`torch.FloatTensor`): 436 | The direct output from learned diffusion model. 437 | timestep (`float`): 438 | The current discrete timestep in the diffusion chain. 439 | sample (`torch.FloatTensor`): 440 | A current instance of a sample created by the diffusion process. 441 | s_churn (`float`): 442 | s_tmin (`float`): 443 | s_tmax (`float`): 444 | s_noise (`float`, defaults to 1.0): 445 | Scaling factor for noise added to the sample. 446 | generator (`torch.Generator`, *optional*): 447 | A random number generator. 448 | return_dict (`bool`): 449 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 450 | tuple. 451 | 452 | Returns: 453 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 454 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 455 | returned, otherwise a tuple is returned where the first element is the sample tensor. 456 | """ 457 | 458 | if ( 459 | isinstance(timestep, int) 460 | or isinstance(timestep, torch.IntTensor) 461 | or isinstance(timestep, torch.LongTensor) 462 | ): 463 | raise ValueError( 464 | ( 465 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 466 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 467 | " one of the `scheduler.timesteps` as a timestep." 468 | ), 469 | ) 470 | 471 | if not self.is_scale_input_called: 472 | logger.warning( 473 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 474 | "See `StableDiffusionPipeline` for a usage example." 475 | ) 476 | 477 | if self.step_index is None: 478 | self._init_step_index(timestep) 479 | 480 | # Upcast to avoid precision issues when computing prev_sample 481 | sample = sample.to(torch.float32) 482 | 483 | sigma = self.sigmas[self.step_index] 484 | 485 | gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 486 | 487 | noise = randn_tensor( 488 | model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator 489 | ) 490 | 491 | eps = noise * s_noise 492 | sigma_hat = sigma * (gamma + 1) 493 | 494 | if gamma > 0: 495 | sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 496 | 497 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 498 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for 499 | # backwards compatibility 500 | if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": 501 | pred_original_sample = model_output 502 | elif self.config.prediction_type == "epsilon": 503 | pred_original_sample = sample - sigma_hat * model_output 504 | elif self.config.prediction_type == "v_prediction": 505 | # denoised = model_output * c_out + input * c_skip 506 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 507 | else: 508 | raise ValueError( 509 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 510 | ) 511 | 512 | # 2. Convert to an ODE derivative 513 | derivative = (sample - pred_original_sample) / sigma_hat 514 | 515 | dt = self.sigmas[self.step_index + 1] - sigma_hat 516 | 517 | prev_sample = sample + derivative * dt 518 | 519 | # Cast sample back to model compatible dtype 520 | prev_sample = prev_sample.to(model_output.dtype) 521 | 522 | # upon completion increase step index by one 523 | self._step_index += 1 524 | 525 | if not return_dict: 526 | return (prev_sample,) 527 | 528 | return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 529 | 530 | def add_noise( 531 | self, 532 | original_samples: torch.FloatTensor, 533 | noise: torch.FloatTensor, 534 | timesteps: torch.FloatTensor, 535 | ) -> torch.FloatTensor: 536 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 537 | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) 538 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): 539 | # mps does not support float64 540 | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) 541 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) 542 | else: 543 | schedule_timesteps = self.timesteps.to(original_samples.device) 544 | timesteps = timesteps.to(original_samples.device) 545 | 546 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 547 | 548 | sigma = sigmas[step_indices].flatten() 549 | while len(sigma.shape) < len(original_samples.shape): 550 | sigma = sigma.unsqueeze(-1) 551 | 552 | noisy_samples = original_samples + noise * sigma 553 | return noisy_samples 554 | 555 | def __len__(self): 556 | return self.config.num_train_timesteps 557 | -------------------------------------------------------------------------------- /node_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | import gc 9 | from comfy.utils import common_upscale,ProgressBar 10 | from huggingface_hub import hf_hub_download 11 | import time 12 | cur_path = os.path.dirname(os.path.abspath(__file__)) 13 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 14 | 15 | class timer: 16 | def __init__(self, method_name="timed process"): 17 | self.method = method_name 18 | 19 | def __enter__(self): 20 | self.start = time.time() 21 | print(f"{self.method} starts") 22 | 23 | def __exit__(self, exc_type, exc_val, exc_tb): 24 | end = time.time() 25 | print(f"{self.method} took {str(round(end - self.start, 2))}s") 26 | 27 | 28 | def process_image_with_mask(image, mask,width, height): 29 | 30 | mask_tensor=mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) 31 | 32 | input_image_resized = tensor2pil_upscale(image, width, height) 33 | mask_pil=tensor2pil_upscale(mask_tensor, width, height) 34 | 35 | return input_image_resized,mask_pil 36 | 37 | def convert_cf2diffuser(model,unet_config_file): 38 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint 39 | from diffusers import UNet2DConditionModel 40 | from .modules.unet_spatio_temporal_condition_edit import UNetSpatioTemporalConditionEdit 41 | cf_state_dict = model.diffusion_model.state_dict() 42 | unet_state_dict = model.model_config.process_unet_state_dict_for_saving(cf_state_dict) 43 | unet_config = UNetSpatioTemporalConditionEdit.load_config(unet_config_file) 44 | Unet = UNetSpatioTemporalConditionEdit.from_config(unet_config).to(device, torch.float16) 45 | #cf_state_dict = convert_ldm_unet_checkpoint(unet_state_dict, Unet.config) 46 | Unet.load_state_dict(unet_state_dict, strict=False) 47 | del cf_state_dict 48 | gc.collect() 49 | torch.cuda.empty_cache() 50 | return Unet 51 | 52 | def cv2pil(cv_image): 53 | """ 54 | 将OpenCV图像转换为PIL图像 55 | :param cv_image: OpenCV图像 56 | :return: PIL图像 57 | """ 58 | # 将图像从BGR转换为RGB 59 | rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) 60 | # 使用PIL的Image.fromarray方法将NumPy数组转换为PIL图像 61 | pil_image = Image.fromarray(rgb_image) 62 | return pil_image 63 | 64 | def tensor_to_pil(tensor): 65 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 66 | image = Image.fromarray(image_np, mode='RGB') 67 | return image 68 | 69 | def tensor2pil_list(image,width,height): 70 | B,_,_,_=image.size() 71 | if B==1: 72 | ref_image_list=[tensor2pil_upscale(image,width,height)] 73 | else: 74 | img_list = list(torch.chunk(image, chunks=B)) 75 | ref_image_list = [tensor2pil_upscale(img,width,height) for img in img_list] 76 | return ref_image_list 77 | 78 | def tensor2pil_upscale(img_tensor, width, height): 79 | samples = img_tensor.movedim(-1, 1) 80 | img = common_upscale(samples, width, height, "nearest-exact", "center") 81 | samples = img.movedim(1, -1) 82 | img_pil = tensor_to_pil(samples) 83 | return img_pil 84 | 85 | def tensor_upscale(img_tensor, width, height): 86 | samples = img_tensor.movedim(-1, 1) 87 | img = common_upscale(samples, width, height, "nearest-exact", "center") 88 | samples = img.movedim(1, -1) 89 | return samples 90 | 91 | 92 | 93 | def tensor2cv(tensor_image,RGB2BGR=True): 94 | if len(tensor_image.shape)==4:#bhwc to hwc 95 | tensor_image=tensor_image.squeeze(0) 96 | if tensor_image.is_cuda: 97 | tensor_image = tensor_image.cpu().detach() 98 | tensor_image=tensor_image.numpy() 99 | #反归一化 100 | maxValue=tensor_image.max() 101 | tensor_image=tensor_image*255/maxValue 102 | img_cv2=np.uint8(tensor_image)#32 to uint8 103 | if RGB2BGR: 104 | img_cv2=cv2.cvtColor(img_cv2,cv2.COLOR_RGB2BGR) 105 | return img_cv2 106 | 107 | def cvargb2tensor(img): 108 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 109 | img = torch.from_numpy(img.transpose((2, 0, 1))) 110 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 111 | 112 | def cv2tensor(img): 113 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 114 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 115 | img = torch.from_numpy(img.transpose((2, 0, 1))) 116 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 117 | 118 | def images_generator(img_list: list,): 119 | #get img size 120 | sizes = {} 121 | for image_ in img_list: 122 | if isinstance(image_,Image.Image): 123 | count = sizes.get(image_.size, 0) 124 | sizes[image_.size] = count + 1 125 | elif isinstance(image_,np.ndarray): 126 | count = sizes.get(image_.shape[:2][::-1], 0) 127 | sizes[image_.shape[:2][::-1]] = count + 1 128 | else: 129 | raise "unsupport image list,must be pil or cv2!!!" 130 | size = max(sizes.items(), key=lambda x: x[1])[0] 131 | yield size[0], size[1] 132 | 133 | # any to tensor 134 | def load_image(img_in): 135 | if isinstance(img_in, Image.Image): 136 | img_in=img_in.convert("RGB") 137 | i = np.array(img_in, dtype=np.float32) 138 | i = torch.from_numpy(i).div_(255) 139 | if i.shape[0] != size[1] or i.shape[1] != size[0]: 140 | i = torch.from_numpy(i).movedim(-1, 0).unsqueeze(0) 141 | i = common_upscale(i, size[0], size[1], "lanczos", "center") 142 | i = i.squeeze(0).movedim(0, -1).numpy() 143 | return i 144 | elif isinstance(img_in,np.ndarray): 145 | i=cv2.cvtColor(img_in,cv2.COLOR_BGR2RGB).astype(np.float32) 146 | i = torch.from_numpy(i).div_(255) 147 | #print(i.shape) 148 | return i 149 | else: 150 | raise "unsupport image list,must be pil,cv2 or tensor!!!" 151 | 152 | total_images = len(img_list) 153 | processed_images = 0 154 | pbar = ProgressBar(total_images) 155 | images = map(load_image, img_list) 156 | try: 157 | prev_image = next(images) 158 | while True: 159 | next_image = next(images) 160 | yield prev_image 161 | processed_images += 1 162 | pbar.update_absolute(processed_images, total_images) 163 | prev_image = next_image 164 | except StopIteration: 165 | pass 166 | if prev_image is not None: 167 | yield prev_image 168 | 169 | def load_images(img_list: list,): 170 | gen = images_generator(img_list) 171 | (width, height) = next(gen) 172 | images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (height, width, 3))))) 173 | if len(images) == 0: 174 | raise FileNotFoundError(f"No images could be loaded .") 175 | return images 176 | 177 | def tensor2pil(tensor): 178 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 179 | image = Image.fromarray(image_np, mode='RGB') 180 | return image 181 | 182 | def pil2narry(img): 183 | narry = torch.from_numpy(np.array(img).astype(np.float32) / 255.0).unsqueeze(0) 184 | return narry 185 | 186 | def equalize_lists(list1, list2): 187 | """ 188 | 比较两个列表的长度,如果不一致,则将较短的列表复制以匹配较长列表的长度。 189 | 190 | 参数: 191 | list1 (list): 第一个列表 192 | list2 (list): 第二个列表 193 | 194 | 返回: 195 | tuple: 包含两个长度相等的列表的元组 196 | """ 197 | len1 = len(list1) 198 | len2 = len(list2) 199 | 200 | if len1 == len2: 201 | pass 202 | elif len1 < len2: 203 | print("list1 is shorter than list2, copying list1 to match list2's length.") 204 | list1.extend(list1 * ((len2 // len1) + 1)) # 复制list1以匹配list2的长度 205 | list1 = list1[:len2] # 确保长度一致 206 | else: 207 | print("list2 is shorter than list1, copying list2 to match list1's length.") 208 | list2.extend(list2 * ((len1 // len2) + 1)) # 复制list2以匹配list1的长度 209 | list2 = list2[:len1] # 确保长度一致 210 | 211 | return list1, list2 212 | 213 | def file_exists(directory, filename): 214 | # 构建文件的完整路径 215 | file_path = os.path.join(directory, filename) 216 | # 检查文件是否存在 217 | return os.path.isfile(file_path) 218 | 219 | def download_weights(file_dir,repo_id,subfolder="",pt_name=""): 220 | if subfolder: 221 | file_path = os.path.join(file_dir,subfolder, pt_name) 222 | sub_dir=os.path.join(file_dir,subfolder) 223 | if not os.path.exists(sub_dir): 224 | os.makedirs(sub_dir) 225 | if not os.path.exists(file_path): 226 | file_path = hf_hub_download( 227 | repo_id=repo_id, 228 | subfolder=subfolder, 229 | filename=pt_name, 230 | local_dir = file_dir, 231 | ) 232 | return file_path 233 | else: 234 | file_path = os.path.join(file_dir, pt_name) 235 | if not os.path.exists(file_dir): 236 | os.makedirs(file_dir) 237 | if not os.path.exists(file_path): 238 | file_path = hf_hub_download( 239 | repo_id=repo_id, 240 | filename=pt_name, 241 | local_dir=file_dir, 242 | ) 243 | return file_path 244 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_framepainter" 3 | description = "Official pytorch implementation of 'FramePainter: Endowing Interactive Image Editing with Video Diffusion Priors',you can use it in comfyUI." 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["#accelerate==0.31.0", "#datasets==3.0.1", "#decord==0.6.0", "#deepspeed==0.15.1", "#diffusers==0.30.2", "#einops==0.8.0", "#gradio==5.13.1", "#imageio==2.35.1", "#imageio-ffmpeg==0.5.1", "#pillow==10.4.0", "safetensors", "scikit-image", "scikit-learn", "#scipy==1.14.1", "#tokenizers==0.19.1", "#torch==2.4.1", "#torchvision==0.19.1", "#transformers==4.41.1"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/smthemex/ComfyUI_FramePainter" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "smthemex" 14 | DisplayName = "ComfyUI_FramePainter" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #accelerate==0.31.0 2 | #datasets==3.0.1 3 | #decord==0.6.0 4 | #deepspeed==0.15.1 5 | #diffusers==0.30.2 6 | #einops==0.8.0 7 | #gradio==5.13.1 8 | #imageio==2.35.1 9 | #imageio-ffmpeg==0.5.1 10 | #pillow==10.4.0 11 | safetensors 12 | scikit-image 13 | scikit-learn 14 | #scipy==1.14.1 15 | #tokenizers==0.19.1 16 | #torch==2.4.1 17 | #torchvision==0.19.1 18 | #transformers==4.41.1 -------------------------------------------------------------------------------- /svd_repo/feature_extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": { 3 | "height": 224, 4 | "width": 224 5 | }, 6 | "do_center_crop": true, 7 | "do_convert_rgb": true, 8 | "do_normalize": true, 9 | "do_rescale": true, 10 | "do_resize": true, 11 | "feature_extractor_type": "CLIPFeatureExtractor", 12 | "image_mean": [ 13 | 0.48145466, 14 | 0.4578275, 15 | 0.40821073 16 | ], 17 | "image_processor_type": "CLIPImageProcessor", 18 | "image_std": [ 19 | 0.26862954, 20 | 0.26130258, 21 | 0.27577711 22 | ], 23 | "resample": 3, 24 | "rescale_factor": 0.00392156862745098, 25 | "size": { 26 | "shortest_edge": 224 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /svd_repo/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableVideoDiffusionPipeline", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "_name_or_path": "diffusers/svd-xt", 5 | "feature_extractor": [ 6 | "transformers", 7 | "CLIPImageProcessor" 8 | ], 9 | "image_encoder": [ 10 | "transformers", 11 | "CLIPVisionModelWithProjection" 12 | ], 13 | "scheduler": [ 14 | "diffusers", 15 | "EulerDiscreteScheduler" 16 | ], 17 | "unet": [ 18 | "diffusers", 19 | "UNetSpatioTemporalConditionModel" 20 | ], 21 | "vae": [ 22 | "diffusers", 23 | "AutoencoderKLTemporalDecoder" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /svd_repo/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 700.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /svd_repo/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNetSpatioTemporalConditionModel", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/unet", 5 | "addition_time_embed_dim": 256, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "cross_attention_dim": 1024, 13 | "down_block_types": [ 14 | "CrossAttnDownBlockSpatioTemporal", 15 | "CrossAttnDownBlockSpatioTemporal", 16 | "CrossAttnDownBlockSpatioTemporal", 17 | "DownBlockSpatioTemporal" 18 | ], 19 | "in_channels": 8, 20 | "layers_per_block": 2, 21 | "num_attention_heads": [ 22 | 5, 23 | 10, 24 | 20, 25 | 20 26 | ], 27 | "num_frames": 25, 28 | "out_channels": 4, 29 | "projection_class_embeddings_input_dim": 768, 30 | "sample_size": 96, 31 | "transformer_layers_per_block": 1, 32 | "up_block_types": [ 33 | "UpBlockSpatioTemporal", 34 | "CrossAttnUpBlockSpatioTemporal", 35 | "CrossAttnUpBlockSpatioTemporal", 36 | "CrossAttnUpBlockSpatioTemporal" 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /svd_repo/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKLTemporalDecoder", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/vae", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "force_upcast": true, 18 | "in_channels": 3, 19 | "latent_channels": 4, 20 | "layers_per_block": 2, 21 | "out_channels": 3, 22 | "sample_size": 768, 23 | "scaling_factor": 0.18215 24 | } 25 | --------------------------------------------------------------------------------