├── exampleA.png ├── assert ├── lq │ ├── lq1.mp4 │ ├── lq2.mp4 │ └── lq3.mp4 ├── method.png └── mask │ └── lq3.png ├── requirements.txt ├── __init__.py ├── svd_repo ├── model_index.json ├── scheduler │ └── scheduler_config.json ├── feature_extractor │ └── preprocessor_config.json ├── vae │ └── config.json └── unet │ └── config.json ├── pyproject.toml ├── config └── infer.yaml ├── src ├── models │ ├── id_proj.py │ ├── model_insightface_360k.py │ └── svfr_adapter │ │ ├── unet_3d_svd_condition_ip.py │ │ └── attention_processor.py ├── utils │ ├── noise_util.py │ └── util.py ├── dataset │ ├── face_align │ │ ├── align.py │ │ └── yoloface.py │ └── dataset.py └── pipelines │ └── pipeline.py ├── LICENSE ├── README.md ├── node_utils.py ├── SVFR_node.py └── infer.py /exampleA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_SVFR/HEAD/exampleA.png -------------------------------------------------------------------------------- /assert/lq/lq1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_SVFR/HEAD/assert/lq/lq1.mp4 -------------------------------------------------------------------------------- /assert/lq/lq2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_SVFR/HEAD/assert/lq/lq2.mp4 -------------------------------------------------------------------------------- /assert/lq/lq3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_SVFR/HEAD/assert/lq/lq3.mp4 -------------------------------------------------------------------------------- /assert/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_SVFR/HEAD/assert/method.png -------------------------------------------------------------------------------- /assert/mask/lq3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_SVFR/HEAD/assert/mask/lq3.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | diffusers 3 | moviepy 4 | numpy 5 | omegaconf 6 | opencv-python 7 | scikit-video 8 | transformers -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .SVFR_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 3 | 4 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 5 | -------------------------------------------------------------------------------- /svd_repo/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableVideoDiffusionPipeline", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "_name_or_path": "diffusers/svd-xt", 5 | "feature_extractor": [ 6 | "transformers", 7 | "CLIPImageProcessor" 8 | ], 9 | "image_encoder": [ 10 | "transformers", 11 | "CLIPVisionModelWithProjection" 12 | ], 13 | "scheduler": [ 14 | "diffusers", 15 | "EulerDiscreteScheduler" 16 | ], 17 | "unet": [ 18 | "diffusers", 19 | "UNetSpatioTemporalConditionModel" 20 | ], 21 | "vae": [ 22 | "diffusers", 23 | "AutoencoderKLTemporalDecoder" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_svfr" 3 | description = "SVFR is a unified framework for face video restoration that supports tasks such as BFR, Colorization, Inpainting,you can use it in ComfyUI" 4 | version = "1.0.0" 5 | license = { file = "LICENSE" } 6 | dependencies = ["accelerate", "diffusers", "moviepy", "numpy", "omegaconf", "opencv-python", "scikit-video", "transformers"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/smthemex/ComfyUI_SVFR" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "smthemex" 14 | DisplayName = "ComfyUI_SVFR" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /svd_repo/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 700.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /config/infer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | n_sample_frames: 16 3 | width: 512 4 | height: 512 5 | 6 | pretrained_model_name_or_path: "models/stable-video-diffusion-img2vid-xt" 7 | unet_checkpoint_path: "models/face_restoration/unet.pth" 8 | id_linear_checkpoint_path: "models/face_restoration/id_linear.pth" 9 | net_arcface_checkpoint_path: "models/face_restoration/insightface_glint360k.pth" 10 | # output_dir: 'result' 11 | 12 | 13 | # test config 14 | weight_dtype: 'fp16' 15 | num_inference_steps: 30 16 | decode_chunk_size: 16 17 | overlap: 3 18 | noise_aug_strength: 0.00 19 | min_appearance_guidance_scale: 2.0 20 | max_appearance_guidance_scale: 2.0 21 | i2i_noise_strength: 1.0 -------------------------------------------------------------------------------- /svd_repo/feature_extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": { 3 | "height": 224, 4 | "width": 224 5 | }, 6 | "do_center_crop": true, 7 | "do_convert_rgb": true, 8 | "do_normalize": true, 9 | "do_rescale": true, 10 | "do_resize": true, 11 | "feature_extractor_type": "CLIPFeatureExtractor", 12 | "image_mean": [ 13 | 0.48145466, 14 | 0.4578275, 15 | 0.40821073 16 | ], 17 | "image_processor_type": "CLIPImageProcessor", 18 | "image_std": [ 19 | 0.26862954, 20 | 0.26130258, 21 | 0.27577711 22 | ], 23 | "resample": 3, 24 | "rescale_factor": 0.00392156862745098, 25 | "size": { 26 | "shortest_edge": 224 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/models/id_proj.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from diffusers import ModelMixin 4 | from einops import rearrange 5 | from torch import nn 6 | 7 | class IDProjConvModel(ModelMixin): 8 | def __init__(self, in_channels=2048, out_channels=1024): 9 | super().__init__() 10 | 11 | self.project1024 = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False) 12 | self.final_norm = torch.nn.LayerNorm(out_channels) 13 | 14 | def forward(self, src_id_features_7_7_1024): 15 | c = self.project1024(src_id_features_7_7_1024) 16 | c = torch.flatten(c, 2) 17 | c = torch.transpose(c, 2, 1) 18 | c = self.final_norm(c) 19 | 20 | return c 21 | -------------------------------------------------------------------------------- /svd_repo/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKLTemporalDecoder", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/vae", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "force_upcast": true, 18 | "in_channels": 3, 19 | "latent_channels": 4, 20 | "layers_per_block": 2, 21 | "out_channels": 3, 22 | "sample_size": 768, 23 | "scaling_factor": 0.18215 24 | } 25 | -------------------------------------------------------------------------------- /src/utils/noise_util.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import torch 3 | 4 | from diffusers.utils.torch_utils import randn_tensor 5 | 6 | def random_noise( 7 | tensor: torch.Tensor = None, 8 | shape: Tuple[int] = None, 9 | dtype: torch.dtype = None, 10 | device: torch.device = None, 11 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 12 | noise_offset: Optional[float] = None, # typical value is 0.1 13 | ) -> torch.Tensor: 14 | if tensor is not None: 15 | shape = tensor.shape 16 | device = tensor.device 17 | dtype = tensor.dtype 18 | if isinstance(device, str): 19 | device = torch.device(device) 20 | noise = randn_tensor(shape, dtype=dtype, device=device, generator=generator) 21 | if noise_offset is not None: 22 | noise += noise_offset * torch.randn( 23 | (tensor.shape[0], tensor.shape[1], 1, 1, 1), device 24 | ) 25 | return noise 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 smthemex 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /svd_repo/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNetSpatioTemporalConditionModel", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/unet", 5 | "addition_time_embed_dim": 256, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "cross_attention_dim": 1024, 13 | "down_block_types": [ 14 | "CrossAttnDownBlockSpatioTemporal", 15 | "CrossAttnDownBlockSpatioTemporal", 16 | "CrossAttnDownBlockSpatioTemporal", 17 | "DownBlockSpatioTemporal" 18 | ], 19 | "in_channels": 8, 20 | "layers_per_block": 2, 21 | "num_attention_heads": [ 22 | 5, 23 | 10, 24 | 20, 25 | 20 26 | ], 27 | "num_frames": 25, 28 | "out_channels": 4, 29 | "projection_class_embeddings_input_dim": 768, 30 | "sample_size": 96, 31 | "transformer_layers_per_block": 1, 32 | "up_block_types": [ 33 | "UpBlockSpatioTemporal", 34 | "CrossAttnUpBlockSpatioTemporal", 35 | "CrossAttnUpBlockSpatioTemporal", 36 | "CrossAttnUpBlockSpatioTemporal" 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /src/dataset/face_align/align.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(BASE_DIR) 5 | import torch 6 | from ..face_align.yoloface import YoloFace 7 | 8 | class AlignImage(object): 9 | def __init__(self, device='cuda', det_path='checkpoints/yoloface_v5m.pt'): 10 | self.facedet = YoloFace(pt_path=det_path, confThreshold=0.5, nmsThreshold=0.45, device=device) 11 | 12 | @torch.no_grad() 13 | def __call__(self, im, maxface=False): 14 | bboxes, kpss, scores = self.facedet.detect(im) 15 | face_num = bboxes.shape[0] 16 | 17 | five_pts_list = [] 18 | scores_list = [] 19 | bboxes_list = [] 20 | for i in range(face_num): 21 | five_pts_list.append(kpss[i].reshape(5,2)) 22 | scores_list.append(scores[i]) 23 | bboxes_list.append(bboxes[i]) 24 | 25 | if maxface and face_num>1: 26 | max_idx = 0 27 | max_area = (bboxes[0, 2])*(bboxes[0, 3]) 28 | for i in range(1, face_num): 29 | area = (bboxes[i,2])*(bboxes[i,3]) 30 | if area>max_area: 31 | max_idx = i 32 | five_pts_list = [five_pts_list[max_idx]] 33 | scores_list = [scores_list[max_idx]] 34 | bboxes_list = [bboxes_list[max_idx]] 35 | 36 | return five_pts_list, scores_list, bboxes_list -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from einops import rearrange 8 | from PIL import Image 9 | 10 | import imageio 11 | 12 | def seed_everything(seed): 13 | import random 14 | 15 | import numpy as np 16 | 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | np.random.seed(seed % (2**32)) 20 | random.seed(seed) 21 | 22 | 23 | def save_videos_from_pil(pil_images, path, fps=8): 24 | save_fmt = Path(path).suffix 25 | os.makedirs(os.path.dirname(path), exist_ok=True) 26 | 27 | if save_fmt == ".mp4": 28 | with imageio.get_writer(path, fps=fps) as writer: 29 | for img in pil_images: 30 | img_array = np.array(img) # Convert PIL Image to numpy array 31 | writer.append_data(img_array) 32 | 33 | elif save_fmt == ".gif": 34 | pil_images[0].save( 35 | fp=path, 36 | format="GIF", 37 | append_images=pil_images[1:], 38 | save_all=True, 39 | duration=(1 / fps * 1000), 40 | loop=0, 41 | optimize=False, 42 | lossless=True 43 | ) 44 | else: 45 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 46 | 47 | 48 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 49 | videos = rearrange(videos, "b c t h w -> t b c h w") 50 | height, width = videos.shape[-2:] 51 | outputs = [] 52 | 53 | for i, x in enumerate(videos): 54 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 55 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 56 | if rescale: 57 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 58 | x = (x * 255).numpy().astype(np.uint8) 59 | x = Image.fromarray(x) 60 | outputs.append(x) 61 | 62 | os.makedirs(os.path.dirname(path), exist_ok=True) 63 | 64 | save_videos_from_pil(outputs, path, fps) 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_SVFR 2 | [SVFR](https://github.com/wangzhiyaoo/SVFR/tree/main) is a unified framework for face video restoration that supports tasks such as BFR, Colorization, Inpainting,you can use it in ComfyUI 3 | 4 | # Update 5 | * 24/02/12 修改模型加载模式为单体模型,此模型调用cofmy的vaae会偏色严重,所以只能用diffuser方法了/Change the model loading mode to a monolithic model 6 | 7 | # 1. Installation 8 | 9 | In the ./ComfyUI /custom_node directory, run the following: 10 | ``` 11 | git clone https://github.com/smthemex/ComfyUI_SVFR.git 12 | ``` 13 | # 2. Requirements 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | # 3. Models Required 18 | * 3.1 download SVFR checkpoints from [google drive](https://drive.google.com/drive/folders/1nzy9Vk-yA_DwXm1Pm4dyE2o0r7V6_5mn) After decompression, place the model in the following file format,从谷歌云盘下载模型,解压后按以下文件格式放置模型; 19 | ``` 20 | ├── Comfyui/models/SVFR/ 21 | | ├── id_linear.pth 22 | | ├── insightface_glint360k.pth 23 | | ├── unet.pth 24 | | ├── yoloface_v5m.pt 25 | ``` 26 | * 3.2 [svd_xt.safetensors](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) or [svd_xt_1_1.safetensors](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) 27 | ``` 28 | ├── Comfyui/models/checkpoints/ 29 | | ├── svd_xt.safetensors or svd_xt_1_1.safetensors 30 | ├── Comfyui/models/vae/ 31 | | ├──svd.ave.fp16.safetensors #rename from stabilityai/stable-video-diffusion-img2vid-xt/vae 重命名repo下的vae模型既可以,不命名也行 32 | ``` 33 | 34 | # 4 Inference mode 35 | 36 | * "bfr,colorization,inpainting,bfr_color,bfr_color_inpaint",inpainting and bfr_color_inpaint mode need a mask(use comfyUI mask or black/white jpg) 37 | 38 | # 5 Example 39 | ![](https://github.com/smthemex/ComfyUI_SVFR/blob/main/exampleA.png) 40 | 41 | # 6 Citation 42 | ``` 43 | @misc{wang2025svfrunifiedframeworkgeneralized, 44 | title={SVFR: A Unified Framework for Generalized Video Face Restoration}, 45 | author={Zhiyao Wang and Xu Chen and Chengming Xu and Junwei Zhu and Xiaobin Hu and Jiangning Zhang and Chengjie Wang and Yuqi Liu and Yiyi Zhou and Rongrong Ji}, 46 | year={2025}, 47 | eprint={2501.01235}, 48 | archivePrefix={arXiv}, 49 | primaryClass={cs.CV}, 50 | url={https://arxiv.org/abs/2501.01235}, 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | from PIL import Image 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | from transformers import CLIPImageProcessor 9 | # import librosa 10 | 11 | import os 12 | import cv2 13 | 14 | mean_face_lm5p_256 = np.array([ 15 | [(30.2946+8)*2+16, 51.6963*2], 16 | [(65.5318+8)*2+16, 51.5014*2], 17 | [(48.0252+8)*2+16, 71.7366*2], 18 | [(33.5493+8)*2+16, 92.3655*2], 19 | [(62.7299+8)*2+16, 92.2041*2], 20 | ], dtype=np.float32) 21 | 22 | def get_affine_transform(target_face_lm5p, mean_lm5p): 23 | mat_warp = np.zeros((2,3)) 24 | A = np.zeros((4,4)) 25 | B = np.zeros((4)) 26 | for i in range(5): 27 | A[0][0] += target_face_lm5p[i][0] * target_face_lm5p[i][0] + target_face_lm5p[i][1] * target_face_lm5p[i][1] 28 | A[0][2] += target_face_lm5p[i][0] 29 | A[0][3] += target_face_lm5p[i][1] 30 | 31 | B[0] += target_face_lm5p[i][0] * mean_lm5p[i][0] + target_face_lm5p[i][1] * mean_lm5p[i][1] #sb[1] += a[i].x*b[i].y - a[i].y*b[i].x; 32 | B[1] += target_face_lm5p[i][0] * mean_lm5p[i][1] - target_face_lm5p[i][1] * mean_lm5p[i][0] 33 | B[2] += mean_lm5p[i][0] 34 | B[3] += mean_lm5p[i][1] 35 | 36 | A[1][1] = A[0][0] 37 | A[2][1] = A[1][2] = -A[0][3] 38 | A[3][1] = A[1][3] = A[2][0] = A[0][2] 39 | A[2][2] = A[3][3] = 5 40 | A[3][0] = A[0][3] 41 | 42 | _, mat23 = cv2.solve(A, B, flags=cv2.DECOMP_SVD) 43 | mat_warp[0][0] = mat23[0] 44 | mat_warp[1][1] = mat23[0] 45 | mat_warp[0][1] = -mat23[1] 46 | mat_warp[1][0] = mat23[1] 47 | mat_warp[0][2] = mat23[2] 48 | mat_warp[1][2] = mat23[3] 49 | 50 | return mat_warp 51 | 52 | def get_union_bbox(bboxes): 53 | bboxes = np.array(bboxes) 54 | min_x = np.min(bboxes[:, 0]) 55 | min_y = np.min(bboxes[:, 1]) 56 | max_x = np.max(bboxes[:, 2]) 57 | max_y = np.max(bboxes[:, 3]) 58 | return np.array([min_x, min_y, max_x, max_y]) 59 | 60 | 61 | def process_bbox(bbox, expand_radio, height, width): 62 | 63 | def expand(bbox, ratio, height, width): 64 | 65 | bbox_h = bbox[3] - bbox[1] 66 | bbox_w = bbox[2] - bbox[0] 67 | 68 | expand_x1 = max(bbox[0] - ratio * bbox_w, 0) 69 | expand_y1 = max(bbox[1] - ratio * bbox_h, 0) 70 | expand_x2 = min(bbox[2] + ratio * bbox_w, width) 71 | expand_y2 = min(bbox[3] + ratio * bbox_h, height) 72 | 73 | return [expand_x1,expand_y1,expand_x2,expand_y2] 74 | 75 | def to_square(bbox_src, bbox_expend, height, width): 76 | 77 | h = bbox_expend[3] - bbox_expend[1] 78 | w = bbox_expend[2] - bbox_expend[0] 79 | c_h = (bbox_expend[1] + bbox_expend[3]) / 2 80 | c_w = (bbox_expend[0] + bbox_expend[2]) / 2 81 | 82 | c = min(h, w) / 2 83 | 84 | c_src_h = (bbox_src[1] + bbox_src[3]) / 2 85 | c_src_w = (bbox_src[0] + bbox_src[2]) / 2 86 | 87 | s_h, s_w = 0, 0 88 | if w < h: 89 | d = abs((h - w) / 2) 90 | s_h = min(d, abs(c_src_h-c_h)) 91 | s_h = s_h if c_src_h > c_h else s_h * (-1) 92 | else: 93 | d = abs((h - w) / 2) 94 | s_w = min(d, abs(c_src_w-c_w)) 95 | s_w = s_w if c_src_w > c_w else s_w * (-1) 96 | 97 | 98 | c_h = (bbox_expend[1] + bbox_expend[3]) / 2 + s_h 99 | c_w = (bbox_expend[0] + bbox_expend[2]) / 2 + s_w 100 | 101 | square_x1 = c_w - c 102 | square_y1 = c_h - c 103 | square_x2 = c_w + c 104 | square_y2 = c_h + c 105 | 106 | return [round(square_x1), round(square_y1), round(square_x2), round(square_y2)] 107 | 108 | 109 | bbox_expend = expand(bbox, expand_radio, height=height, width=width) 110 | processed_bbox = to_square(bbox, bbox_expend, height=height, width=width) 111 | 112 | return processed_bbox 113 | 114 | 115 | def crop_resize_img(img, bbox): 116 | x1, y1, x2, y2 = bbox 117 | img = img.crop((x1, y1, x2, y2)) 118 | return img 119 | -------------------------------------------------------------------------------- /node_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | from comfy.utils import common_upscale,ProgressBar 9 | import folder_paths 10 | 11 | weight_dtype = torch.float16 12 | cur_path = os.path.dirname(os.path.abspath(__file__)) 13 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 14 | 15 | 16 | def pil2narry(img): 17 | img = torch.from_numpy(np.array(img).astype(np.float32) / 255.0).unsqueeze(0) 18 | return img 19 | 20 | def narry_list(list_in): 21 | for i in range(len(list_in)): 22 | value = list_in[i] 23 | modified_value = pil2narry(value) 24 | list_in[i] = modified_value 25 | return list_in 26 | 27 | 28 | def gen_img_form_video(tensor): 29 | pil = [] 30 | for x in tensor: 31 | pil[x] = tensor_to_pil(x) 32 | yield pil 33 | 34 | 35 | def phi_list(list_in): 36 | for i in range(len(list_in)): 37 | value = list_in[i] 38 | list_in[i] = value 39 | return list_in 40 | 41 | def tensor_to_pil(tensor): 42 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 43 | image = Image.fromarray(image_np, mode='RGB') 44 | return image 45 | 46 | def nomarl_upscale(img_tensor, width, height): 47 | samples = img_tensor.movedim(-1, 1) 48 | img = common_upscale(samples, width, height, "nearest-exact", "center") 49 | samples = img.movedim(1, -1) 50 | img_pil = tensor_to_pil(samples) 51 | return img_pil 52 | 53 | def tensor_upscale(img_tensor, width, height): 54 | samples = img_tensor.movedim(-1, 1) 55 | img = common_upscale(samples, width, height, "nearest-exact", "center") 56 | samples = img.movedim(1, -1) 57 | return samples 58 | 59 | 60 | def tensor2cv(tensor_image): 61 | if len(tensor_image.shape)==4:# b hwc to hwc 62 | tensor_image=tensor_image.squeeze(0) 63 | if tensor_image.is_cuda: 64 | tensor_image = tensor_image.cpu() 65 | tensor_image=tensor_image.numpy() 66 | #反归一化 67 | maxValue=tensor_image.max() 68 | tensor_image=tensor_image*255/maxValue 69 | img_cv2=np.uint8(tensor_image)#32 to uint8 70 | img_cv2=cv2.cvtColor(img_cv2,cv2.COLOR_RGB2BGR) 71 | return img_cv2 72 | 73 | def cvargb2tensor(img): 74 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 75 | img = torch.from_numpy(img.transpose((2, 0, 1))) 76 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 77 | 78 | def cv2tensor(img): 79 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 80 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 81 | img = torch.from_numpy(img.transpose((2, 0, 1))) 82 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 83 | 84 | def images_generator(img_list: list,): 85 | #get img size 86 | sizes = {} 87 | for image_ in img_list: 88 | if isinstance(image_,Image.Image): 89 | count = sizes.get(image_.size, 0) 90 | sizes[image_.size] = count + 1 91 | elif isinstance(image_,np.ndarray): 92 | count = sizes.get(image_.shape[:2][::-1], 0) 93 | sizes[image_.shape[:2][::-1]] = count + 1 94 | else: 95 | raise "unsupport image list,must be pil or cv2!!!" 96 | size = max(sizes.items(), key=lambda x: x[1])[0] 97 | yield size[0], size[1] 98 | 99 | # any to tensor 100 | def load_image(img_in): 101 | if isinstance(img_in, Image.Image): 102 | img_in=img_in.convert("RGB") 103 | i = np.array(img_in, dtype=np.float32) 104 | i = torch.from_numpy(i).div_(255) 105 | if i.shape[0] != size[1] or i.shape[1] != size[0]: 106 | i = torch.from_numpy(i).movedim(-1, 0).unsqueeze(0) 107 | i = common_upscale(i, size[0], size[1], "lanczos", "center") 108 | i = i.squeeze(0).movedim(0, -1).numpy() 109 | return i 110 | elif isinstance(img_in,np.ndarray): 111 | i=cv2.cvtColor(img_in,cv2.COLOR_BGR2RGB).astype(np.float32) 112 | i = torch.from_numpy(i).div_(255) 113 | #print(i.shape) 114 | return i 115 | else: 116 | raise "unsupport image list,must be pil,cv2 or tensor!!!" 117 | 118 | total_images = len(img_list) 119 | processed_images = 0 120 | pbar = ProgressBar(total_images) 121 | images = map(load_image, img_list) 122 | try: 123 | prev_image = next(images) 124 | while True: 125 | next_image = next(images) 126 | yield prev_image 127 | processed_images += 1 128 | pbar.update_absolute(processed_images, total_images) 129 | prev_image = next_image 130 | except StopIteration: 131 | pass 132 | if prev_image is not None: 133 | yield prev_image 134 | 135 | def load_images(img_list: list,): 136 | gen = images_generator(img_list) 137 | (width, height) = next(gen) 138 | images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (height, width, 3))))) 139 | if len(images) == 0: 140 | raise FileNotFoundError(f"No images could be loaded .") 141 | return images 142 | 143 | def tensor2pil(tensor): 144 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 145 | image = Image.fromarray(image_np, mode='RGB') 146 | return image 147 | 148 | def cf_tensor2cv(tensor,width, height): 149 | d1, _, _, _ = tensor.size() 150 | if d1 > 1: 151 | tensor_list = list(torch.chunk(tensor, chunks=d1)) 152 | tensor = [tensor_list][0] 153 | cr_tensor=tensor_upscale(tensor,width, height) 154 | cv_img=tensor2cv(cr_tensor) 155 | return cv_img 156 | 157 | def tensor2pillist(tensor): 158 | b, _, _, _ = tensor.size() 159 | if b == 1: 160 | img_list = [nomarl_upscale(tensor, 768, 768)] 161 | else: 162 | image_= torch.chunk(tensor, chunks=b) 163 | img_list = [nomarl_upscale(i, 768, 768) for i in image_] # pil 164 | return img_list 165 | 166 | 167 | -------------------------------------------------------------------------------- /src/models/model_insightface_360k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | 5 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200', 'getarcface'] 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, 22 | out_planes, 23 | kernel_size=1, 24 | stride=stride, 25 | bias=False) 26 | 27 | 28 | class IBasicBlock(nn.Module): 29 | expansion = 1 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, 31 | groups=1, base_width=64, dilation=1): 32 | super(IBasicBlock, self).__init__() 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 38 | self.conv1 = conv3x3(inplanes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 40 | self.prelu = nn.PReLU(planes) 41 | self.conv2 = conv3x3(planes, planes, stride) 42 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | identity = x 48 | out = self.bn1(x) 49 | out = self.conv1(out) 50 | out = self.bn2(out) 51 | out = self.prelu(out) 52 | out = self.conv2(out) 53 | out = self.bn3(out) 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | out += identity 57 | return out 58 | 59 | 60 | class IResNet(nn.Module): 61 | fc_scale = 7 * 7 62 | def __init__(self, 63 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 64 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 65 | super(IResNet, self).__init__() 66 | self.fp16 = fp16 67 | self.inplanes = 64 68 | self.dilation = 1 69 | if replace_stride_with_dilation is None: 70 | replace_stride_with_dilation = [False, False, False] 71 | if len(replace_stride_with_dilation) != 3: 72 | raise ValueError("replace_stride_with_dilation should be None " 73 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 74 | self.groups = groups 75 | self.base_width = width_per_group 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 78 | self.prelu = nn.PReLU(self.inplanes) 79 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 80 | self.layer2 = self._make_layer(block, 81 | 128, 82 | layers[1], 83 | stride=2, 84 | dilate=replace_stride_with_dilation[0]) 85 | self.layer3 = self._make_layer(block, 86 | 256, 87 | layers[2], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[1]) 90 | self.layer4 = self._make_layer(block, 91 | 512, 92 | layers[3], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[2]) 95 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 96 | self.dropout = nn.Dropout(p=dropout, inplace=True) 97 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 98 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 99 | nn.init.constant_(self.features.weight, 1.0) 100 | self.features.weight.requires_grad = False 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.normal_(m.weight, 0, 0.1) 105 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | if zero_init_residual: 110 | for m in self.modules(): 111 | if isinstance(m, IBasicBlock): 112 | nn.init.constant_(m.bn2.weight, 0) 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 115 | downsample = None 116 | previous_dilation = self.dilation 117 | if dilate: 118 | self.dilation *= stride 119 | stride = 1 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | conv1x1(self.inplanes, planes * block.expansion, stride), 123 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 124 | ) 125 | layers = [] 126 | layers.append( 127 | block(self.inplanes, planes, stride, downsample, self.groups, 128 | self.base_width, previous_dilation)) 129 | self.inplanes = planes * block.expansion 130 | for _ in range(1, blocks): 131 | layers.append( 132 | block(self.inplanes, 133 | planes, 134 | groups=self.groups, 135 | base_width=self.base_width, 136 | dilation=self.dilation)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | # with torch.cuda.amp.autocast(self.fp16): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.prelu(x) 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | layer4_res = x 150 | x = self.bn2(x) 151 | x = torch.flatten(x, 1) 152 | x = self.dropout(x) 153 | x = self.fc(x.float() if self.fp16 else x) 154 | y = self.features(x) 155 | return y,layer4_res 156 | 157 | 158 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 159 | model = IResNet(block, layers, **kwargs) 160 | if pretrained: 161 | raise ValueError() 162 | return model 163 | 164 | 165 | def iresnet18(pretrained=False, progress=True, **kwargs): 166 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 167 | progress, **kwargs) 168 | 169 | 170 | def iresnet34(pretrained=False, progress=True, **kwargs): 171 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 172 | progress, **kwargs) 173 | 174 | 175 | def iresnet50(pretrained=False, progress=True, **kwargs): 176 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 177 | progress, **kwargs) 178 | 179 | 180 | def iresnet100(pretrained=False, progress=True, **kwargs): 181 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 182 | progress, **kwargs) 183 | 184 | 185 | def iresnet200(pretrained=False, progress=True, **kwargs): 186 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 187 | progress, **kwargs) 188 | 189 | 190 | def getarcface(pretrained=None): 191 | model = iresnet100() 192 | for param in model.parameters(): 193 | param.requires_grad=False 194 | 195 | if pretrained is not None and os.path.exists(pretrained): 196 | info = model.load_state_dict(torch.load(pretrained, map_location=lambda storage, loc: storage)) 197 | # print('insightface_glint360k', info) 198 | return model.eval() 199 | 200 | 201 | if __name__=='__main__': 202 | ckpt = 'pretrained/insightface_glint360k.pth' 203 | arcface = getarcface(ckpt) -------------------------------------------------------------------------------- /SVFR_node.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import numpy as np 4 | import os 5 | import torch 6 | from .infer import main_loader,main_sampler 7 | from .node_utils import nomarl_upscale,tensor_upscale 8 | 9 | import folder_paths 10 | 11 | MAX_SEED = np.iinfo(np.int32).max 12 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 13 | current_path = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | weigths_SVFR_current_path = os.path.join(folder_paths.models_dir, "SVFR") 16 | if not os.path.exists(weigths_SVFR_current_path): 17 | os.makedirs(weigths_SVFR_current_path) 18 | 19 | folder_paths.add_model_folder_path("SVFR", weigths_SVFR_current_path) 20 | 21 | 22 | class SVFR_LoadModel: 23 | def __init__(self): 24 | pass 25 | 26 | @classmethod 27 | def INPUT_TYPES(cls): 28 | yolo_ckpt_list = [i for i in folder_paths.get_filename_list("SVFR") if 29 | "yolo" in i] 30 | insightface_ckpt_list = [i for i in folder_paths.get_filename_list("SVFR") if 31 | "insightface" in i] 32 | unet_ckpt_list = [i for i in folder_paths.get_filename_list("SVFR") if 33 | "unet" in i] 34 | id_ckpt_list = [i for i in folder_paths.get_filename_list("SVFR") if 35 | "id" in i] 36 | return { 37 | "required": { 38 | "checkpoints": (folder_paths.get_filename_list("checkpoints"),), 39 | "vae": (folder_paths.get_filename_list("vae"),), 40 | "unet": (["none"] + unet_ckpt_list,), 41 | "yolo_ckpt": (["none"] + yolo_ckpt_list,), 42 | "id_ckpt": (["none"] + id_ckpt_list,), 43 | "insightface": (["none"] + insightface_ckpt_list,), 44 | "dtype": (["fp16","bf16","fp32"],), 45 | } 46 | } 47 | 48 | RETURN_TYPES = ("MODEL_SVFR",) 49 | RETURN_NAMES = ("model",) 50 | FUNCTION = "main_loader" 51 | CATEGORY = "SVFR" 52 | 53 | def main_loader(self,checkpoints,vae, unet, yolo_ckpt, id_ckpt, insightface,dtype): 54 | 55 | I2V_repo=os.path.join(current_path, "svd_repo") 56 | if dtype == "fp16": 57 | weight_dtype = torch.float16 58 | elif dtype == "fp32": 59 | weight_dtype = torch.float32 60 | else: 61 | weight_dtype = torch.bfloat16 62 | 63 | if unet == "none" or yolo_ckpt == "none" or id_ckpt == "none" or insightface == "none": 64 | raise "need choice ckpt in menu" 65 | else: 66 | unet_path=folder_paths.get_full_path("SVFR",unet) 67 | det_path=folder_paths.get_full_path("SVFR",yolo_ckpt) 68 | id_path=folder_paths.get_full_path("SVFR",id_ckpt) 69 | face_path=folder_paths.get_full_path("SVFR",insightface) 70 | UNET= folder_paths.get_full_path("checkpoints",checkpoints) 71 | vae=folder_paths.get_full_path("vae",vae) 72 | pipe,id_linear,net_arcface,align_instance=main_loader(weight_dtype,I2V_repo,UNET,vae,unet_path,det_path,id_path,face_path,device,dtype) 73 | print("****** Load model is done.******") 74 | return ( 75 | {"pipe": pipe, "id_linear": id_linear, "net_arcface": net_arcface, "align_instance": align_instance, "weight_dtype": weight_dtype},) 76 | 77 | 78 | class SVFR_Sampler: 79 | @classmethod 80 | def INPUT_TYPES(s): 81 | 82 | return { 83 | "required": { 84 | "image": ("IMAGE",), # [B,H,W,C], C=3,B>1 85 | "model": ("MODEL_SVFR",), 86 | "seed": ("INT", {"default": 77, "min": 0, "max": MAX_SEED}), 87 | "width": ("INT", {"default": 512, "min": 128, "max": 2048, "step": 64, "display": "number"}), 88 | "height": ("INT", {"default": 512, "min": 128, "max": 2048, "step": 64, "display": "number"}), 89 | "decode_chunk_size": ("INT", {"default":16, "min": 4, "max": 128, "step": 4,}), 90 | "n_sample_frames": ("INT", {"default": 16, "min": 8, "max": 100, "step": 1,}), 91 | "steps": ("INT", {"default": 50, "min": 1, "max": 4096, "step": 1, "display": "number"}), 92 | "noise_aug_strength": ("FLOAT", {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.01, "round": 0.001}), 93 | "overlap": ("INT", {"default":3, "min": 1, "max": 64}), 94 | "min_appearance_guidance_scale": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0, "step": 0.1, "round": 0.01}), 95 | "max_appearance_guidance_scale": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0, "step": 0.1, "round": 0.01}), 96 | "i2i_noise_strength": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 1.0, "step": 0.1, "round": 0.01}), 97 | "infer_mode":(["bfr","colorization","inpainting","bfr_color","bfr_color_inpaint"],), 98 | "save_video": ("BOOLEAN", {"default": False},), 99 | "crop_face_region": ("BOOLEAN", {"default": True},), 100 | }, 101 | "optional": {"mask": ("MASK",), 102 | }, 103 | } 104 | 105 | RETURN_TYPES = ("IMAGE",) 106 | RETURN_NAMES = ("images",) 107 | FUNCTION = "sampler_main" 108 | CATEGORY = "SVFR" 109 | 110 | def sampler_main(self, image, model, seed, width, height,decode_chunk_size,n_sample_frames 111 | , steps,noise_aug_strength, overlap,min_appearance_guidance_scale,max_appearance_guidance_scale, i2i_noise_strength,infer_mode,save_video,crop_face_region,**kwargs): 112 | 113 | pipe = model.get("pipe") 114 | id_linear = model.get("id_linear") 115 | net_arcface = model.get("net_arcface") 116 | align_instance = model.get("align_instance") 117 | weight_dtype = model.get("weight_dtype") 118 | mask=kwargs.get("mask") 119 | 120 | if isinstance(mask,torch.Tensor): #缩放至图片尺寸 121 | if mask.shape[-1]==64 and mask.shape[-2]==64: 122 | raise "input mask is not a useful,looks like a default comfyUI mask" 123 | 124 | if len(mask.shape) == 3: # 1,h,w 125 | mask_array = mask.squeeze().mul(255).clamp(0, 255).byte().numpy() 126 | elif len(mask.shape)==2 and mask.shape[0]!=1: # h,w 127 | mask_array=mask.mul(255).clamp(0, 255).byte().numpy() 128 | else: 129 | raise "check input mask's shape" 130 | mask_array=np.where(mask_array > 0, 255, 0) 131 | else: 132 | mask_array=None 133 | video_len, _, _, _ = image.size() 134 | if video_len < 8: 135 | raise "input video has not much frames below 8 frame,change your input video!" 136 | else: 137 | tensor_list = list(torch.chunk(image, chunks=video_len)) 138 | input_frames_pil = [nomarl_upscale(i,width, height) for i in tensor_list] # tensor to pil 139 | 140 | if infer_mode=="bfr": 141 | task_ids=[0] 142 | elif infer_mode=="colorization": 143 | task_ids = [1] 144 | elif infer_mode=="inpainting": 145 | task_ids = [2] 146 | elif infer_mode=="bfr_color": 147 | task_ids = [0,1] 148 | else: 149 | task_ids = [0, 1,2 ] 150 | 151 | if not isinstance(mask_array,np.ndarray) and 2 in task_ids: 152 | raise "If use inpainting need link a mask or a batch mask in the front." 153 | 154 | print("******** start infer *********") 155 | 156 | images=main_sampler(pipe, align_instance, net_arcface, id_linear, folder_paths.get_output_directory(), weight_dtype, 157 | seed,input_frames_pil,task_ids,mask_array,save_video,decode_chunk_size,noise_aug_strength, 158 | min_appearance_guidance_scale,max_appearance_guidance_scale, 159 | overlap,i2i_noise_strength,steps,n_sample_frames,device,crop_face_region) 160 | 161 | #model.to("cpu")#显存不会自动释放,手动迁移,不然很容易OOM 162 | torch.cuda.empty_cache() 163 | return (images,) 164 | 165 | 166 | class SVFR_img2mask: 167 | @classmethod 168 | def INPUT_TYPES(s): 169 | return { 170 | "required": { 171 | "image": ("IMAGE",), # [B,H,W,C], C=3,B=1 172 | "threshold": ("INT", {"default": 0, "min": 0, "max": 254, "step": 0, "display": "number"}), 173 | "center_crop": ("BOOLEAN", {"default": False},), 174 | "width": ("INT", {"default": 512, "min": 128, "max": 2048, "step": 64, "display": "number"}), 175 | "height": ("INT", {"default": 512, "min": 128, "max": 2048, "step": 64, "display": "number"}), 176 | } 177 | } 178 | RETURN_TYPES = ("MASK",) 179 | RETURN_NAMES = ("mask",) 180 | FUNCTION = "main" 181 | CATEGORY = "SVFR" 182 | 183 | def main(self, image,threshold,center_crop,width,height): 184 | if center_crop: 185 | image=tensor_upscale(image, width, height) 186 | np_img=image.squeeze().mul(255).clamp(0, 255).byte().numpy() 187 | np_img=np.mean(np_img, axis=2).astype(np.uint8) 188 | 189 | black_threshold = 50 # 黑色阈值,小于这个值的像素被认为是黑色 190 | white_threshold = 200 # 白色阈值,大于这个值的像素被认为是白色 191 | black_pixels = np.sum(np_img < black_threshold) # 计算黑色像素的数量 192 | white_pixels = np.sum(np_img > white_threshold) # 计算白色像素的数量 193 | if black_pixels>white_pixels: #黑多白少,按白色为遮罩 194 | out = np.where(np_img > threshold, 255, 0).astype(np.float32) / 255.0 195 | else: 196 | out = np.where(np_img > threshold, 0, 255).astype(np.float32) / 255.0 197 | return (torch.from_numpy(out).unsqueeze(0),) 198 | 199 | 200 | NODE_CLASS_MAPPINGS = { 201 | "SVFR_LoadModel": SVFR_LoadModel, 202 | "SVFR_Sampler": SVFR_Sampler, 203 | "SVFR_img2mask":SVFR_img2mask, 204 | 205 | } 206 | NODE_DISPLAY_NAME_MAPPINGS = { 207 | "SVFR_LoadModel": "SVFR_LoadModel", 208 | "SVFR_Sampler": "SVFR_Sampler", 209 | "SVFR_img2mask":"SVFR_img2mask" 210 | } 211 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import warnings 4 | import os 5 | import numpy as np 6 | import safetensors.torch 7 | import torch 8 | import torch.utils.checkpoint 9 | from PIL import Image 10 | from diffusers import AutoencoderKLTemporalDecoder 11 | from diffusers.schedulers import EulerDiscreteScheduler 12 | import torchvision.transforms as transforms 13 | import torch.nn.functional as F 14 | from torchvision.utils import save_image 15 | import random 16 | import cv2 17 | import safetensors 18 | 19 | import gc 20 | # pipeline 21 | from .src.pipelines.pipeline import LQ2VideoLongSVDPipeline 22 | from .src.utils.util import save_videos_grid, seed_everything 23 | from .src.models.id_proj import IDProjConvModel 24 | from .src.models import model_insightface_360k 25 | from .src.dataset.face_align.align import AlignImage 26 | from .src.models.svfr_adapter.unet_3d_svd_condition_ip import UNet3DConditionSVDModel 27 | from .src.dataset.dataset import get_affine_transform, mean_face_lm5p_256,get_union_bbox, process_bbox, crop_resize_img 28 | from .node_utils import tensor2pil 29 | warnings.filterwarnings("ignore") 30 | 31 | 32 | def main_loader(weight_dtype, repo,UNET,VAE, unet_path, det_path, id_path, face_path,device,dtype): 33 | 34 | 35 | val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(repo, subfolder="scheduler") 36 | 37 | input_unet_dic=safetensors.torch.load_file(UNET) 38 | unet_config=UNet3DConditionSVDModel.load_config(os.path.join(repo, "unet")) 39 | unet=UNet3DConditionSVDModel.from_config(unet_config).to(weight_dtype) 40 | unet.load_state_dict(input_unet_dic, strict=False) 41 | 42 | 43 | align_instance = AlignImage(device, det_path=det_path) 44 | 45 | input_vae_dic=safetensors.torch.load_file(VAE) 46 | vae_config=AutoencoderKLTemporalDecoder.load_config(os.path.join(repo, "vae")) 47 | vae=AutoencoderKLTemporalDecoder.from_config(vae_config).to(weight_dtype) 48 | vae.load_state_dict(input_vae_dic, strict=False) 49 | 50 | import torch.nn as nn 51 | class InflatedConv3d(nn.Conv2d): 52 | def forward(self, x): 53 | x = super().forward(x) 54 | return x 55 | 56 | # Add ref channel 57 | old_weights = unet.conv_in.weight 58 | old_bias = unet.conv_in.bias 59 | new_conv1 = InflatedConv3d( 60 | 12, 61 | old_weights.shape[0], 62 | kernel_size=unet.conv_in.kernel_size, 63 | stride=unet.conv_in.stride, 64 | padding=unet.conv_in.padding, 65 | bias=True if old_bias is not None else False, 66 | ) 67 | param = torch.zeros((320, 4, 3, 3), requires_grad=True) 68 | new_conv1.weight = torch.nn.Parameter(torch.cat((old_weights, param), dim=1)) 69 | if old_bias is not None: 70 | new_conv1.bias = old_bias 71 | unet.conv_in = new_conv1 72 | unet.config["in_channels"] = 12 73 | unet.config.in_channels = 12 74 | 75 | id_linear = IDProjConvModel(in_channels=512, out_channels=1024).to(device=device) 76 | 77 | pre_unet_dict=torch.load(unet_path, map_location="cpu") 78 | pre_linear_dict=torch.load(id_path, map_location="cpu") 79 | 80 | # load pretrained weights 81 | unet.load_state_dict(pre_unet_dict,strict=True,) 82 | 83 | id_linear.load_state_dict(pre_linear_dict,strict=True,) 84 | 85 | net_arcface = model_insightface_360k.getarcface(face_path).eval().to(device=device) 86 | 87 | #image_encoder.to(weight_dtype) 88 | vae.to(weight_dtype) 89 | unet.to(weight_dtype) 90 | id_linear.to(weight_dtype) 91 | net_arcface.requires_grad_(False).to(weight_dtype) 92 | del input_unet_dic, input_vae_dic, pre_unet_dict, pre_linear_dict 93 | gc.collect() 94 | torch.cuda.empty_cache() 95 | pipe = LQ2VideoLongSVDPipeline( 96 | unet=unet, 97 | #image_encoder=image_encoder, 98 | vae=vae, 99 | scheduler=val_noise_scheduler, 100 | feature_extractor=None, 101 | ) 102 | pipe = pipe.to(device, dtype=unet.dtype) 103 | 104 | return pipe, id_linear, net_arcface, align_instance 105 | 106 | 107 | def main_sampler(pipe,align_instance, net_arcface, id_linear, save_dir, weight_dtype, seed, input_frames_pil, task_ids, 108 | mask_array, 109 | save_video, decode_chunk_size, noise_aug_strength, min_appearance_guidance_scale, 110 | max_appearance_guidance_scale, 111 | overlap, i2i_noise_strength, steps, n_sample_frames,device,crop_face_region): 112 | 113 | 114 | to_tensor = transforms.Compose([ 115 | transforms.ToTensor(), 116 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 117 | ]) 118 | 119 | seed_everything(seed) 120 | 121 | if 2 in task_ids and isinstance(mask_array, np.ndarray): 122 | white_positions = mask_array == 255 123 | 124 | print('task_ids:', task_ids) 125 | task_prompt = [0, 0, 0] 126 | for i in range(3): 127 | if i in task_ids: 128 | task_prompt[i] = 1 129 | print("task_prompt:", task_prompt) 130 | 131 | files_prefix = ''.join(random.choice("0123456789") for _ in range(5)) 132 | video_name = f"infer_{files_prefix}" 133 | # print(video_name) 134 | 135 | if os.path.exists(os.path.join(save_dir, "result_frames", video_name[:-4])): 136 | print(os.path.join(save_dir, "result_frames", video_name[:-4])) 137 | # continue 138 | 139 | #import decord 140 | #cap = decord.VideoReader(input_frames_pil, fault_tol=1) 141 | 142 | total_frames = len(input_frames_pil) 143 | T = total_frames # 144 | print("total_frames:", total_frames) 145 | step = 1 146 | drive_idx_start = 0 147 | drive_idx_list = list(range(drive_idx_start, drive_idx_start + T * step, step)) 148 | assert len(drive_idx_list) == T 149 | 150 | if crop_face_region: 151 | # Crop faces from the video for further processing 152 | bbox_list = [] 153 | frame_interval = 5 154 | for frame_count, drive_idx in enumerate(drive_idx_list): 155 | if frame_count % frame_interval != 0: 156 | continue 157 | frame = np.array(input_frames_pil[drive_idx]) 158 | _, _, bboxes_list = align_instance(frame[:,:,[2,1,0]], maxface=True) 159 | if bboxes_list==[]: 160 | continue 161 | x1, y1, ww, hh = bboxes_list[0] 162 | x2, y2 = x1 + ww, y1 + hh 163 | bbox = [x1, y1, x2, y2] 164 | bbox_list.append(bbox) 165 | bbox = get_union_bbox(bbox_list) 166 | bbox_s = process_bbox(bbox, expand_radio=0.4, height=frame.shape[0], width=frame.shape[1]) 167 | 168 | 169 | imSameIDs = [] 170 | vid_gt = [] 171 | width,height=input_frames_pil[0].size 172 | for i, drive_idx in enumerate(drive_idx_list): 173 | imSameID = input_frames_pil[drive_idx] 174 | #imSameID = Image.fromarray(frame) 175 | if crop_face_region: 176 | imSameID = crop_resize_img(imSameID, bbox_s) 177 | imSameID = imSameID.resize((width,height)) 178 | if 1 in task_ids: 179 | imSameID = imSameID.convert("L") # Convert to grayscale 180 | imSameID = imSameID.convert("RGB") 181 | image_array = np.array(imSameID) 182 | if 2 in task_ids and isinstance(mask_array, np.ndarray): 183 | image_array[white_positions] = [255, 255, 255] # mask for inpainting task 184 | vid_gt.append(np.float32(image_array / 255.)) 185 | imSameIDs.append(imSameID) 186 | 187 | vid_lq = [(torch.from_numpy(frame).permute(2, 0, 1) - 0.5) / 0.5 for frame in vid_gt] # torch.Size([3, 512, 512]) 188 | 189 | val_data = dict( 190 | pixel_values_vid_lq=torch.stack(vid_lq, dim=0), 191 | # pixel_values_ref_img=self.to_tensor(target_image), 192 | # pixel_values_ref_concat_img=self.to_tensor(imSrc2), 193 | task_ids=task_ids, 194 | task_id_input=torch.tensor(task_prompt), 195 | total_frames=total_frames, 196 | ) 197 | 198 | window_overlap = 0 199 | inter_frame_list = get_overlap_slide_window_indices(val_data["total_frames"], n_sample_frames, window_overlap) 200 | 201 | lq_frames = val_data["pixel_values_vid_lq"] 202 | task_ids = val_data["task_ids"] 203 | task_id_input = val_data["task_id_input"] 204 | height, width = val_data["pixel_values_vid_lq"].shape[-2:] 205 | 206 | print("Generating the first clip...") 207 | output = pipe( 208 | lq_frames[inter_frame_list[0]].to(device).to(weight_dtype), # lq 209 | None, # ref concat 210 | torch.zeros((1, len(inter_frame_list[0]), 49, 1024)).to(device).to(weight_dtype), # encoder_hidden_states 211 | task_id_input.to(device).to(weight_dtype), 212 | height=height, 213 | width=width, 214 | num_frames=len(inter_frame_list[0]), 215 | decode_chunk_size=decode_chunk_size, 216 | noise_aug_strength=noise_aug_strength, 217 | min_guidance_scale=min_appearance_guidance_scale, 218 | max_guidance_scale=max_appearance_guidance_scale, 219 | overlap=overlap, 220 | frames_per_batch=len(inter_frame_list[0]), 221 | num_inference_steps=steps, 222 | i2i_noise_strength=i2i_noise_strength, 223 | ) 224 | video = output.frames 225 | ref_img_tensor = video[0][:, -1] 226 | 227 | ref_img = (video[0][:, -1] * 0.5 + 0.5).clamp(0, 1) * 255. 228 | ref_img = ref_img.permute(1, 2, 0).cpu().numpy().astype(np.uint8) 229 | pts5 = align_instance(ref_img[:, :, [2, 1, 0]], maxface=True)[0][0] 230 | 231 | warp_mat = get_affine_transform(pts5, mean_face_lm5p_256 * height / 256) 232 | ref_img = cv2.warpAffine(np.array(Image.fromarray(ref_img)), warp_mat, (height, width), flags=cv2.INTER_CUBIC) 233 | ref_img = to_tensor(ref_img).to(device).to(weight_dtype) 234 | 235 | save_image(ref_img * 0.5 + 0.5, f"{save_dir}/ref_img_align.png") 236 | 237 | ref_img = F.interpolate(ref_img.unsqueeze(0)[:, :, 0:224, 16:240], size=[112, 112], mode='bilinear') 238 | _, id_feature_conv = net_arcface(ref_img) 239 | id_embedding = id_linear(id_feature_conv) 240 | 241 | print('Generating all video clips...') 242 | video = pipe( 243 | lq_frames.to(device).to(weight_dtype), # lq 244 | ref_img_tensor.to(device).to(weight_dtype), 245 | id_embedding.unsqueeze(1).repeat(1, len(lq_frames), 1, 1).to(device).to(weight_dtype), # encoder_hidden_states 246 | task_id_input.to(device).to(weight_dtype), 247 | height=height, 248 | width=width, 249 | num_frames=val_data["total_frames"], # frame_num, 250 | decode_chunk_size=decode_chunk_size, 251 | noise_aug_strength=noise_aug_strength, 252 | min_guidance_scale=min_appearance_guidance_scale, 253 | max_guidance_scale=max_appearance_guidance_scale, 254 | overlap=overlap, 255 | frames_per_batch=n_sample_frames, 256 | num_inference_steps=steps, 257 | i2i_noise_strength=i2i_noise_strength, 258 | ).frames 259 | 260 | 261 | video = (video * 0.5 + 0.5).clamp(0, 1) 262 | video = torch.cat([video.to(device=device)], dim=0).cpu() # torch.Size([1, 3, 160, 512, 512]) 263 | 264 | 265 | if save_video: 266 | save_videos_grid(video, f"{save_dir}/{video_name[:-4]}_{seed}.mp4", n_rows=1, fps=25) 267 | 268 | # if restore_frames: 269 | # video = video.squeeze(0) 270 | # os.makedirs(os.path.join(save_dir, "result_frames", f"{video_name[:-4]}_{seed}"),exist_ok=True) 271 | # print(os.path.join(save_dir, "result_frames", video_name[:-4])) 272 | # for i in range(video.shape[1]): 273 | # save_frames_path = os.path.join(f"{save_dir}/result_frames", f"{video_name[:-4]}_{seed}", f'{i:08d}.png') 274 | # save_image(video[:,i], save_frames_path) 275 | 276 | return video.squeeze(0).permute(1, 2, 3, 0) # bcthw to B,H,W,C 277 | 278 | 279 | def get_overlap_slide_window_indices(video_length, window_size, window_overlap): 280 | inter_frame_list = [] 281 | for j in range(0, video_length, window_size - window_overlap): 282 | inter_frame_list.append([e % video_length for e in range(j, min(j + window_size, video_length))]) 283 | 284 | return inter_frame_list 285 | 286 | -------------------------------------------------------------------------------- /src/dataset/face_align/yoloface.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | 8 | 9 | def xyxy2xywh(x): 10 | # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right 11 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 12 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center 13 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center 14 | y[:, 2] = x[:, 2] - x[:, 0] # width 15 | y[:, 3] = x[:, 3] - x[:, 1] # height 16 | return y 17 | 18 | 19 | def xywh2xyxy(x): 20 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 21 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 22 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x 23 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y 24 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x 25 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y 26 | return y 27 | 28 | 29 | def box_iou(box1, box2): 30 | # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py 31 | """ 32 | Return intersection-over-union (Jaccard index) of boxes. 33 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 34 | Arguments: 35 | box1 (Tensor[N, 4]) 36 | box2 (Tensor[M, 4]) 37 | Returns: 38 | iou (Tensor[N, M]): the NxM matrix containing the pairwise 39 | IoU values for every element in boxes1 and boxes2 40 | """ 41 | 42 | def box_area(box): 43 | # box = 4xn 44 | return (box[2] - box[0]) * (box[3] - box[1]) 45 | 46 | area1 = box_area(box1.T) 47 | area2 = box_area(box2.T) 48 | 49 | # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) 50 | inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - 51 | torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) 52 | # iou = inter / (area1 + area2 - inter) 53 | return inter / (area1[:, None] + area2 - inter) 54 | 55 | 56 | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): 57 | # Rescale coords (xyxy) from img1_shape to img0_shape 58 | if ratio_pad is None: # calculate from img0_shape 59 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new 60 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding 61 | else: 62 | gain = ratio_pad[0][0] 63 | pad = ratio_pad[1] 64 | 65 | coords[:, [0, 2]] -= pad[0] # x padding 66 | coords[:, [1, 3]] -= pad[1] # y padding 67 | coords[:, :4] /= gain 68 | clip_coords(coords, img0_shape) 69 | return coords 70 | 71 | 72 | def clip_coords(boxes, img_shape): 73 | # Clip bounding xyxy bounding boxes to image shape (height, width) 74 | boxes[:, 0].clamp_(0, img_shape[1]) # x1 75 | boxes[:, 1].clamp_(0, img_shape[0]) # y1 76 | boxes[:, 2].clamp_(0, img_shape[1]) # x2 77 | boxes[:, 3].clamp_(0, img_shape[0]) # y2 78 | 79 | 80 | def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): 81 | # Rescale coords (xyxy) from img1_shape to img0_shape 82 | if ratio_pad is None: # calculate from img0_shape 83 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new 84 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding 85 | else: 86 | gain = ratio_pad[0][0] 87 | pad = ratio_pad[1] 88 | 89 | coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding 90 | coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding 91 | coords[:, :10] /= gain 92 | #clip_coords(coords, img0_shape) 93 | coords[:, 0].clamp_(0, img0_shape[1]) # x1 94 | coords[:, 1].clamp_(0, img0_shape[0]) # y1 95 | coords[:, 2].clamp_(0, img0_shape[1]) # x2 96 | coords[:, 3].clamp_(0, img0_shape[0]) # y2 97 | coords[:, 4].clamp_(0, img0_shape[1]) # x3 98 | coords[:, 5].clamp_(0, img0_shape[0]) # y3 99 | coords[:, 6].clamp_(0, img0_shape[1]) # x4 100 | coords[:, 7].clamp_(0, img0_shape[0]) # y4 101 | coords[:, 8].clamp_(0, img0_shape[1]) # x5 102 | coords[:, 9].clamp_(0, img0_shape[0]) # y5 103 | return coords 104 | 105 | 106 | def show_results(img, xywh, conf, landmarks, class_num): 107 | h,w,c = img.shape 108 | tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness 109 | x1 = int(xywh[0] * w - 0.5 * xywh[2] * w) 110 | y1 = int(xywh[1] * h - 0.5 * xywh[3] * h) 111 | x2 = int(xywh[0] * w + 0.5 * xywh[2] * w) 112 | y2 = int(xywh[1] * h + 0.5 * xywh[3] * h) 113 | cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=tl, lineType=cv2.LINE_AA) 114 | 115 | clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)] 116 | 117 | for i in range(5): 118 | point_x = int(landmarks[2 * i] * w) 119 | point_y = int(landmarks[2 * i + 1] * h) 120 | cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1) 121 | 122 | tf = max(tl - 1, 1) # font thickness 123 | label = str(conf)[:5] 124 | cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 125 | return img 126 | 127 | 128 | def make_divisible(x, divisor): 129 | # Returns x evenly divisible by divisor 130 | return (x // divisor) * divisor 131 | 132 | 133 | def non_max_suppression_face(prediction, conf_thres=0.5, iou_thres=0.45, classes=None, agnostic=False, labels=()): 134 | """Performs Non-Maximum Suppression (NMS) on inference results 135 | Returns: 136 | detections with shape: nx6 (x1, y1, x2, y2, conf, cls) 137 | """ 138 | 139 | nc = prediction.shape[2] - 15 # number of classes 140 | xc = prediction[..., 4] > conf_thres # candidates 141 | 142 | # Settings 143 | min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height 144 | # time_limit = 10.0 # seconds to quit after 145 | redundant = True # require redundant detections 146 | multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) 147 | merge = False # use merge-NMS 148 | 149 | # t = time.time() 150 | output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] 151 | for xi, x in enumerate(prediction): # image index, image inference 152 | # Apply constraints 153 | # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height 154 | x = x[xc[xi]] # confidence 155 | 156 | # Cat apriori labels if autolabelling 157 | if labels and len(labels[xi]): 158 | l = labels[xi] 159 | v = torch.zeros((len(l), nc + 15), device=x.device) 160 | v[:, :4] = l[:, 1:5] # box 161 | v[:, 4] = 1.0 # conf 162 | v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls 163 | x = torch.cat((x, v), 0) 164 | 165 | # If none remain process next image 166 | if not x.shape[0]: 167 | continue 168 | 169 | # Compute conf 170 | x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf 171 | 172 | # Box (center x, center y, width, height) to (x1, y1, x2, y2) 173 | box = xywh2xyxy(x[:, :4]) 174 | 175 | # Detections matrix nx6 (xyxy, conf, landmarks, cls) 176 | if multi_label: 177 | i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T 178 | x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1) 179 | else: # best class only 180 | conf, j = x[:, 15:].max(1, keepdim=True) 181 | x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] 182 | 183 | # Filter by class 184 | if classes is not None: 185 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] 186 | 187 | # If none remain process next image 188 | n = x.shape[0] # number of boxes 189 | if not n: 190 | continue 191 | 192 | # Batched NMS 193 | c = x[:, 15:16] * (0 if agnostic else max_wh) # classes 194 | boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores 195 | i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS 196 | #if i.shape[0] > max_det: # limit detections 197 | # i = i[:max_det] 198 | if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) 199 | # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) 200 | iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix 201 | weights = iou * scores[None] # box weights 202 | x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes 203 | if redundant: 204 | i = i[iou.sum(1) > 1] # require redundancy 205 | 206 | output[xi] = x[i] 207 | # if (time.time() - t) > time_limit: 208 | # break # time limit exceeded 209 | 210 | return output 211 | 212 | 213 | class YoloFace(): 214 | def __init__(self, pt_path='checkpoints/yolov5m-face.pt', confThreshold=0.5, nmsThreshold=0.45, device='cuda'): 215 | assert os.path.exists(pt_path) 216 | 217 | self.inpSize = 416 218 | self.conf_thres = confThreshold 219 | self.iou_thres = nmsThreshold 220 | self.test_device = torch.device(device if torch.cuda.is_available() else "cpu") 221 | self.model = torch.jit.load(pt_path).to(self.test_device) 222 | self.last_w = 416 223 | self.last_h = 416 224 | self.grids = None 225 | 226 | @torch.no_grad() 227 | def detect(self, srcimg): 228 | # t0=time.time() 229 | 230 | h0, w0 = srcimg.shape[:2] # orig hw 231 | r = self.inpSize / min(h0, w0) # resize image to img_size 232 | h1 = int(h0*r+31)//32*32 233 | w1 = int(w0*r+31)//32*32 234 | 235 | img = cv2.resize(srcimg, (w1,h1), interpolation=cv2.INTER_LINEAR) 236 | 237 | # Convert 238 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR to RGB 239 | 240 | # Run inference 241 | img = torch.from_numpy(img).to(self.test_device).permute(2,0,1) 242 | img = img.float()/255 # uint8 to fp16/32 0-1 243 | if img.ndimension() == 3: 244 | img = img.unsqueeze(0) 245 | 246 | # Inference 247 | if h1 != self.last_h or w1 != self.last_w or self.grids is None: 248 | grids = [] 249 | for scale in [8,16,32]: 250 | ny = h1//scale 251 | nx = w1//scale 252 | yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) 253 | grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float() 254 | grids.append(grid.to(self.test_device)) 255 | self.grids = grids 256 | self.last_w = w1 257 | self.last_h = h1 258 | 259 | pred = self.model(img, self.grids).cpu() 260 | 261 | # Apply NMS 262 | det = non_max_suppression_face(pred, self.conf_thres, self.iou_thres)[0] 263 | # Process detections 264 | # det = pred[0] 265 | bboxes = np.zeros((det.shape[0], 4)) 266 | kpss = np.zeros((det.shape[0], 5, 2)) 267 | scores = np.zeros((det.shape[0])) 268 | # gn = torch.tensor([w0, h0, w0, h0]).to(pred) # normalization gain whwh 269 | # gn_lks = torch.tensor([w0, h0, w0, h0, w0, h0, w0, h0, w0, h0]).to(pred) # normalization gain landmarks 270 | det = det.cpu().numpy() 271 | 272 | for j in range(det.shape[0]): 273 | # xywh = (xyxy2xywh(det[j, :4].view(1, 4)) / gn).view(4).cpu().numpy() 274 | bboxes[j, 0] = det[j, 0] * w0/w1 275 | bboxes[j, 1] = det[j, 1] * h0/h1 276 | bboxes[j, 2] = det[j, 2] * w0/w1 - bboxes[j, 0] 277 | bboxes[j, 3] = det[j, 3] * h0/h1 - bboxes[j, 1] 278 | scores[j] = det[j, 4] 279 | # landmarks = (det[j, 5:15].view(1, 10) / gn_lks).view(5,2).cpu().numpy() 280 | kpss[j, :, :] = det[j, 5:15].reshape(5, 2) * np.array([[w0/w1,h0/h1]]) 281 | # class_num = det[j, 15].cpu().numpy() 282 | # orgimg = show_results(orgimg, xywh, conf, landmarks, class_num) 283 | return bboxes, kpss, scores 284 | 285 | 286 | 287 | if __name__ == '__main__': 288 | import time 289 | 290 | imgpath = 'test.png' 291 | 292 | yoloface = YoloFace(pt_path='../checkpoints/yoloface_v5m.pt') 293 | srcimg = cv2.imread(imgpath) 294 | 295 | #warpup 296 | bboxes, kpss, scores = yoloface.detect(srcimg) 297 | bboxes, kpss, scores = yoloface.detect(srcimg) 298 | bboxes, kpss, scores = yoloface.detect(srcimg) 299 | 300 | t1 = time.time() 301 | for _ in range(10): 302 | bboxes, kpss, scores = yoloface.detect(srcimg) 303 | t2 = time.time() 304 | print('total time: {} ms'.format((t2 - t1) * 1000)) 305 | for i in range(bboxes.shape[0]): 306 | xmin, ymin, xamx, ymax = int(bboxes[i, 0]), int(bboxes[i, 1]), int(bboxes[i, 0] + bboxes[i, 2]), int(bboxes[i, 1] + bboxes[i, 3]) 307 | cv2.rectangle(srcimg, (xmin, ymin), (xamx, ymax), (0, 0, 255), thickness=2) 308 | for j in range(5): 309 | cv2.circle(srcimg, (int(kpss[i, j, 0]), int(kpss[i, j, 1])), 1, (0, 255, 0), thickness=5) 310 | cv2.imwrite('test_yoloface.jpg', srcimg) -------------------------------------------------------------------------------- /src/models/svfr_adapter/unet_3d_svd_condition_ip.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union, Any 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.loaders import UNet2DConditionLoadersMixin 10 | from diffusers.utils import BaseOutput, logging 11 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor 12 | 13 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 14 | from diffusers.models.modeling_utils import ModelMixin 15 | from ..svfr_adapter.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block 16 | from ..svfr_adapter.attention_processor import AttnProcessor2_0, AttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor 17 | 18 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 19 | 20 | @dataclass 21 | class UNet3DConditionSVDOutput(BaseOutput): 22 | """ 23 | The output of [`UNet3DConditionSVDModel`]. 24 | 25 | Args: 26 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): 27 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 28 | """ 29 | 30 | sample: torch.FloatTensor = None 31 | 32 | 33 | class UNet3DConditionSVDModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 34 | r""" 35 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample 36 | shaped output. 37 | 38 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 39 | for all models (such as downloading or saving). 40 | 41 | Parameters: 42 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 43 | Height and width of input/output sample. 44 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 45 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 46 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 47 | The tuple of downsample blocks to use. 48 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 49 | The tuple of upsample blocks to use. 50 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 51 | The tuple of output channels for each block. 52 | addition_time_embed_dim: (`int`, defaults to 256): 53 | Dimension to to encode the additional time ids. 54 | projection_class_embeddings_input_dim (`int`, defaults to 768): 55 | The dimension of the projection of encoded `added_time_ids`. 56 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 57 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 58 | The dimension of the cross attention features. 59 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 60 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 61 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 62 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 63 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 64 | The number of attention heads. 65 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 66 | """ 67 | 68 | _supports_gradient_checkpointing = True 69 | 70 | @register_to_config 71 | def __init__( 72 | self, 73 | sample_size: Optional[int] = None, 74 | in_channels: int = 8, 75 | out_channels: int = 4, 76 | down_block_types: Tuple[str] = ( 77 | "CrossAttnDownBlockSpatioTemporal", 78 | "CrossAttnDownBlockSpatioTemporal", 79 | "CrossAttnDownBlockSpatioTemporal", 80 | "DownBlockSpatioTemporal", 81 | ), 82 | up_block_types: Tuple[str] = ( 83 | "UpBlockSpatioTemporal", 84 | "CrossAttnUpBlockSpatioTemporal", 85 | "CrossAttnUpBlockSpatioTemporal", 86 | "CrossAttnUpBlockSpatioTemporal", 87 | ), 88 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 89 | addition_time_embed_dim: int = 256, 90 | projection_class_embeddings_input_dim: int = 768, 91 | layers_per_block: Union[int, Tuple[int]] = 2, 92 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 93 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 94 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), 95 | num_frames: int = 25, 96 | ): 97 | super().__init__() 98 | 99 | self.sample_size = sample_size 100 | 101 | # Check inputs 102 | if len(down_block_types) != len(up_block_types): 103 | raise ValueError( 104 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 105 | ) 106 | 107 | if len(block_out_channels) != len(down_block_types): 108 | raise ValueError( 109 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 110 | ) 111 | 112 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 113 | raise ValueError( 114 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 115 | ) 116 | 117 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 118 | raise ValueError( 119 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 120 | ) 121 | 122 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 123 | raise ValueError( 124 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 125 | ) 126 | 127 | # input 128 | self.conv_in = nn.Conv2d( 129 | in_channels, 130 | block_out_channels[0], 131 | kernel_size=3, 132 | padding=1, 133 | ) 134 | 135 | # time 136 | time_embed_dim = block_out_channels[0] * 4 137 | 138 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 139 | timestep_input_dim = block_out_channels[0] 140 | 141 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 142 | 143 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 144 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 145 | 146 | self.down_blocks = nn.ModuleList([]) 147 | self.up_blocks = nn.ModuleList([]) 148 | 149 | if isinstance(num_attention_heads, int): 150 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 151 | 152 | if isinstance(cross_attention_dim, int): 153 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 154 | 155 | if isinstance(layers_per_block, int): 156 | layers_per_block = [layers_per_block] * len(down_block_types) 157 | 158 | if isinstance(transformer_layers_per_block, int): 159 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 160 | 161 | blocks_time_embed_dim = time_embed_dim 162 | 163 | # down 164 | output_channel = block_out_channels[0] 165 | for i, down_block_type in enumerate(down_block_types): 166 | input_channel = output_channel 167 | output_channel = block_out_channels[i] 168 | is_final_block = i == len(block_out_channels) - 1 169 | 170 | down_block = get_down_block( 171 | down_block_type, 172 | num_layers=layers_per_block[i], 173 | transformer_layers_per_block=transformer_layers_per_block[i], 174 | in_channels=input_channel, 175 | out_channels=output_channel, 176 | temb_channels=blocks_time_embed_dim, 177 | add_downsample=not is_final_block, 178 | resnet_eps=1e-5, 179 | cross_attention_dim=cross_attention_dim[i], 180 | num_attention_heads=num_attention_heads[i], 181 | resnet_act_fn="silu", 182 | ) 183 | self.down_blocks.append(down_block) 184 | 185 | # mid 186 | self.mid_block = UNetMidBlockSpatioTemporal( 187 | block_out_channels[-1], 188 | temb_channels=blocks_time_embed_dim, 189 | transformer_layers_per_block=transformer_layers_per_block[-1], 190 | cross_attention_dim=cross_attention_dim[-1], 191 | num_attention_heads=num_attention_heads[-1], 192 | ) 193 | 194 | # count how many layers upsample the images 195 | self.num_upsamplers = 0 196 | 197 | # up 198 | reversed_block_out_channels = list(reversed(block_out_channels)) 199 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 200 | reversed_layers_per_block = list(reversed(layers_per_block)) 201 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 202 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 203 | 204 | output_channel = reversed_block_out_channels[0] 205 | for i, up_block_type in enumerate(up_block_types): 206 | is_final_block = i == len(block_out_channels) - 1 207 | 208 | prev_output_channel = output_channel 209 | output_channel = reversed_block_out_channels[i] 210 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 211 | 212 | # add upsample block for all BUT final layer 213 | if not is_final_block: 214 | add_upsample = True 215 | self.num_upsamplers += 1 216 | else: 217 | add_upsample = False 218 | 219 | up_block = get_up_block( 220 | up_block_type, 221 | num_layers=reversed_layers_per_block[i] + 1, 222 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 223 | in_channels=input_channel, 224 | out_channels=output_channel, 225 | prev_output_channel=prev_output_channel, 226 | temb_channels=blocks_time_embed_dim, 227 | add_upsample=add_upsample, 228 | resnet_eps=1e-5, 229 | resolution_idx=i, 230 | cross_attention_dim=reversed_cross_attention_dim[i], 231 | num_attention_heads=reversed_num_attention_heads[i], 232 | resnet_act_fn="silu", 233 | ) 234 | self.up_blocks.append(up_block) 235 | prev_output_channel = output_channel 236 | 237 | # out 238 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) 239 | self.conv_act = nn.SiLU() 240 | 241 | self.conv_out = nn.Conv2d( 242 | block_out_channels[0], 243 | out_channels, 244 | kernel_size=3, 245 | padding=1, 246 | ) 247 | 248 | @property 249 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 250 | r""" 251 | Returns: 252 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 253 | indexed by its weight name. 254 | """ 255 | # set recursively 256 | processors = {} 257 | 258 | def fn_recursive_add_processors( 259 | name: str, 260 | module: torch.nn.Module, 261 | processors: Dict[str, AttentionProcessor], 262 | ): 263 | if hasattr(module, "get_processor"): 264 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 265 | 266 | for sub_name, child in module.named_children(): 267 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 268 | 269 | return processors 270 | 271 | for name, module in self.named_children(): 272 | fn_recursive_add_processors(name, module, processors) 273 | 274 | return processors 275 | 276 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 277 | r""" 278 | Sets the attention processor to use to compute attention. 279 | 280 | Parameters: 281 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 282 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 283 | for **all** `Attention` layers. 284 | 285 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 286 | processor. This is strongly recommended when setting trainable attention processors. 287 | 288 | """ 289 | count = len(self.attn_processors.keys()) 290 | 291 | if isinstance(processor, dict) and len(processor) != count: 292 | raise ValueError( 293 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 294 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 295 | ) 296 | 297 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 298 | if hasattr(module, "set_processor"): 299 | if not isinstance(processor, dict): 300 | module.set_processor(processor) 301 | else: 302 | module.set_processor(processor.pop(f"{name}.processor")) 303 | 304 | for sub_name, child in module.named_children(): 305 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 306 | 307 | for name, module in self.named_children(): 308 | fn_recursive_attn_processor(name, module, processor) 309 | 310 | def set_default_attn_processor(self): 311 | """ 312 | Disables custom attention processors and sets the default attention implementation. 313 | """ 314 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 315 | processor = AttnProcessor() 316 | else: 317 | raise ValueError( 318 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 319 | ) 320 | 321 | self.set_attn_processor(processor) 322 | 323 | def _set_gradient_checkpointing(self, module, value=False): 324 | if hasattr(module, "gradient_checkpointing"): 325 | module.gradient_checkpointing = value 326 | 327 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 328 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 329 | """ 330 | Sets the attention processor to use [feed forward 331 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 332 | 333 | Parameters: 334 | chunk_size (`int`, *optional*): 335 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 336 | over each tensor of dim=`dim`. 337 | dim (`int`, *optional*, defaults to `0`): 338 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 339 | or dim=1 (sequence length). 340 | """ 341 | if dim not in [0, 1]: 342 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 343 | 344 | # By default chunk size is 1 345 | chunk_size = chunk_size or 1 346 | 347 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 348 | if hasattr(module, "set_chunk_feed_forward"): 349 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 350 | 351 | for child in module.children(): 352 | fn_recursive_feed_forward(child, chunk_size, dim) 353 | 354 | for module in self.children(): 355 | fn_recursive_feed_forward(module, chunk_size, dim) 356 | 357 | def forward( 358 | self, 359 | sample: torch.FloatTensor, 360 | timestep: Union[torch.Tensor, float, int], 361 | encoder_hidden_states: torch.Tensor, 362 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 363 | mid_block_additional_residual: Optional[torch.Tensor] = None, 364 | return_dict: bool = True, 365 | added_time_ids: torch.Tensor=None, 366 | pose_cond_fea: Optional[torch.Tensor] = None, 367 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 368 | ) -> Union[UNet3DConditionSVDOutput, Tuple]: 369 | r""" 370 | The [`UNetSpatioTemporalConditionModel`] forward method. 371 | 372 | Args: 373 | sample (`torch.FloatTensor`): 374 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 375 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 376 | encoder_hidden_states (`torch.FloatTensor`): 377 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 378 | added_time_ids: (`torch.FloatTensor`): 379 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 380 | embeddings and added to the time embeddings. 381 | return_dict (`bool`, *optional*, defaults to `True`): 382 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 383 | tuple. 384 | Returns: 385 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 386 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 387 | a `tuple` is returned where the first element is the sample tensor. 388 | """ 389 | # 1. time 390 | timesteps = timestep 391 | if not torch.is_tensor(timesteps): 392 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 393 | # This would be a good case for the `match` statement (Python 3.10+) 394 | is_mps = sample.device.type == "mps" 395 | if isinstance(timestep, float): 396 | dtype = torch.float32 if is_mps else torch.float64 397 | else: 398 | dtype = torch.int32 if is_mps else torch.int64 399 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 400 | elif len(timesteps.shape) == 0: 401 | timesteps = timesteps[None].to(sample.device) 402 | 403 | batch_size, num_frames = sample.shape[:2] 404 | timesteps = timesteps.expand(batch_size) 405 | 406 | t_emb = self.time_proj(timesteps) 407 | t_emb = t_emb.to(dtype=sample.dtype) 408 | emb = self.time_embedding(t_emb) 409 | 410 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 411 | time_embeds = time_embeds.reshape((batch_size, -1)) 412 | time_embeds = time_embeds.to(emb.dtype) 413 | aug_emb = self.add_embedding(time_embeds) 414 | emb = emb + aug_emb 415 | 416 | sample = sample.flatten(0, 1) 417 | emb = emb.repeat_interleave(num_frames, dim=0) 418 | 419 | # 2. pre-process 420 | sample = self.conv_in(sample) 421 | 422 | if pose_cond_fea is not None: 423 | sample = sample + pose_cond_fea.flatten(0, 1) 424 | 425 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 426 | 427 | down_block_res_samples = (sample,) 428 | for downsample_block in self.down_blocks: 429 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 430 | sample, res_samples = downsample_block( 431 | hidden_states=sample, 432 | temb=emb, 433 | encoder_hidden_states=encoder_hidden_states, 434 | cross_attention_kwargs=cross_attention_kwargs, 435 | image_only_indicator=image_only_indicator, 436 | ) 437 | else: 438 | sample, res_samples = downsample_block( 439 | hidden_states=sample, 440 | temb=emb, 441 | image_only_indicator=image_only_indicator, 442 | ) 443 | 444 | down_block_res_samples += res_samples 445 | 446 | 447 | # 4. mid 448 | sample = self.mid_block( 449 | hidden_states=sample, 450 | temb=emb, 451 | encoder_hidden_states=encoder_hidden_states, 452 | image_only_indicator=image_only_indicator, 453 | cross_attention_kwargs=cross_attention_kwargs, 454 | 455 | ) 456 | 457 | 458 | # 5. up 459 | for i, upsample_block in enumerate(self.up_blocks): 460 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 461 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 462 | 463 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 464 | sample = upsample_block( 465 | hidden_states=sample, 466 | temb=emb, 467 | res_hidden_states_tuple=res_samples, 468 | encoder_hidden_states=encoder_hidden_states, 469 | cross_attention_kwargs=cross_attention_kwargs, 470 | image_only_indicator=image_only_indicator, 471 | ) 472 | else: 473 | sample = upsample_block( 474 | hidden_states=sample, 475 | temb=emb, 476 | res_hidden_states_tuple=res_samples, 477 | image_only_indicator=image_only_indicator, 478 | ) 479 | 480 | # 6. post-process 481 | sample = self.conv_norm_out(sample) 482 | sample = self.conv_act(sample) 483 | sample = self.conv_out(sample) 484 | 485 | # 7. Reshape back to original shape 486 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 487 | 488 | if not return_dict: 489 | return (sample,) 490 | 491 | return UNet3DConditionSVDOutput(sample=sample) 492 | 493 | 494 | 495 | def init_ip_adapters(unet, num_adapter_embeds=[], scale=1.0): 496 | # init adapter modules 497 | attn_procs = {} 498 | unet_sd = unet.state_dict() 499 | for name in unet.attn_processors.keys(): 500 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 501 | if name.startswith("mid_block"): 502 | hidden_size = unet.config.block_out_channels[-1] 503 | elif name.startswith("up_blocks"): 504 | block_id = int(name[len("up_blocks.")]) 505 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 506 | elif name.startswith("down_blocks"): 507 | block_id = int(name[len("down_blocks.")]) 508 | hidden_size = unet.config.block_out_channels[block_id] 509 | # if cross_attention_dim is None or "temporal_transformer_blocks" in name: 510 | if cross_attention_dim is None: 511 | attn_processor_class = ( 512 | AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor 513 | ) 514 | attn_procs[name] = attn_processor_class() 515 | else: 516 | attn_processor_class = ( 517 | IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor 518 | ) 519 | 520 | attn_procs[name] = attn_processor_class( 521 | hidden_size=hidden_size, 522 | cross_attention_dim=cross_attention_dim, 523 | num_tokens=num_adapter_embeds, 524 | scale=scale 525 | ) 526 | 527 | layer_name = name.split(".processor")[0] 528 | weights = {} 529 | for i in range(len(num_adapter_embeds)): 530 | weights.update({f"to_k_ip.{i}.weight": unet_sd[layer_name + ".to_k.weight"]}) 531 | weights.update({f"to_v_ip.{i}.weight": unet_sd[layer_name + ".to_v.weight"]}) 532 | 533 | attn_procs[name].load_state_dict(weights) 534 | unet.set_attn_processor(attn_procs) 535 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) 536 | return adapter_modules 537 | -------------------------------------------------------------------------------- /src/models/svfr_adapter/attention_processor.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from typing import Callable, List, Optional, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from diffusers.image_processor import IPAdapterMaskProcessor 10 | from diffusers.utils import deprecate, logging 11 | from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available 12 | from diffusers.utils.torch_utils import maybe_allow_in_graph 13 | from diffusers.models.attention_processor import Attention 14 | 15 | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17 | 18 | if is_torch_npu_available(): 19 | import torch_npu 20 | 21 | if is_xformers_available(): 22 | import xformers 23 | import xformers.ops 24 | else: 25 | xformers = None 26 | 27 | class AttnProcessor: 28 | r""" 29 | Default processor for performing attention-related computations. 30 | """ 31 | 32 | def __call__( 33 | self, 34 | attn: Attention, 35 | hidden_states: torch.Tensor, 36 | encoder_hidden_states: Optional[torch.Tensor] = None, 37 | attention_mask: Optional[torch.Tensor] = None, 38 | temb: Optional[torch.Tensor] = None, 39 | *args, 40 | **kwargs, 41 | ) -> torch.Tensor: 42 | if len(args) > 0 or kwargs.get("scale", None) is not None: 43 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 44 | deprecate("scale", "1.0.0", deprecation_message) 45 | 46 | residual = hidden_states 47 | 48 | if attn.spatial_norm is not None: 49 | hidden_states = attn.spatial_norm(hidden_states, temb) 50 | 51 | input_ndim = hidden_states.ndim 52 | 53 | if input_ndim == 4: 54 | batch_size, channel, height, width = hidden_states.shape 55 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 56 | 57 | batch_size, sequence_length, _ = ( 58 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 59 | ) 60 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 61 | 62 | if attn.group_norm is not None: 63 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 64 | 65 | query = attn.to_q(hidden_states) 66 | 67 | if encoder_hidden_states is None: 68 | encoder_hidden_states = hidden_states 69 | elif attn.norm_cross: 70 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 71 | 72 | key = attn.to_k(encoder_hidden_states) 73 | value = attn.to_v(encoder_hidden_states) 74 | 75 | query = attn.head_to_batch_dim(query) 76 | key = attn.head_to_batch_dim(key) 77 | value = attn.head_to_batch_dim(value) 78 | 79 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 80 | hidden_states = torch.bmm(attention_probs, value) 81 | hidden_states = attn.batch_to_head_dim(hidden_states) 82 | 83 | # linear proj 84 | hidden_states = attn.to_out[0](hidden_states) 85 | # dropout 86 | hidden_states = attn.to_out[1](hidden_states) 87 | 88 | if input_ndim == 4: 89 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 90 | 91 | if attn.residual_connection: 92 | hidden_states = hidden_states + residual 93 | 94 | hidden_states = hidden_states / attn.rescale_output_factor 95 | 96 | return hidden_states 97 | 98 | class AttnProcessor2_0(nn.Module): 99 | r""" 100 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 101 | """ 102 | 103 | def __init__(self): 104 | super().__init__() 105 | if not hasattr(F, "scaled_dot_product_attention"): 106 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 107 | 108 | def __call__( 109 | self, 110 | attn: Attention, 111 | hidden_states: torch.Tensor, 112 | encoder_hidden_states: Optional[torch.Tensor] = None, 113 | attention_mask: Optional[torch.Tensor] = None, 114 | temb: Optional[torch.Tensor] = None, 115 | ip_adapter_masks: Optional[torch.Tensor] = None, 116 | *args, 117 | **kwargs, 118 | ) -> torch.Tensor: 119 | if len(args) > 0 or kwargs.get("scale", None) is not None: 120 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 121 | deprecate("scale", "1.0.0", deprecation_message) 122 | 123 | residual = hidden_states 124 | if attn.spatial_norm is not None: 125 | hidden_states = attn.spatial_norm(hidden_states, temb) 126 | 127 | input_ndim = hidden_states.ndim 128 | 129 | if input_ndim == 4: 130 | batch_size, channel, height, width = hidden_states.shape 131 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 132 | 133 | batch_size, sequence_length, _ = ( 134 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 135 | ) 136 | 137 | if attention_mask is not None: 138 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 139 | # scaled_dot_product_attention expects attention_mask shape to be 140 | # (batch, heads, source_length, target_length) 141 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 142 | 143 | if attn.group_norm is not None: 144 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 145 | 146 | query = attn.to_q(hidden_states) 147 | 148 | if encoder_hidden_states is None: 149 | encoder_hidden_states = hidden_states 150 | elif attn.norm_cross: 151 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 152 | 153 | key = attn.to_k(encoder_hidden_states) 154 | value = attn.to_v(encoder_hidden_states) 155 | 156 | inner_dim = key.shape[-1] 157 | head_dim = inner_dim // attn.heads 158 | 159 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 160 | 161 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 162 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 163 | 164 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 165 | # TODO: add support for attn.scale when we move to Torch 2.1 166 | hidden_states = F.scaled_dot_product_attention( 167 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 168 | ) 169 | 170 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 171 | hidden_states = hidden_states.to(query.dtype) 172 | 173 | # linear proj 174 | hidden_states = attn.to_out[0](hidden_states) 175 | # dropout 176 | hidden_states = attn.to_out[1](hidden_states) 177 | 178 | if input_ndim == 4: 179 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 180 | 181 | if attn.residual_connection: 182 | hidden_states = hidden_states + residual 183 | 184 | hidden_states = hidden_states / attn.rescale_output_factor 185 | 186 | return hidden_states 187 | 188 | class IPAdapterAttnProcessor(nn.Module): 189 | r""" 190 | Attention processor for Multiple IP-Adapters. 191 | 192 | Args: 193 | hidden_size (`int`): 194 | The hidden size of the attention layer. 195 | cross_attention_dim (`int`): 196 | The number of channels in the `encoder_hidden_states`. 197 | num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): 198 | The context length of the image features. 199 | scale (`float` or List[`float`], defaults to 1.0): 200 | the weight scale of image prompt. 201 | """ 202 | 203 | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): 204 | super().__init__() 205 | 206 | self.hidden_size = hidden_size 207 | self.cross_attention_dim = cross_attention_dim 208 | 209 | if not isinstance(num_tokens, (tuple, list)): 210 | num_tokens = [num_tokens] 211 | self.num_tokens = num_tokens 212 | 213 | if not isinstance(scale, list): 214 | scale = [scale] * len(num_tokens) 215 | if len(scale) != len(num_tokens): 216 | raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") 217 | self.scale = scale 218 | 219 | self.to_k_ip = nn.ModuleList( 220 | [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] 221 | ) 222 | self.to_v_ip = nn.ModuleList( 223 | [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] 224 | ) 225 | 226 | def __call__( 227 | self, 228 | attn: Attention, 229 | hidden_states: torch.Tensor, 230 | encoder_hidden_states: Optional[torch.Tensor] = None, 231 | attention_mask: Optional[torch.Tensor] = None, 232 | temb: Optional[torch.Tensor] = None, 233 | scale: float = 1.0, 234 | ip_adapter_masks: Optional[torch.Tensor] = None, 235 | ): 236 | residual = hidden_states 237 | 238 | # separate ip_hidden_states from encoder_hidden_states 239 | if encoder_hidden_states is not None: 240 | if isinstance(encoder_hidden_states, tuple): 241 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states 242 | else: 243 | deprecation_message = ( 244 | "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." 245 | " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." 246 | ) 247 | deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) 248 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] 249 | encoder_hidden_states, ip_hidden_states = ( 250 | encoder_hidden_states[:, :end_pos, :], 251 | [encoder_hidden_states[:, end_pos:, :]], 252 | ) 253 | 254 | if attn.spatial_norm is not None: 255 | hidden_states = attn.spatial_norm(hidden_states, temb) 256 | 257 | input_ndim = hidden_states.ndim 258 | 259 | if input_ndim == 4: 260 | batch_size, channel, height, width = hidden_states.shape 261 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 262 | 263 | batch_size, sequence_length, _ = ( 264 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 265 | ) 266 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 267 | 268 | if attn.group_norm is not None: 269 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 270 | 271 | query = attn.to_q(hidden_states) 272 | 273 | if encoder_hidden_states is None: 274 | encoder_hidden_states = hidden_states 275 | elif attn.norm_cross: 276 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 277 | 278 | key = attn.to_k(encoder_hidden_states) 279 | value = attn.to_v(encoder_hidden_states) 280 | 281 | query = attn.head_to_batch_dim(query) 282 | key = attn.head_to_batch_dim(key) 283 | value = attn.head_to_batch_dim(value) 284 | 285 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 286 | hidden_states = torch.bmm(attention_probs, value) 287 | hidden_states = attn.batch_to_head_dim(hidden_states) 288 | 289 | if ip_adapter_masks is not None: 290 | if not isinstance(ip_adapter_masks, List): 291 | # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] 292 | ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) 293 | if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): 294 | raise ValueError( 295 | f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " 296 | f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " 297 | f"({len(ip_hidden_states)})" 298 | ) 299 | else: 300 | for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): 301 | if not isinstance(mask, torch.Tensor) or mask.ndim != 4: 302 | raise ValueError( 303 | "Each element of the ip_adapter_masks array should be a tensor with shape " 304 | "[1, num_images_for_ip_adapter, height, width]." 305 | " Please use `IPAdapterMaskProcessor` to preprocess your mask" 306 | ) 307 | if mask.shape[1] != ip_state.shape[1]: 308 | raise ValueError( 309 | f"Number of masks ({mask.shape[1]}) does not match " 310 | f"number of ip images ({ip_state.shape[1]}) at index {index}" 311 | ) 312 | if isinstance(scale, list) and not len(scale) == mask.shape[1]: 313 | raise ValueError( 314 | f"Number of masks ({mask.shape[1]}) does not match " 315 | f"number of scales ({len(scale)}) at index {index}" 316 | ) 317 | else: 318 | ip_adapter_masks = [None] * len(self.scale) 319 | 320 | # for ip-adapter 321 | for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( 322 | ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks 323 | ): 324 | skip = False 325 | if isinstance(scale, list): 326 | if all(s == 0 for s in scale): 327 | skip = True 328 | elif scale == 0: 329 | skip = True 330 | if not skip: 331 | if mask is not None: 332 | if not isinstance(scale, list): 333 | scale = [scale] * mask.shape[1] 334 | 335 | current_num_images = mask.shape[1] 336 | for i in range(current_num_images): 337 | ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) 338 | ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) 339 | 340 | ip_key = attn.head_to_batch_dim(ip_key) 341 | ip_value = attn.head_to_batch_dim(ip_value) 342 | 343 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 344 | _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 345 | _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) 346 | 347 | mask_downsample = IPAdapterMaskProcessor.downsample( 348 | mask[:, i, :, :], 349 | batch_size, 350 | _current_ip_hidden_states.shape[1], 351 | _current_ip_hidden_states.shape[2], 352 | ) 353 | 354 | mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) 355 | 356 | hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) 357 | else: 358 | ip_key = to_k_ip(current_ip_hidden_states) 359 | ip_value = to_v_ip(current_ip_hidden_states) 360 | 361 | ip_key = attn.head_to_batch_dim(ip_key) 362 | ip_value = attn.head_to_batch_dim(ip_value) 363 | 364 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 365 | current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 366 | current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) 367 | 368 | hidden_states = hidden_states + scale * current_ip_hidden_states 369 | 370 | # linear proj 371 | hidden_states = attn.to_out[0](hidden_states) 372 | # dropout 373 | hidden_states = attn.to_out[1](hidden_states) 374 | 375 | if input_ndim == 4: 376 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 377 | 378 | if attn.residual_connection: 379 | hidden_states = hidden_states + residual 380 | 381 | hidden_states = hidden_states / attn.rescale_output_factor 382 | 383 | return hidden_states 384 | 385 | 386 | class IPAdapterAttnProcessor2_0(torch.nn.Module): 387 | r""" 388 | Attention processor for IP-Adapter for PyTorch 2.0. 389 | 390 | Args: 391 | hidden_size (`int`): 392 | The hidden size of the attention layer. 393 | cross_attention_dim (`int`): 394 | The number of channels in the `encoder_hidden_states`. 395 | num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): 396 | The context length of the image features. 397 | scale (`float` or `List[float]`, defaults to 1.0): 398 | the weight scale of image prompt. 399 | """ 400 | 401 | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): 402 | super().__init__() 403 | 404 | if not hasattr(F, "scaled_dot_product_attention"): 405 | raise ImportError( 406 | f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." 407 | ) 408 | 409 | self.hidden_size = hidden_size 410 | self.cross_attention_dim = cross_attention_dim 411 | 412 | if not isinstance(num_tokens, (tuple, list)): 413 | num_tokens = [num_tokens] 414 | self.num_tokens = num_tokens 415 | 416 | if not isinstance(scale, list): 417 | scale = [scale] * len(num_tokens) 418 | if len(scale) != len(num_tokens): 419 | raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") 420 | self.scale = scale 421 | 422 | self.to_k_ip = nn.ModuleList( 423 | [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] 424 | ) 425 | self.to_v_ip = nn.ModuleList( 426 | [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] 427 | ) 428 | 429 | def __call__( 430 | self, 431 | attn: Attention, 432 | hidden_states: torch.Tensor, 433 | encoder_hidden_states: Optional[torch.Tensor] = None, 434 | attention_mask: Optional[torch.Tensor] = None, 435 | temb: Optional[torch.Tensor] = None, 436 | scale: float = 1.0, 437 | ip_adapter_masks: Optional[torch.Tensor] = None, 438 | ): 439 | residual = hidden_states 440 | 441 | # separate ip_hidden_states from encoder_hidden_states 442 | if encoder_hidden_states is not None: 443 | if isinstance(encoder_hidden_states, tuple): 444 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states 445 | 446 | else: 447 | deprecation_message = ( 448 | "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." 449 | " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." 450 | ) 451 | deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) 452 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] 453 | encoder_hidden_states, ip_hidden_states = ( 454 | encoder_hidden_states[:, :end_pos, :], 455 | [encoder_hidden_states[:, end_pos:, :]], 456 | ) 457 | 458 | if attn.spatial_norm is not None: 459 | hidden_states = attn.spatial_norm(hidden_states, temb) 460 | 461 | input_ndim = hidden_states.ndim 462 | 463 | if input_ndim == 4: 464 | batch_size, channel, height, width = hidden_states.shape 465 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 466 | 467 | batch_size, sequence_length, _ = ( 468 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 469 | ) 470 | 471 | if attention_mask is not None: 472 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 473 | # scaled_dot_product_attention expects attention_mask shape to be 474 | # (batch, heads, source_length, target_length) 475 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 476 | 477 | if attn.group_norm is not None: 478 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 479 | 480 | query = attn.to_q(hidden_states) 481 | 482 | if encoder_hidden_states is None: 483 | encoder_hidden_states = hidden_states 484 | elif attn.norm_cross: 485 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 486 | 487 | key = attn.to_k(encoder_hidden_states) 488 | value = attn.to_v(encoder_hidden_states) 489 | 490 | inner_dim = key.shape[-1] 491 | head_dim = inner_dim // attn.heads 492 | 493 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 494 | 495 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 496 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 497 | 498 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 499 | # TODO: add support for attn.scale when we move to Torch 2.1 500 | hidden_states = F.scaled_dot_product_attention( 501 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 502 | ) 503 | 504 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 505 | hidden_states = hidden_states.to(query.dtype) 506 | 507 | if ip_adapter_masks is not None: 508 | if not isinstance(ip_adapter_masks, List): 509 | # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] 510 | ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) 511 | if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): 512 | raise ValueError( 513 | f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " 514 | f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " 515 | f"({len(ip_hidden_states)})" 516 | ) 517 | else: 518 | for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): 519 | ip_hidden_states[index] = ip_state = ip_state.unsqueeze(1) 520 | if not isinstance(mask, torch.Tensor) or mask.ndim != 4: 521 | raise ValueError( 522 | "Each element of the ip_adapter_masks array should be a tensor with shape " 523 | "[1, num_images_for_ip_adapter, height, width]." 524 | " Please use `IPAdapterMaskProcessor` to preprocess your mask" 525 | ) 526 | if mask.shape[1] != ip_state.shape[1]: 527 | raise ValueError( 528 | f"Number of masks ({mask.shape[1]}) does not match " 529 | f"number of ip images ({ip_state.shape[1]}) at index {index}" 530 | ) 531 | if isinstance(scale, list) and not len(scale) == mask.shape[1]: 532 | raise ValueError( 533 | f"Number of masks ({mask.shape[1]}) does not match " 534 | f"number of scales ({len(scale)}) at index {index}" 535 | ) 536 | else: 537 | ip_adapter_masks = [None] * len(self.scale) 538 | 539 | # for ip-adapter 540 | for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( 541 | ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks 542 | ): 543 | skip = False 544 | if isinstance(scale, list): 545 | if all(s == 0 for s in scale): 546 | skip = True 547 | elif scale == 0: 548 | skip = True 549 | if not skip: 550 | if mask is not None: 551 | if not isinstance(scale, list): 552 | scale = [scale] * mask.shape[1] 553 | 554 | current_num_images = mask.shape[1] 555 | for i in range(current_num_images): 556 | ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) 557 | ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) 558 | 559 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 560 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 561 | 562 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 563 | # TODO: add support for attn.scale when we move to Torch 2.1 564 | _current_ip_hidden_states = F.scaled_dot_product_attention( 565 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 566 | ) 567 | 568 | _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( 569 | batch_size, -1, attn.heads * head_dim 570 | ) 571 | _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) 572 | 573 | mask_downsample = IPAdapterMaskProcessor.downsample( 574 | mask[:, i, :, :], 575 | batch_size, 576 | _current_ip_hidden_states.shape[1], 577 | _current_ip_hidden_states.shape[2], 578 | ) 579 | mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) 580 | hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) 581 | 582 | else: 583 | ip_key = to_k_ip(current_ip_hidden_states) 584 | ip_value = to_v_ip(current_ip_hidden_states) 585 | 586 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 587 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 588 | 589 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 590 | # TODO: add support for attn.scale when we move to Torch 2.1 591 | current_ip_hidden_states = F.scaled_dot_product_attention( 592 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 593 | ) 594 | 595 | current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( 596 | batch_size, -1, attn.heads * head_dim 597 | ) 598 | current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) 599 | 600 | hidden_states = hidden_states + scale * current_ip_hidden_states 601 | 602 | 603 | # linear proj 604 | hidden_states = attn.to_out[0](hidden_states) 605 | # dropout 606 | hidden_states = attn.to_out[1](hidden_states) 607 | 608 | if input_ndim == 4: 609 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 610 | 611 | if attn.residual_connection: 612 | hidden_states = hidden_states + residual 613 | 614 | hidden_states = hidden_states / attn.rescale_output_factor 615 | 616 | return hidden_states 617 | -------------------------------------------------------------------------------- /src/pipelines/pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, Dict, List, Optional, Union 4 | import numpy as np 5 | import PIL.Image 6 | import torch 7 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 8 | 9 | from diffusers.image_processor import VaeImageProcessor 10 | # from diffusers.models import UNetSpatioTemporalConditionModel 11 | from diffusers.utils import BaseOutput, logging 12 | from diffusers.utils.torch_utils import randn_tensor, is_compiled_module 13 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 14 | from diffusers import ( 15 | AutoencoderKLTemporalDecoder, 16 | EulerDiscreteScheduler, 17 | ) 18 | 19 | # from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel 20 | from ..models.svfr_adapter.unet_3d_svd_condition_ip import UNet3DConditionSVDModel 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | 26 | def _append_dims(x, target_dims): 27 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 28 | dims_to_append = target_dims - x.ndim 29 | if dims_to_append < 0: 30 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 31 | return x[(...,) + (None,) * dims_to_append] 32 | 33 | 34 | def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"): 35 | batch_size, channels, num_frames, height, width = video.shape 36 | outputs = [] 37 | for batch_idx in range(batch_size): 38 | batch_vid = video[batch_idx].permute(1, 0, 2, 3) 39 | batch_output = processor.postprocess(batch_vid, output_type) 40 | 41 | outputs.append(batch_output) 42 | 43 | if output_type == "np": 44 | outputs = np.stack(outputs) 45 | 46 | elif output_type == "pt": 47 | outputs = torch.stack(outputs) 48 | 49 | elif not output_type == "pil": 50 | raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") 51 | 52 | return outputs 53 | 54 | 55 | @dataclass 56 | class LQ2VideoSVDPipelineOutput(BaseOutput): 57 | r""" 58 | Output class for zero-shot text-to-video pipeline. 59 | 60 | Args: 61 | frames (`[List[PIL.Image.Image]`, `np.ndarray`]): 62 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 63 | num_channels)`. 64 | """ 65 | 66 | frames: Union[List[PIL.Image.Image], np.ndarray] 67 | latents: Union[torch.Tensor, np.ndarray] 68 | 69 | 70 | class LQ2VideoLongSVDPipeline(DiffusionPipeline): 71 | r""" 72 | Pipeline to generate video from an input image using Stable Video Diffusion. 73 | 74 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 75 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 76 | 77 | Args: 78 | vae ([`AutoencoderKL`]): 79 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 80 | image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): 81 | Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). 82 | unet ([`UNetSpatioTemporalConditionModel`]): 83 | A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. 84 | scheduler ([`EulerDiscreteScheduler`]): 85 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. 86 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 87 | A `CLIPImageProcessor` to extract features from generated images. 88 | """ 89 | 90 | model_cpu_offload_seq = "unet->vae" 91 | _callback_tensor_inputs = ["latents"] 92 | 93 | def __init__( 94 | self, 95 | vae: AutoencoderKLTemporalDecoder, 96 | #image_encoder: CLIPVisionModelWithProjection, 97 | unet: UNet3DConditionSVDModel, 98 | scheduler: EulerDiscreteScheduler, 99 | feature_extractor: CLIPImageProcessor, 100 | vae_config=None, 101 | 102 | ): 103 | super().__init__() 104 | self.register_modules( 105 | vae=vae, 106 | #image_encoder=image_encoder, 107 | unet=unet, 108 | scheduler=scheduler, 109 | feature_extractor=feature_extractor, 110 | ) 111 | 112 | self.vae_config=vae_config 113 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 114 | 115 | # print("vae:", self.vae_scale_factor) 116 | 117 | self.image_processor = VaeImageProcessor( 118 | vae_scale_factor=self.vae_scale_factor, 119 | do_convert_rgb=True) 120 | 121 | 122 | def _clip_encode_image(self, image, num_frames, device, num_videos_per_prompt, do_classifier_free_guidance): 123 | dtype = next(self.image_encoder.parameters()).dtype 124 | 125 | if not isinstance(image, torch.Tensor): 126 | image = self.image_processor.pil_to_numpy(image) 127 | image = self.image_processor.numpy_to_pt(image) 128 | 129 | image = image * 2.0 - 1.0 130 | image = _resize_with_antialiasing(image, (224, 224)) 131 | image = (image + 1.0) / 2.0 132 | 133 | # Normalize the image with for CLIP input 134 | image = self.feature_extractor( 135 | images=image, 136 | do_normalize=True, 137 | do_center_crop=False, 138 | do_resize=False, 139 | do_rescale=False, 140 | return_tensors="pt", 141 | ).pixel_values 142 | 143 | image = image.to(device=device, dtype=dtype, non_blocking=True,).unsqueeze(0) # 3,224,224 144 | image_embeddings = self.image_encoder(image).image_embeds 145 | image_embeddings = image_embeddings.unsqueeze(1) 146 | 147 | # duplicate image embeddings for each generation per prompt, using mps friendly method 148 | bs_embed, seq_len, _ = image_embeddings.shape 149 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) 150 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 151 | 152 | if do_classifier_free_guidance: 153 | negative_image_embeddings = torch.zeros_like(image_embeddings) 154 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) 155 | # image_embeddings = torch.cat([image_embeddings, image_embeddings]) 156 | 157 | return image_embeddings 158 | 159 | def _encode_vae_image( 160 | self, 161 | image: torch.Tensor, 162 | device, 163 | num_videos_per_prompt, 164 | do_classifier_free_guidance, 165 | cf_vae, 166 | use_cf=True, 167 | ): 168 | if not use_cf: 169 | image = image.to(device=self.vae.device) 170 | image_latents = self.vae.encode(image).latent_dist.mode() 171 | image_latents = image_latents.to(device=device) 172 | else: 173 | image_latents = cf_vae.encode(image) 174 | image_latents = image_latents.to(device=device) 175 | #print("image_latents:", image_latents.shape) #torch.Size([20, 4, 64, 64]) 176 | #image_latents = image_latents * 0.18215 177 | image_latents = image_latents.unsqueeze(0) 178 | 179 | if do_classifier_free_guidance: 180 | negative_image_latents = torch.zeros_like(image_latents) 181 | 182 | # For classifier free guidance, we need to do two forward passes. 183 | # Here we concatenate the unconditional and text embeddings into a single batch 184 | # to avoid doing two forward passes 185 | # image_latents = torch.cat([negative_image_latents, image_latents]) 186 | image_latents = torch.cat([image_latents, image_latents]) 187 | 188 | # duplicate image_latents for each generation per prompt, using mps friendly method 189 | image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1) 190 | 191 | return image_latents 192 | 193 | def _get_add_time_ids( 194 | self, 195 | task_id_input, 196 | dtype, 197 | batch_size, 198 | num_videos_per_prompt, 199 | do_classifier_free_guidance, 200 | ): 201 | 202 | passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(task_id_input) 203 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 204 | 205 | if expected_add_embed_dim != passed_add_embed_dim: 206 | raise ValueError( 207 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 208 | ) 209 | 210 | # add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 211 | # add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 212 | add_time_ids = task_id_input.to(dtype) 213 | add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 214 | 215 | if do_classifier_free_guidance: 216 | add_time_ids = torch.cat([add_time_ids, add_time_ids]) 217 | 218 | return add_time_ids 219 | 220 | def decode_latents(self, latents, num_frames, decode_chunk_size=14): 221 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] 222 | latents = latents.flatten(0, 1) 223 | 224 | latents = 1 / self.vae.config.scaling_factor * latents 225 | 226 | forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward 227 | accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) 228 | 229 | # decode decode_chunk_size frames at a time to avoid OOM 230 | frames = [] 231 | for i in range(0, latents.shape[0], decode_chunk_size): 232 | num_frames_in = latents[i : i + decode_chunk_size].shape[0] 233 | decode_kwargs = {} 234 | if accepts_num_frames: 235 | # we only pass num_frames_in if it's expected 236 | decode_kwargs["num_frames"] = num_frames_in 237 | 238 | frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample 239 | frames.append(frame) 240 | frames = torch.cat(frames, dim=0) 241 | 242 | # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] 243 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) 244 | 245 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 246 | frames = frames.float() 247 | return frames 248 | 249 | def check_inputs(self, image, height, width): 250 | if ( 251 | not isinstance(image, torch.Tensor) 252 | and not isinstance(image, PIL.Image.Image) 253 | and not isinstance(image, list) 254 | ): 255 | raise ValueError( 256 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" 257 | f" {type(image)}" 258 | ) 259 | 260 | if height % 8 != 0 or width % 8 != 0: 261 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 262 | 263 | def prepare_latents( 264 | self, 265 | batch_size, 266 | num_frames, 267 | num_channels_latents, 268 | height, 269 | width, 270 | dtype, 271 | device, 272 | generator, 273 | latents=None, 274 | ref_image_latents=None, 275 | timestep=None 276 | ): 277 | from ..utils.noise_util import random_noise 278 | shape = ( 279 | batch_size, 280 | num_frames, 281 | num_channels_latents // 3, 282 | height // self.vae_scale_factor, 283 | width // self.vae_scale_factor, 284 | ) 285 | if isinstance(generator, list) and len(generator) != batch_size: 286 | raise ValueError( 287 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 288 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 289 | ) 290 | 291 | if latents is None: 292 | # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 293 | # noise = video_fusion_noise(shape=shape, generator=generator, device=device, dtype=dtype) 294 | # noise = video_fusion_noise_repeat(shape=shape, generator=generator, device=device, dtype=dtype) 295 | noise = random_noise(shape=shape, generator=generator, device=device, dtype=dtype) 296 | # noise = video_fusion_noise_repeat_0830(shape=shape, generator=generator, device=device, dtype=dtype) 297 | else: 298 | noise = latents.to(device) 299 | 300 | # scale the initial noise by the standard deviation required by the scheduler 301 | if timestep is not None: 302 | init_latents = ref_image_latents.unsqueeze(0) 303 | # init_latents = ref_image_latents.unsqueeze(1) 304 | #print(f"noise.shape: {noise.shape}, init_latents.shape: {init_latents.shape}") 305 | # print(init_latents.is_cuda,noise.is_cuda,timestep.is_cuda) 306 | latents = self.scheduler.add_noise(init_latents, noise, timestep) 307 | else: 308 | latents = noise * self.scheduler.init_noise_sigma 309 | 310 | return latents 311 | 312 | def get_timesteps(self, num_inference_steps, strength, device): 313 | # get the original timestep using init_timestep 314 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 315 | 316 | t_start = max(num_inference_steps - init_timestep, 0) 317 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] 318 | 319 | return timesteps, num_inference_steps - t_start 320 | 321 | @property 322 | def guidance_scale1(self): 323 | return self._guidance_scale1 324 | 325 | @property 326 | def guidance_scale2(self): 327 | return self._guidance_scale2 328 | 329 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 330 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 331 | # corresponds to doing no classifier free guidance. 332 | # @property 333 | # def do_classifier_free_guidance(self): 334 | # return True 335 | 336 | @property 337 | def num_timesteps(self): 338 | return self._num_timesteps 339 | 340 | @torch.no_grad() 341 | def __call__( 342 | self, 343 | ref_image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], # lq 344 | ref_concat_image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], # last concat ref img 345 | id_prompts: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], # id encode_hidden_state 346 | # task_id: int = 0, 347 | task_id_input: torch.Tensor = None, 348 | height: int = 512, 349 | width: int = 512, 350 | num_frames: Optional[int] = None, 351 | num_inference_steps: int = 25, 352 | min_guidance_scale=1.0, # 1.0, 353 | max_guidance_scale=3.0, 354 | noise_aug_strength: int = 0.02, 355 | decode_chunk_size: Optional[int] = None, 356 | num_videos_per_prompt: Optional[int] = 1, 357 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 358 | latents: Optional[torch.FloatTensor] = None, 359 | output_type: Optional[str] = "pil", 360 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 361 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 362 | return_dict: bool = True, 363 | do_classifier_free_guidance: bool = True, 364 | overlap=7, 365 | frames_per_batch=14, 366 | i2i_noise_strength=1.0, 367 | cf_vae=None, 368 | use_cf=False, 369 | ): 370 | r""" 371 | The call function to the pipeline for generation. 372 | 373 | Args: 374 | image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): 375 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with 376 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). 377 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 378 | The height in pixels of the generated image. 379 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 380 | The width in pixels of the generated image. 381 | num_frames (`int`, *optional*): 382 | The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` 383 | num_inference_steps (`int`, *optional*, defaults to 25): 384 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 385 | expense of slower inference. This parameter is modulated by `strength`. 386 | min_guidance_scale (`float`, *optional*, defaults to 1.0): 387 | The minimum guidance scale. Used for the classifier free guidance with first frame. 388 | max_guidance_scale (`float`, *optional*, defaults to 3.0): 389 | The maximum guidance scale. Used for the classifier free guidance with last frame. 390 | noise_aug_strength (`int`, *optional*, defaults to 0.02): 391 | The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. 392 | decode_chunk_size (`int`, *optional*): 393 | The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency 394 | between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once 395 | for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. 396 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 397 | The number of images to generate per prompt. 398 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 399 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 400 | generation deterministic. 401 | latents (`torch.FloatTensor`, *optional*): 402 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 403 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 404 | tensor is generated by sampling using the supplied random `generator`. 405 | output_type (`str`, *optional*, defaults to `"pil"`): 406 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 407 | callback_on_step_end (`Callable`, *optional*): 408 | A function that calls at the end of each denoising steps during the inference. The function is called 409 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 410 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 411 | `callback_on_step_end_tensor_inputs`. 412 | callback_on_step_end_tensor_inputs (`List`, *optional*): 413 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 414 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 415 | `._callback_tensor_inputs` attribute of your pipeline class. 416 | return_dict (`bool`, *optional*, defaults to `True`): 417 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 418 | plain tuple. 419 | 420 | Returns: 421 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: 422 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, 423 | otherwise a `tuple` is returned where the first element is a list of list with the generated frames. 424 | 425 | Examples: 426 | 427 | ```py 428 | from diffusers import StableVideoDiffusionPipeline 429 | from diffusers.utils import load_image, export_to_video 430 | 431 | pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") 432 | pipe.to("cuda") 433 | 434 | image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") 435 | image = image.resize((1024, 576)) 436 | 437 | frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] 438 | export_to_video(frames, "generated.mp4", fps=7) 439 | ``` 440 | """ 441 | # 0. Default height and width to unet 442 | height = height or self.unet.config.sample_size * self.vae_scale_factor 443 | width = width or self.unet.config.sample_size * self.vae_scale_factor 444 | 445 | # print(min_guidance_scale, max_guidance_scale) 446 | 447 | num_frames = num_frames if num_frames is not None else self.unet.config.num_frames 448 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames 449 | 450 | # 1. Check inputs. Raise error if not correct 451 | self.check_inputs(ref_image, height, width) 452 | 453 | # 2. Define call parameters 454 | if isinstance(ref_image, PIL.Image.Image): 455 | batch_size = 1 456 | elif isinstance(ref_image, list): 457 | batch_size = len(ref_image) 458 | else: 459 | if len(ref_image.shape)==4: 460 | batch_size = 1 461 | else: 462 | batch_size = ref_image.shape[0] 463 | 464 | # if not use_cf: 465 | # self.vae=cf_vae 466 | device = self._execution_device 467 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 468 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 469 | # corresponds to doing no classifier free guidance. 470 | # do_classifier_free_guidance = True #True 471 | 472 | # 3. Prepare clip image embeds 473 | # image_embeddings = torch.zeros([2,1,1024],dtype=self.vae.dtype).to(device) 474 | # image_embeddings = self._clip_encode_image( 475 | # clip_image, 476 | # num_frames, 477 | # device, 478 | # num_videos_per_prompt, 479 | # do_classifier_free_guidance,) 480 | # print(image_embeddings) 481 | image_embeddings = torch.cat([torch.zeros_like(id_prompts),id_prompts], dim=0) if do_classifier_free_guidance else id_prompts 482 | # image_embeddings = torch.cat([torch.zeros_like(id_prompts),id_prompts,id_prompts], dim=0) 483 | # image_embeddings = torch.cat([id_prompts,id_prompts,id_prompts], dim=0) 484 | # image_embeddings = torch.cat([torch.zeros_like(id_prompts),torch.zeros_like(id_prompts),torch.zeros_like(id_prompts)], dim=0) 485 | # image_embeddings = torch.cat([id_prompts_neg, id_prompts, id_prompts], dim=0) 486 | 487 | 488 | # NOTE: Stable Diffusion Video was conditioned on fps - 1, which 489 | # is why it is reduced here. 490 | # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 491 | # fps = fps - 1 492 | 493 | # 4. Encode input image using VAE 494 | if not use_cf: 495 | needs_upcasting = (self.vae.dtype == torch.float16 or self.vae.dtype == torch.bfloat16) and self.vae.config.force_upcast 496 | vae_dtype = self.vae.dtype 497 | # if needs_upcasting: 498 | # self.vae.to(dtype=torch.float32) # vae_dtype = torch.float32 499 | else: 500 | needs_upcasting=False 501 | vae_dtype=cf_vae.vae_dtype 502 | 503 | 504 | if not use_cf: 505 | # Prepare ref image latents 506 | ref_image_tensor = ref_image.to(device=self.vae.device,dtype=self.vae.dtype, ) 507 | #print(ref_image.shape) 508 | else: 509 | ref_image_tensor = ref_image.to(dtype=vae_dtype) #torch.Size([16, 3, 512, 512]) 510 | 511 | # bsz = ref_image_tensor.shape[0] 512 | # ref_image_tensor = rearrange(ref_image_tensor,'b f c h w-> (b f) c h w') 513 | if not use_cf: 514 | chunk_size = 20 515 | ref_image_latents = [] 516 | for chunk_idx in range((ref_image_tensor.shape[0]//chunk_size)+1): 517 | if chunk_idx*chunk_size>=num_frames: break 518 | ref_image_latent = self.vae.encode(ref_image_tensor[chunk_idx*chunk_size:(chunk_idx+1)*chunk_size]).latent_dist.mean #TODO 519 | ref_image_latents.append(ref_image_latent) 520 | ref_image_latents = torch.cat(ref_image_latents,dim=0) 521 | # print(ref_image_tensor.shape,ref_image_latents.shape) 522 | ref_image_latents = ref_image_latents * 0.18215 # (f, 4, h, w) 523 | # ref_image_latents = rearrange(ref_image_latents, '(b f) c h w-> b f c h w', b=bsz) 524 | 525 | noise = randn_tensor( 526 | ref_image_tensor.shape, 527 | generator=generator, 528 | device=self.vae.device, 529 | dtype=self.vae.dtype) 530 | 531 | ref_image_tensor = ref_image_tensor + noise_aug_strength * noise 532 | 533 | image_latents = [] 534 | for chunk_idx in range((ref_image_tensor.shape[0]//chunk_size)+1): 535 | if chunk_idx*chunk_size>=num_frames: break 536 | image_latent = self._encode_vae_image( 537 | ref_image_tensor[chunk_idx*chunk_size:(chunk_idx+1)*chunk_size], 538 | device=device, 539 | num_videos_per_prompt=num_videos_per_prompt, 540 | do_classifier_free_guidance=do_classifier_free_guidance, 541 | cf_vae=cf_vae, 542 | use_cf=use_cf, 543 | ) 544 | image_latents.append(image_latent) 545 | image_latents = torch.cat(image_latents, dim=1) 546 | # print(ref_image_tensor.shape,image_latents.shape) 547 | # print(image_latents.shape) 548 | image_latents = image_latents.to(image_embeddings.dtype) 549 | ref_image_latents = ref_image_latents.to(image_embeddings.dtype) 550 | else: 551 | chunk_size = 20 552 | ref_image_latents = [] 553 | for chunk_idx in range((ref_image_tensor.shape[0]//chunk_size)+1): 554 | if chunk_idx*chunk_size>=num_frames: break 555 | ref_image_latent=cf_vae.encode(ref_image_tensor[chunk_idx*chunk_size:(chunk_idx+1)*chunk_size].permute(0,2,3,1)) 556 | ref_image_latent.to(device=device,dtype=vae_dtype) 557 | ref_image_latents.append(ref_image_latent) 558 | ref_image_latents = torch.cat(ref_image_latents,dim=0) 559 | ref_image_latents = ref_image_latents * 0.18215 # (f, 4, h, w) 560 | 561 | image_latents = [] 562 | for chunk_idx in range((ref_image_tensor.shape[0]//chunk_size)+1): 563 | if chunk_idx*chunk_size>=num_frames: break 564 | img_l=cf_vae.encode(ref_image_tensor[chunk_idx*chunk_size:(chunk_idx+1)*chunk_size].permute(0,2,3,1)) 565 | image_latents_ = torch.cat([img_l.unsqueeze(0), img_l.unsqueeze(0)]) 566 | image_latents.append(image_latents_) 567 | image_latents = torch.cat(image_latents, dim=1) #torch.Size([16, 3, 512, 512]) torch.Size([32, 4, 64, 64]) 568 | image_latents = image_latents.to(device,vae_dtype) 569 | ref_image_latents = ref_image_latents.to(device,vae_dtype) 570 | #print(ref_image_tensor.shape,image_latents.shape,1) 571 | 572 | # cast back to fp16 if needed 573 | if not use_cf: 574 | if needs_upcasting: 575 | self.vae.to(dtype=vae_dtype) 576 | 577 | # Repeat the image latents for each frame so we can concatenate them with the noise 578 | # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] 579 | # image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) 580 | 581 | if not use_cf: 582 | if ref_concat_image is not None: 583 | ref_concat_tensor = ref_concat_image.to( 584 | dtype=self.vae.dtype, device=self.vae.device 585 | ) 586 | ref_concat_tensor = self.vae.encode(ref_concat_tensor.unsqueeze(0)).latent_dist.mode() 587 | ref_concat_tensor = ref_concat_tensor.unsqueeze(0).repeat(1,num_frames,1,1,1) 588 | ref_concat_tensor = torch.cat([torch.zeros_like(ref_concat_tensor), ref_concat_tensor]) if do_classifier_free_guidance else ref_concat_tensor 589 | ref_concat_tensor = ref_concat_tensor.to(image_embeddings) 590 | else: 591 | ref_concat_tensor = torch.zeros_like(image_latents) 592 | else: 593 | if ref_concat_image is not None: 594 | ref_concat_tensor = ref_concat_image #torch.Size([3, 512, 512]) 595 | #print(ref_concat_tensor.shape) 596 | #ref_concat_tensor = self.vae.encode(ref_concat_tensor.unsqueeze(0)).latent_dist.mode() 597 | ref_concat_tensor = cf_vae.encode(ref_concat_tensor.unsqueeze(0).permute(0,2,3,1)) 598 | ref_concat_tensor = ref_concat_tensor.unsqueeze(0).repeat(1,num_frames,1,1,1) 599 | ref_concat_tensor = torch.cat([torch.zeros_like(ref_concat_tensor), ref_concat_tensor]) if do_classifier_free_guidance else ref_concat_tensor 600 | ref_concat_tensor = ref_concat_tensor.to(device=device,dtype=vae_dtype) 601 | else: 602 | ref_concat_tensor = torch.zeros_like(image_latents).to(device=device,dtype=vae_dtype) 603 | 604 | # 5. Get Added Time IDs 605 | added_time_ids = self._get_add_time_ids( 606 | task_id_input, 607 | image_embeddings.dtype, 608 | batch_size, 609 | num_videos_per_prompt, 610 | do_classifier_free_guidance, 611 | ) 612 | added_time_ids = added_time_ids.to(device, dtype=self.unet.dtype) 613 | 614 | # 4. Prepare timesteps 615 | self.scheduler.set_timesteps(num_inference_steps, device=device) 616 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, i2i_noise_strength, device) 617 | latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) 618 | 619 | 620 | # 5. Prepare latent variables 621 | num_channels_latents = self.unet.config.in_channels 622 | 623 | latents = self.prepare_latents( 624 | batch_size * num_videos_per_prompt, 625 | num_frames, 626 | num_channels_latents, 627 | height, 628 | width, 629 | image_embeddings.dtype, 630 | device, 631 | generator, 632 | latents, 633 | ref_image_latents, 634 | timestep=latent_timestep 635 | ) 636 | 637 | #print(latents.shape)#torch.Size([1, 54, 4, 32, 32]) 638 | 639 | # 7. Prepare guidance scale 640 | guidance_scale = torch.linspace( 641 | min_guidance_scale, 642 | max_guidance_scale, 643 | num_inference_steps) 644 | guidance_scale1 = guidance_scale.to(device, latents.dtype) 645 | guidance_scale2 = guidance_scale.to(device, latents.dtype) 646 | 647 | 648 | self._guidance_scale1 = guidance_scale1 649 | self._guidance_scale2 = guidance_scale2 650 | 651 | # 8. Denoising loop 652 | latents_all = latents # for any-frame generation 653 | 654 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 655 | self._num_timesteps = len(timesteps) 656 | shift = 0 657 | with self.progress_bar(total=num_inference_steps) as progress_bar: 658 | for i, t in enumerate(timesteps): 659 | 660 | # init 661 | pred_latents = torch.zeros_like( 662 | latents_all, 663 | dtype=self.unet.dtype, 664 | ) 665 | counter = torch.zeros( 666 | (latents_all.shape[0], num_frames, 1, 1, 1), 667 | dtype=self.unet.dtype, 668 | ).to(device=latents_all.device) 669 | 670 | for batch, index_start in enumerate(range(0, num_frames, frames_per_batch - overlap*(i<3))): 671 | self.scheduler._step_index = None 672 | index_start -= shift 673 | def indice_slice(tensor, idx_list): 674 | tensor_list = [] 675 | for idx in idx_list: 676 | idx = idx % tensor.shape[1] 677 | tensor_list.append(tensor[:,idx]) 678 | return torch.stack(tensor_list, 1) 679 | idx_list = list(range(index_start, index_start+frames_per_batch)) 680 | latents = indice_slice(latents_all, idx_list) 681 | image_latents_input = indice_slice(image_latents, idx_list) 682 | image_embeddings_input = indice_slice(image_embeddings, idx_list) 683 | ref_concat_tensor_input = indice_slice(ref_concat_tensor, idx_list) 684 | 685 | 686 | # if index_start + frames_per_batch >= num_frames: 687 | # index_start = num_frames - frames_per_batch 688 | 689 | # latents = latents_all[:, index_start:index_start + frames_per_batch] 690 | # image_latents_input = image_latents[:, index_start:index_start + frames_per_batch] 691 | 692 | # expand the latents if we are doing classifier free guidance 693 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 694 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 695 | 696 | # = torch.cat([torch.zeros_like(image_latents_input),image_latents_input]) if do_classifier_free_guidance else image_latents_input 697 | # image_latents_input = torch.zeros_like(image_latents_input) 698 | # image_latents_input = torch.cat([image_latents_input] * 2) if do_classifier_free_guidance else image_latents_input 699 | 700 | 701 | # Concatenate image_latents over channels dimention 702 | # print(latent_model_input.shape, image_latents_input.shape) 703 | latent_model_input = torch.cat([ 704 | latent_model_input, 705 | image_latents_input, 706 | ref_concat_tensor_input], dim=2) 707 | # predict the noise residual 708 | noise_pred = self.unet( 709 | latent_model_input, 710 | t, 711 | encoder_hidden_states=image_embeddings_input.flatten(0,1), 712 | added_time_ids=added_time_ids, 713 | return_dict=False, 714 | )[0] 715 | # perform guidance 716 | if do_classifier_free_guidance: 717 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3) 718 | noise_pred = noise_pred_uncond + self.guidance_scale1[i] * (noise_pred_cond - noise_pred_uncond) #+ self.guidance_scale2[i] * (noise_pred_cond - noise_pred_drop_id) 719 | 720 | # compute the previous noisy sample x_t -> x_t-1 721 | latents = self.scheduler.step(noise_pred, t.to(self.unet.dtype), latents).prev_sample 722 | 723 | if callback_on_step_end is not None: 724 | callback_kwargs = {} 725 | for k in callback_on_step_end_tensor_inputs: 726 | callback_kwargs[k] = locals()[k] 727 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 728 | 729 | latents = callback_outputs.pop("latents", latents) 730 | 731 | # if batch == 0: 732 | for iii in range(frames_per_batch): 733 | # pred_latents[:, index_start + iii:index_start + iii + 1] += latents[:, iii:iii+1] * min(iii + 1, frames_per_batch-iii) 734 | # counter[:, index_start + iii:index_start + iii + 1] += min(iii + 1, frames_per_batch-iii) 735 | p = (index_start + iii) % pred_latents.shape[1] 736 | pred_latents[:, p] += latents[:, iii] * min(iii + 1, frames_per_batch-iii) 737 | counter[:, p] += 1 * min(iii + 1, frames_per_batch-iii) 738 | 739 | 740 | shift += overlap 741 | shift = shift % frames_per_batch 742 | 743 | pred_latents = pred_latents / counter 744 | latents_all = pred_latents 745 | 746 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 747 | progress_bar.update() 748 | 749 | latents = latents_all 750 | if not output_type == "latent": 751 | # cast back to fp16 if needed 752 | if needs_upcasting: 753 | self.vae.to(dtype=vae_dtype) 754 | 755 | #print(latents.shape,latents.dtype) #torch.Size([1, 16, 4, 32, 32]) torch.float16 756 | frames = self.decode_latents(latents, num_frames, decode_chunk_size) 757 | else: 758 | 759 | frames = latents 760 | 761 | self.maybe_free_model_hooks() 762 | 763 | if not return_dict: 764 | return frames 765 | return LQ2VideoSVDPipelineOutput(frames=frames,latents=latents) 766 | 767 | 768 | # resizing utils 769 | # TODO: clean up later 770 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 771 | h, w = input.shape[-2:] 772 | factors = (h / size[0], w / size[1]) 773 | 774 | # First, we have to determine sigma 775 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 776 | sigmas = ( 777 | max((factors[0] - 1.0) / 2.0, 0.001), 778 | max((factors[1] - 1.0) / 2.0, 0.001), 779 | ) 780 | 781 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 782 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 783 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 784 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 785 | 786 | # Make sure it is odd 787 | if (ks[0] % 2) == 0: 788 | ks = ks[0] + 1, ks[1] 789 | 790 | if (ks[1] % 2) == 0: 791 | ks = ks[0], ks[1] + 1 792 | 793 | input = _gaussian_blur2d(input, ks, sigmas) 794 | 795 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 796 | return output 797 | 798 | 799 | def _compute_padding(kernel_size): 800 | """Compute padding tuple.""" 801 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 802 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 803 | if len(kernel_size) < 2: 804 | raise AssertionError(kernel_size) 805 | computed = [k - 1 for k in kernel_size] 806 | 807 | # for even kernels we need to do asymmetric padding :( 808 | out_padding = 2 * len(kernel_size) * [0] 809 | 810 | for i in range(len(kernel_size)): 811 | computed_tmp = computed[-(i + 1)] 812 | 813 | pad_front = computed_tmp // 2 814 | pad_rear = computed_tmp - pad_front 815 | 816 | out_padding[2 * i + 0] = pad_front 817 | out_padding[2 * i + 1] = pad_rear 818 | 819 | return out_padding 820 | 821 | 822 | def _filter2d(input, kernel): 823 | # prepare kernel 824 | b, c, h, w = input.shape 825 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 826 | 827 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 828 | 829 | height, width = tmp_kernel.shape[-2:] 830 | 831 | padding_shape: list[int] = _compute_padding([height, width]) 832 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 833 | 834 | # kernel and input tensor reshape to align element-wise or batch-wise params 835 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 836 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 837 | 838 | # convolve the tensor with the kernel. 839 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 840 | 841 | out = output.view(b, c, h, w) 842 | return out 843 | 844 | 845 | def _gaussian(window_size: int, sigma): 846 | if isinstance(sigma, float): 847 | sigma = torch.tensor([[sigma]]) 848 | 849 | batch_size = sigma.shape[0] 850 | 851 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 852 | 853 | if window_size % 2 == 0: 854 | x = x + 0.5 855 | 856 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 857 | 858 | return gauss / gauss.sum(-1, keepdim=True) 859 | 860 | 861 | def _gaussian_blur2d(input, kernel_size, sigma): 862 | if isinstance(sigma, tuple): 863 | sigma = torch.tensor([sigma], dtype=input.dtype) 864 | else: 865 | sigma = sigma.to(dtype=input.dtype) 866 | 867 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 868 | bs = sigma.shape[0] 869 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 870 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 871 | out_x = _filter2d(input, kernel_x[..., None, :]) 872 | out = _filter2d(out_x, kernel_y[..., None]) 873 | 874 | return out 875 | --------------------------------------------------------------------------------