├── .gitignore ├── README.md ├── __init__.py ├── animate_anyone_nodes.py ├── assets ├── animate_anyone_res.mp4 ├── animate_anyone_test_00003.mp4 ├── animate_anyone_wf.json ├── dance.mp4 ├── show.mp4 └── test_12.png ├── controlnet_sdv.py ├── pipeline_stable_video_diffusion_controlnet_long.py ├── prepare.py ├── requirements.txt ├── run_inference_release.py └── unet_spatio_temporal_condition_controlnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | /output/ 4 | /input/ 5 | /checkpoint/ 6 | /temp/ 7 | /.vs 8 | .idea/ 9 | venv/ 10 | .git/ 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-AnimateAnyone-reproduction 2 | 3 | A ComfyUI custom node that simply integrates the [animate-anyone-reproduction](https://github.com/bendanzzc/AnimateAnyone-reproduction) functionality. 4 | 5 | 一个简单接入 animate-anyone-reproduction 的 ComfyUI 节点。 6 | 7 | 8 | ## Instruction 指南 9 | 10 | We use 'COMFYUI_PATH' to represent your comfyui directory. 11 | 我们用'COMFYUI_PATH'表示你comfyui的目录 12 | 13 | ### clone repo 克隆仓库 14 | * Clone this repo into 'COMFYUI_PATH/custum_nodes' 将这个仓库克隆到'COMFYUI_PATH/custum_nodes'目录下 15 | ```txt 16 | git clone https://github.com/AuroBit/ComfyUI-Animate-Anyone-reproduction.git custom_nodes/ComfyUI-Animate-Anyone-reproduction 17 | ``` 18 | * install dependences: 安装依赖 19 | ```txt 20 | pip install -r custom_nodes/ComfyUI-Animate-Anyone-reproduction/requirements.txt 21 | ``` 22 | 23 | 24 | ### prepare checkpoints files 准备模型文件 25 | #### Method 1: clone file repos from huggingface and modelscope. 方法1:通过clone仓库下载模型文件 26 | 27 | * clone (or download all files) the SVD model repo from: https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/tree/main to any where you like. 将[SVD模型](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/tree/main)克隆到一个你喜欢的目录下。 28 | * Create a folder named 'animate_anyone' under the 'COMFYUI_PATH/models' folder. 在'COMFYUI_PATH/models'下创建一个'animate_anyone'文件夹 29 | * Copy all files and folders (except the 'unet' folder) to 'COMFYUI_PATH/models/animate_anyone' 将所有文件和文件夹(除了'unet'文件夹)复制到'COMFYUI_PATH/models/animate_anyone'下面 30 | 31 | * clone (or download model and files) from: https://modelscope.cn/models/lightnessly/animate-anyone-v1/files 克隆原作者训练的[模型仓库](https://modelscope.cn/models/lightnessly/animate-anyone-v1/files) 32 | * Copy the 'controlnet' folder to the 'COMFYUI_PATH/models/animate_anyone' folder. 33 | * Copy the 'unet' folder to the 'COMFYUI_PATH/models/animate_anyone' folder. 34 | 35 | * Your file structure should look like this: 下载完后你的模型文件夹的目录结构应该是这样的 36 | ```txt 37 | - ComfyUI 38 | ... 39 | - models 40 | - animate_anyone 41 | - controlnet 42 | config.json 43 | diffusion_pytorch_model.safetensors 44 | - feature_extractor 45 | preprocessor_config.json 46 | - image_encoder 47 | ... 48 | - scheduler 49 | ... 50 | - unet 51 | ... 52 | - vae 53 | ... 54 | ... 55 | ``` 56 | 57 | 58 | #### Method 2 59 | Download all these file automatically using python script 60 | 通过python脚本自动下载模型文件 61 | 62 | * run: 63 | ```txt 64 | python custom_nodes\ComfyUI-AnimateAnyon-reproduction\prepare.py 65 | ``` 66 | 67 | 68 | ## Examples 69 | [] -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .animate_anyone_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 4 | 5 | -------------------------------------------------------------------------------- /animate_anyone_nodes.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | from .pipeline_stable_video_diffusion_controlnet_long import StableVideoDiffusionPipelineControlNet 9 | from .controlnet_sdv import ControlNetSDVModel 10 | #from diffusers import T2IAdapter 11 | from .unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel 12 | import re 13 | import folder_paths 14 | 15 | 16 | def load_images_from_folder_to_pil(folder, target_size=(512, 512)): 17 | images = [] 18 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed 19 | 20 | def frame_number(filename): 21 | matches = re.findall(r'\d+', filename) # Find all sequences of digits in the filename 22 | if matches: 23 | if matches[-1] == '0000' and len(matches) > 1: 24 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000' 25 | return int(matches[-1]) # Otherwise, return the last sequence 26 | return float('inf') # Return 'inf' 27 | 28 | 29 | # Sorting files based on frame number 30 | sorted_files = sorted(os.listdir(folder)) 31 | 32 | # Load, resize, and convert images 33 | for filename in sorted_files: 34 | ext = os.path.splitext(filename)[1].lower() 35 | if ext in valid_extensions: 36 | img = Image.open(os.path.join(folder, filename)).convert('RGB') 37 | images.append(img) 38 | 39 | return images[::2] 40 | 41 | 42 | class AnimateAnyone: 43 | @classmethod 44 | def INPUT_TYPES(cls): 45 | return { 46 | "required": { 47 | "image": ('IMAGE',), 48 | "pose_images": ('IMAGE',) 49 | }, 50 | "optional": { 51 | "width": ('INT', {'default': 512, 'min': 64, 'max': 10240, 'step': 64}), 52 | "height": ('INT', {'default': 768, 'min': 64, 'max': 10240, 'step': 64}), 53 | "frames_per_batch": ('INT', {'default': 14, 'min': 1, 'max': 1024}), 54 | "steps": ('INT', {'default': 25, 'min': 1, 'max': 1024}), 55 | "fps": ('INT', {'default': 7, 'min': 1, 'max': 1024}) 56 | } 57 | } 58 | 59 | RETURN_TYPES = ("IMAGE",) 60 | FUNCTION = 'run_inference' 61 | CATEGORY = 'AnymateAnyone' 62 | 63 | def _tensor_to_pil(self, tensor_img): 64 | new_tensor_img = tensor_img * 255 65 | 66 | pil_img_list = [] 67 | for img in new_tensor_img: 68 | np_img = img.to(torch.uint8).cpu().numpy() 69 | pil_img = Image.fromarray(np_img).convert('RGB') 70 | pil_img_list.append(pil_img) 71 | return pil_img_list 72 | 73 | 74 | 75 | 76 | 77 | def _pil_to_tensor(self, images:list): 78 | """ 79 | Input: 80 | images: list of pil.Image 81 | """ 82 | w,h = images[0].size 83 | tensor_imgs = torch.zeros([len(images), h, w, 3], dtype=torch.float32) 84 | for i in range(len(images)): 85 | np_img = np.array(images[i]) 86 | tensor_img = torch.from_numpy(np_img.astype(np.float32) / 255.) 87 | tensor_imgs[i] = tensor_img 88 | return tensor_imgs 89 | 90 | 91 | def run_inference(self, image, pose_images, 92 | width=512, height=768, 93 | frames_per_batch=14, steps=25, fps=7): 94 | 95 | args = { 96 | "pretrained_model_name_or_path": "models/animate_anyone", 97 | } 98 | assert width%64 ==0, "`height` and `width` have to be divisible by 64" 99 | assert height%64 ==0, "`height` and `width` have to be divisible by 64" 100 | 101 | # convert image from tensor to pil.Image 102 | ref_images = self._tensor_to_pil(image) 103 | pose_imgs_pil = self._tensor_to_pil(pose_images) 104 | 105 | # load checkpoints 106 | controlnet = ControlNetSDVModel.from_pretrained(args["pretrained_model_name_or_path"] + "/controlnet") 107 | unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(args["pretrained_model_name_or_path"]+'/unet') 108 | 109 | pipeline = StableVideoDiffusionPipelineControlNet.from_pretrained(args["pretrained_model_name_or_path"], controlnet=controlnet, unet=unet) 110 | pipeline.to(dtype=torch.float16) 111 | pipeline.enable_model_cpu_offload() 112 | 113 | print('Model loading: DONE.') 114 | 115 | # val_save_dir = os.path.join(folder_paths.get_output_directory(), "animate_anyone") 116 | # os.makedirs(val_save_dir, exist_ok=True) 117 | 118 | # Inference and saving loop 119 | final_result = [] 120 | num_frames = len(pose_imgs_pil) 121 | if num_frames <= frames_per_batch: 122 | nb_append = frames_per_batch + 1 - num_frames 123 | pose_imgs_pil.extend([pose_imgs_pil[-1]] * nb_append) 124 | 125 | print('Image and frame data loading: DONE.') 126 | 127 | import time 128 | start = time.time() 129 | 130 | for ref_image in ref_images: 131 | video_frames = pipeline(ref_image, pose_imgs_pil, decode_chunk_size=2,num_frames=len(pose_imgs_pil),motion_bucket_id=127.0, 132 | fps=fps,controlnet_cond_scale=1.0, width=width, height=height, 133 | min_guidance_scale=3.5, max_guidance_scale=3.5, frames_per_batch=frames_per_batch, num_inference_steps=steps, overlap=4).frames[0] 134 | # [video_frames[i].save(f"{val_save_dir}/{i}.jpg") for i in range(len(video_frames))] 135 | final_result.extend(video_frames[:num_frames]) 136 | 137 | end = time.time() 138 | print(f"Elipsed time: {end - start}") 139 | 140 | tensor_results = self._pil_to_tensor(final_result) 141 | return (tensor_results, ) 142 | 143 | 144 | 145 | 146 | 147 | 148 | NODE_CLASS_MAPPINGS = { 149 | "AnimateAnyone":AnimateAnyone, 150 | } 151 | 152 | NODE_DISPLAY_NAME_MAPPINGS = { 153 | "AnimateAnyone": "AnimateAnyone" 154 | } 155 | -------------------------------------------------------------------------------- /assets/animate_anyone_res.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AuroBit/ComfyUI-AnimateAnyone-reproduction/4ff8e90391cf32bb1dc40f7e4888911fdf896e84/assets/animate_anyone_res.mp4 -------------------------------------------------------------------------------- /assets/animate_anyone_test_00003.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AuroBit/ComfyUI-AnimateAnyone-reproduction/4ff8e90391cf32bb1dc40f7e4888911fdf896e84/assets/animate_anyone_test_00003.mp4 -------------------------------------------------------------------------------- /assets/animate_anyone_wf.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 131, 3 | "last_link_id": 271, 4 | "nodes": [ 5 | { 6 | "id": 115, 7 | "type": "PreviewImage", 8 | "pos": [ 9 | -850, 10 | -855 11 | ], 12 | "size": { 13 | "0": 367.8274230957031, 14 | "1": 449.884521484375 15 | }, 16 | "flags": {}, 17 | "order": 9, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 235 24 | } 25 | ], 26 | "properties": { 27 | "Node name for S&R": "PreviewImage" 28 | } 29 | }, 30 | { 31 | "id": 123, 32 | "type": "Reroute", 33 | "pos": [ 34 | -432, 35 | -1216 36 | ], 37 | "size": [ 38 | 75, 39 | 26 40 | ], 41 | "flags": {}, 42 | "order": 10, 43 | "mode": 0, 44 | "inputs": [ 45 | { 46 | "name": "", 47 | "type": "*", 48 | "link": 246 49 | } 50 | ], 51 | "outputs": [ 52 | { 53 | "name": "", 54 | "type": "IMAGE", 55 | "links": [ 56 | 248 57 | ], 58 | "slot_index": 0 59 | } 60 | ], 61 | "properties": { 62 | "showOutputText": false, 63 | "horizontal": false 64 | } 65 | }, 66 | { 67 | "id": 114, 68 | "type": "DWPreprocessor", 69 | "pos": [ 70 | -847, 71 | -1129 72 | ], 73 | "size": { 74 | "0": 315, 75 | "1": 198 76 | }, 77 | "flags": {}, 78 | "order": 8, 79 | "mode": 0, 80 | "inputs": [ 81 | { 82 | "name": "image", 83 | "type": "IMAGE", 84 | "link": 245 85 | } 86 | ], 87 | "outputs": [ 88 | { 89 | "name": "IMAGE", 90 | "type": "IMAGE", 91 | "links": [ 92 | 235, 93 | 246 94 | ], 95 | "shape": 3, 96 | "slot_index": 0 97 | }, 98 | { 99 | "name": "POSE_KEYPOINT", 100 | "type": "POSE_KEYPOINT", 101 | "links": null, 102 | "shape": 3 103 | } 104 | ], 105 | "properties": { 106 | "Node name for S&R": "DWPreprocessor" 107 | }, 108 | "widgets_values": [ 109 | "enable", 110 | "enable", 111 | "enable", 112 | 512, 113 | "yolox_l.onnx", 114 | "dw-ll_ucoco_384_bs5.torchscript.pt" 115 | ] 116 | }, 117 | { 118 | "id": 124, 119 | "type": "Reroute", 120 | "pos": [ 121 | -246, 122 | -1213 123 | ], 124 | "size": [ 125 | 75, 126 | 26 127 | ], 128 | "flags": {}, 129 | "order": 11, 130 | "mode": 0, 131 | "inputs": [ 132 | { 133 | "name": "", 134 | "type": "*", 135 | "link": 248 136 | } 137 | ], 138 | "outputs": [ 139 | { 140 | "name": "", 141 | "type": "IMAGE", 142 | "links": [ 143 | 249 144 | ], 145 | "slot_index": 0 146 | } 147 | ], 148 | "properties": { 149 | "showOutputText": false, 150 | "horizontal": false 151 | } 152 | }, 153 | { 154 | "id": 125, 155 | "type": "Reroute", 156 | "pos": [ 157 | -249.00068450134245, 158 | -1267.0123651764297 159 | ], 160 | "size": [ 161 | 75, 162 | 26 163 | ], 164 | "flags": {}, 165 | "order": 5, 166 | "mode": 0, 167 | "inputs": [ 168 | { 169 | "name": "", 170 | "type": "*", 171 | "link": 250 172 | } 173 | ], 174 | "outputs": [ 175 | { 176 | "name": "", 177 | "type": "IMAGE", 178 | "links": [ 179 | 251 180 | ], 181 | "slot_index": 0 182 | } 183 | ], 184 | "properties": { 185 | "showOutputText": false, 186 | "horizontal": false 187 | } 188 | }, 189 | { 190 | "id": 120, 191 | "type": "Reroute", 192 | "pos": [ 193 | -1136, 194 | -1274 195 | ], 196 | "size": [ 197 | 75, 198 | 26 199 | ], 200 | "flags": {}, 201 | "order": 2, 202 | "mode": 0, 203 | "inputs": [ 204 | { 205 | "name": "", 206 | "type": "*", 207 | "link": 241 208 | } 209 | ], 210 | "outputs": [ 211 | { 212 | "name": "", 213 | "type": "IMAGE", 214 | "links": [ 215 | 250, 216 | 255 217 | ], 218 | "slot_index": 0 219 | } 220 | ], 221 | "properties": { 222 | "showOutputText": false, 223 | "horizontal": false 224 | } 225 | }, 226 | { 227 | "id": 122, 228 | "type": "Reroute", 229 | "pos": [ 230 | -983, 231 | -1230 232 | ], 233 | "size": [ 234 | 75, 235 | 26 236 | ], 237 | "flags": {}, 238 | "order": 7, 239 | "mode": 0, 240 | "inputs": [ 241 | { 242 | "name": "", 243 | "type": "*", 244 | "link": 247 245 | } 246 | ], 247 | "outputs": [ 248 | { 249 | "name": "", 250 | "type": "IMAGE", 251 | "links": [ 252 | 245 253 | ], 254 | "slot_index": 0 255 | } 256 | ], 257 | "properties": { 258 | "showOutputText": false, 259 | "horizontal": false 260 | } 261 | }, 262 | { 263 | "id": 128, 264 | "type": "Reroute", 265 | "pos": [ 266 | 714.9541711678453, 267 | -1267.1455356413667 268 | ], 269 | "size": [ 270 | 75, 271 | 26 272 | ], 273 | "flags": {}, 274 | "order": 6, 275 | "mode": 0, 276 | "inputs": [ 277 | { 278 | "name": "", 279 | "type": "*", 280 | "link": 255 281 | } 282 | ], 283 | "outputs": [ 284 | { 285 | "name": "", 286 | "type": "IMAGE", 287 | "links": [ 288 | 263 289 | ], 290 | "slot_index": 0 291 | } 292 | ], 293 | "properties": { 294 | "showOutputText": false, 295 | "horizontal": false 296 | } 297 | }, 298 | { 299 | "id": 127, 300 | "type": "Reroute", 301 | "pos": [ 302 | 720, 303 | -1222 304 | ], 305 | "size": [ 306 | 75, 307 | 26 308 | ], 309 | "flags": {}, 310 | "order": 16, 311 | "mode": 0, 312 | "inputs": [ 313 | { 314 | "name": "", 315 | "type": "*", 316 | "link": 271 317 | } 318 | ], 319 | "outputs": [ 320 | { 321 | "name": "", 322 | "type": "IMAGE", 323 | "links": [ 324 | 264 325 | ], 326 | "slot_index": 0 327 | } 328 | ], 329 | "properties": { 330 | "showOutputText": false, 331 | "horizontal": false 332 | } 333 | }, 334 | { 335 | "id": 111, 336 | "type": "LoadImage", 337 | "pos": [ 338 | -1526, 339 | -1174 340 | ], 341 | "size": { 342 | "0": 315, 343 | "1": 314 344 | }, 345 | "flags": {}, 346 | "order": 0, 347 | "mode": 0, 348 | "outputs": [ 349 | { 350 | "name": "IMAGE", 351 | "type": "IMAGE", 352 | "links": [ 353 | 241 354 | ], 355 | "shape": 3, 356 | "slot_index": 0 357 | }, 358 | { 359 | "name": "MASK", 360 | "type": "MASK", 361 | "links": null, 362 | "shape": 3 363 | } 364 | ], 365 | "properties": { 366 | "Node name for S&R": "LoadImage" 367 | }, 368 | "widgets_values": [ 369 | "test_12 (1).png", 370 | "image" 371 | ] 372 | }, 373 | { 374 | "id": 110, 375 | "type": "AnimateAnyone", 376 | "pos": [ 377 | -144, 378 | -1117 379 | ], 380 | "size": { 381 | "0": 315, 382 | "1": 174 383 | }, 384 | "flags": {}, 385 | "order": 12, 386 | "mode": 0, 387 | "inputs": [ 388 | { 389 | "name": "image", 390 | "type": "IMAGE", 391 | "link": 251 392 | }, 393 | { 394 | "name": "pose_images", 395 | "type": "IMAGE", 396 | "link": 249 397 | } 398 | ], 399 | "outputs": [ 400 | { 401 | "name": "IMAGE", 402 | "type": "IMAGE", 403 | "links": [ 404 | 231, 405 | 232, 406 | 252 407 | ], 408 | "shape": 3, 409 | "slot_index": 0 410 | } 411 | ], 412 | "properties": { 413 | "Node name for S&R": "AnimateAnyone" 414 | }, 415 | "widgets_values": [ 416 | 704, 417 | 896, 418 | 14, 419 | 25, 420 | 7 421 | ] 422 | }, 423 | { 424 | "id": 130, 425 | "type": "ReActorFaceSwap", 426 | "pos": [ 427 | 823, 428 | -1151 429 | ], 430 | "size": { 431 | "0": 315, 432 | "1": 338 433 | }, 434 | "flags": {}, 435 | "order": 17, 436 | "mode": 0, 437 | "inputs": [ 438 | { 439 | "name": "input_image", 440 | "type": "IMAGE", 441 | "link": 264 442 | }, 443 | { 444 | "name": "source_image", 445 | "type": "IMAGE", 446 | "link": 263 447 | }, 448 | { 449 | "name": "face_model", 450 | "type": "FACE_MODEL", 451 | "link": null 452 | } 453 | ], 454 | "outputs": [ 455 | { 456 | "name": "IMAGE", 457 | "type": "IMAGE", 458 | "links": [ 459 | 260, 460 | 269 461 | ], 462 | "shape": 3, 463 | "slot_index": 0 464 | }, 465 | { 466 | "name": "FACE_MODEL", 467 | "type": "FACE_MODEL", 468 | "links": null, 469 | "shape": 3 470 | } 471 | ], 472 | "properties": { 473 | "Node name for S&R": "ReActorFaceSwap" 474 | }, 475 | "widgets_values": [ 476 | true, 477 | "inswapper_128.onnx", 478 | "retinaface_resnet50", 479 | "codeformer-v0.1.0.pth", 480 | 1, 481 | 0.5, 482 | "no", 483 | "no", 484 | "0", 485 | "0", 486 | 1 487 | ] 488 | }, 489 | { 490 | "id": 121, 491 | "type": "Reroute", 492 | "pos": [ 493 | -1143, 494 | -1231 495 | ], 496 | "size": [ 497 | 75, 498 | 26 499 | ], 500 | "flags": {}, 501 | "order": 3, 502 | "mode": 0, 503 | "inputs": [ 504 | { 505 | "name": "", 506 | "type": "*", 507 | "link": 242 508 | } 509 | ], 510 | "outputs": [ 511 | { 512 | "name": "", 513 | "type": "IMAGE", 514 | "links": [ 515 | 247 516 | ], 517 | "slot_index": 0 518 | } 519 | ], 520 | "properties": { 521 | "showOutputText": false, 522 | "horizontal": false 523 | } 524 | }, 525 | { 526 | "id": 126, 527 | "type": "Reroute", 528 | "pos": [ 529 | 511, 530 | -1222 531 | ], 532 | "size": [ 533 | 75, 534 | 26 535 | ], 536 | "flags": {}, 537 | "order": 15, 538 | "mode": 0, 539 | "inputs": [ 540 | { 541 | "name": "", 542 | "type": "*", 543 | "link": 252 544 | } 545 | ], 546 | "outputs": [ 547 | { 548 | "name": "", 549 | "type": "IMAGE", 550 | "links": [ 551 | 271 552 | ], 553 | "slot_index": 0 554 | } 555 | ], 556 | "properties": { 557 | "showOutputText": false, 558 | "horizontal": false 559 | } 560 | }, 561 | { 562 | "id": 27, 563 | "type": "PreviewImage", 564 | "pos": [ 565 | -1526, 566 | -483 567 | ], 568 | "size": { 569 | "0": 353.612060546875, 570 | "1": 479.5135498046875 571 | }, 572 | "flags": {}, 573 | "order": 4, 574 | "mode": 0, 575 | "inputs": [ 576 | { 577 | "name": "images", 578 | "type": "IMAGE", 579 | "link": 243, 580 | "label": "图像" 581 | } 582 | ], 583 | "properties": { 584 | "Node name for S&R": "PreviewImage" 585 | } 586 | }, 587 | { 588 | "id": 112, 589 | "type": "PreviewImage", 590 | "pos": [ 591 | 193, 592 | -881 593 | ], 594 | "size": { 595 | "0": 355.12384033203125, 596 | "1": 589.513427734375 597 | }, 598 | "flags": {}, 599 | "order": 13, 600 | "mode": 0, 601 | "inputs": [ 602 | { 603 | "name": "images", 604 | "type": "IMAGE", 605 | "link": 231 606 | } 607 | ], 608 | "properties": { 609 | "Node name for S&R": "PreviewImage" 610 | } 611 | }, 612 | { 613 | "id": 118, 614 | "type": "PreviewImage", 615 | "pos": [ 616 | 777, 617 | -720 618 | ], 619 | "size": { 620 | "0": 431.2261962890625, 621 | "1": 697.3739624023438 622 | }, 623 | "flags": {}, 624 | "order": 18, 625 | "mode": 0, 626 | "inputs": [ 627 | { 628 | "name": "images", 629 | "type": "IMAGE", 630 | "link": 260 631 | } 632 | ], 633 | "properties": { 634 | "Node name for S&R": "PreviewImage" 635 | } 636 | }, 637 | { 638 | "id": 5, 639 | "type": "VHS_LoadVideo", 640 | "pos": [ 641 | -1532, 642 | -797 643 | ], 644 | "size": [ 645 | 334.2194519042969, 646 | 242 647 | ], 648 | "flags": {}, 649 | "order": 1, 650 | "mode": 0, 651 | "inputs": [ 652 | { 653 | "name": "batch_manager", 654 | "type": "VHS_BatchManager", 655 | "link": null 656 | } 657 | ], 658 | "outputs": [ 659 | { 660 | "name": "IMAGE", 661 | "type": "IMAGE", 662 | "links": [ 663 | 242, 664 | 243 665 | ], 666 | "shape": 3, 667 | "slot_index": 0, 668 | "label": "图像" 669 | }, 670 | { 671 | "name": "frame_count", 672 | "type": "INT", 673 | "links": null, 674 | "shape": 3, 675 | "slot_index": 1, 676 | "label": "帧计数" 677 | }, 678 | { 679 | "name": "audio", 680 | "type": "VHS_AUDIO", 681 | "links": null, 682 | "shape": 3 683 | } 684 | ], 685 | "properties": { 686 | "Node name for S&R": "VHS_LoadVideo" 687 | }, 688 | "widgets_values": { 689 | "video": "dance.mp4", 690 | "force_rate": 20, 691 | "force_size": "512x?", 692 | "custom_width": 512, 693 | "custom_height": 512, 694 | "frame_load_cap": 256, 695 | "skip_first_frames": 0, 696 | "select_every_nth": 1, 697 | "choose video to upload": "image", 698 | "videopreview": { 699 | "hidden": false, 700 | "paused": true, 701 | "params": { 702 | "filename": "dance.mp4", 703 | "type": "input", 704 | "format": "video/mp4", 705 | "frame_load_cap": 256, 706 | "skip_first_frames": 0, 707 | "force_rate": 20, 708 | "select_every_nth": 1, 709 | "force_size": "512x?" 710 | } 711 | } 712 | } 713 | }, 714 | { 715 | "id": 113, 716 | "type": "VHS_VideoCombine", 717 | "pos": [ 718 | -138, 719 | -881 720 | ], 721 | "size": [ 722 | 315, 723 | 290 724 | ], 725 | "flags": {}, 726 | "order": 14, 727 | "mode": 0, 728 | "inputs": [ 729 | { 730 | "name": "images", 731 | "type": "IMAGE", 732 | "link": 232 733 | }, 734 | { 735 | "name": "audio", 736 | "type": "VHS_AUDIO", 737 | "link": null 738 | }, 739 | { 740 | "name": "batch_manager", 741 | "type": "VHS_BatchManager", 742 | "link": null 743 | } 744 | ], 745 | "outputs": [ 746 | { 747 | "name": "Filenames", 748 | "type": "VHS_FILENAMES", 749 | "links": null, 750 | "shape": 3 751 | } 752 | ], 753 | "properties": { 754 | "Node name for S&R": "VHS_VideoCombine" 755 | }, 756 | "widgets_values": { 757 | "frame_rate": 8, 758 | "loop_count": 0, 759 | "filename_prefix": "animate_anyone", 760 | "format": "video/h265-mp4", 761 | "pix_fmt": "yuv420p10le", 762 | "crf": 22, 763 | "save_metadata": true, 764 | "pingpong": false, 765 | "save_output": true, 766 | "videopreview": { 767 | "hidden": false, 768 | "paused": false, 769 | "params": { 770 | "filename": "animate_anyone_00014.mp4", 771 | "subfolder": "", 772 | "type": "output", 773 | "format": "video/h265-mp4" 774 | } 775 | } 776 | } 777 | }, 778 | { 779 | "id": 131, 780 | "type": "VHS_VideoCombine", 781 | "pos": [ 782 | 1297, 783 | -1169 784 | ], 785 | "size": [ 786 | 745.5550133571146, 787 | 848.5143508280603 788 | ], 789 | "flags": {}, 790 | "order": 19, 791 | "mode": 0, 792 | "inputs": [ 793 | { 794 | "name": "images", 795 | "type": "IMAGE", 796 | "link": 269 797 | }, 798 | { 799 | "name": "audio", 800 | "type": "VHS_AUDIO", 801 | "link": null 802 | }, 803 | { 804 | "name": "batch_manager", 805 | "type": "VHS_BatchManager", 806 | "link": null 807 | } 808 | ], 809 | "outputs": [ 810 | { 811 | "name": "Filenames", 812 | "type": "VHS_FILENAMES", 813 | "links": null, 814 | "shape": 3 815 | } 816 | ], 817 | "properties": { 818 | "Node name for S&R": "VHS_VideoCombine" 819 | }, 820 | "widgets_values": { 821 | "frame_rate": 20, 822 | "loop_count": 0, 823 | "filename_prefix": "animate_anyone", 824 | "format": "video/h265-mp4", 825 | "pix_fmt": "yuv420p10le", 826 | "crf": 22, 827 | "save_metadata": true, 828 | "pingpong": false, 829 | "save_output": true, 830 | "videopreview": { 831 | "hidden": false, 832 | "paused": false, 833 | "params": { 834 | "filename": "animate_anyone_00013.mp4", 835 | "subfolder": "", 836 | "type": "output", 837 | "format": "video/h265-mp4" 838 | } 839 | } 840 | } 841 | } 842 | ], 843 | "links": [ 844 | [ 845 | 231, 846 | 110, 847 | 0, 848 | 112, 849 | 0, 850 | "IMAGE" 851 | ], 852 | [ 853 | 232, 854 | 110, 855 | 0, 856 | 113, 857 | 0, 858 | "IMAGE" 859 | ], 860 | [ 861 | 235, 862 | 114, 863 | 0, 864 | 115, 865 | 0, 866 | "IMAGE" 867 | ], 868 | [ 869 | 241, 870 | 111, 871 | 0, 872 | 120, 873 | 0, 874 | "*" 875 | ], 876 | [ 877 | 242, 878 | 5, 879 | 0, 880 | 121, 881 | 0, 882 | "*" 883 | ], 884 | [ 885 | 243, 886 | 5, 887 | 0, 888 | 27, 889 | 0, 890 | "IMAGE" 891 | ], 892 | [ 893 | 245, 894 | 122, 895 | 0, 896 | 114, 897 | 0, 898 | "IMAGE" 899 | ], 900 | [ 901 | 246, 902 | 114, 903 | 0, 904 | 123, 905 | 0, 906 | "*" 907 | ], 908 | [ 909 | 247, 910 | 121, 911 | 0, 912 | 122, 913 | 0, 914 | "*" 915 | ], 916 | [ 917 | 248, 918 | 123, 919 | 0, 920 | 124, 921 | 0, 922 | "*" 923 | ], 924 | [ 925 | 249, 926 | 124, 927 | 0, 928 | 110, 929 | 1, 930 | "IMAGE" 931 | ], 932 | [ 933 | 250, 934 | 120, 935 | 0, 936 | 125, 937 | 0, 938 | "*" 939 | ], 940 | [ 941 | 251, 942 | 125, 943 | 0, 944 | 110, 945 | 0, 946 | "IMAGE" 947 | ], 948 | [ 949 | 252, 950 | 110, 951 | 0, 952 | 126, 953 | 0, 954 | "*" 955 | ], 956 | [ 957 | 255, 958 | 120, 959 | 0, 960 | 128, 961 | 0, 962 | "*" 963 | ], 964 | [ 965 | 260, 966 | 130, 967 | 0, 968 | 118, 969 | 0, 970 | "IMAGE" 971 | ], 972 | [ 973 | 263, 974 | 128, 975 | 0, 976 | 130, 977 | 1, 978 | "IMAGE" 979 | ], 980 | [ 981 | 264, 982 | 127, 983 | 0, 984 | 130, 985 | 0, 986 | "IMAGE" 987 | ], 988 | [ 989 | 269, 990 | 130, 991 | 0, 992 | 131, 993 | 0, 994 | "IMAGE" 995 | ], 996 | [ 997 | 271, 998 | 126, 999 | 0, 1000 | 127, 1001 | 0, 1002 | "*" 1003 | ] 1004 | ], 1005 | "groups": [ 1006 | { 1007 | "title": "AnimateAnyone", 1008 | "bounding": [ 1009 | -255, 1010 | -1325, 1011 | 866, 1012 | 1013 1013 | ], 1014 | "color": "#8A8", 1015 | "font_size": 24 1016 | }, 1017 | { 1018 | "title": "Loader", 1019 | "bounding": [ 1020 | -1595, 1021 | -1338, 1022 | 533, 1023 | 948 1024 | ], 1025 | "color": "#A88", 1026 | "font_size": 24 1027 | }, 1028 | { 1029 | "title": "FaceSwap", 1030 | "bounding": [ 1031 | 709, 1032 | -1326, 1033 | 875, 1034 | 1021 1035 | ], 1036 | "color": "#3f789e", 1037 | "font_size": 24 1038 | }, 1039 | { 1040 | "title": "Pose", 1041 | "bounding": [ 1042 | -988, 1043 | -1332, 1044 | 648, 1045 | 981 1046 | ], 1047 | "color": "#8AA", 1048 | "font_size": 24 1049 | } 1050 | ], 1051 | "config": {}, 1052 | "extra": {}, 1053 | "version": 0.4 1054 | } -------------------------------------------------------------------------------- /assets/dance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AuroBit/ComfyUI-AnimateAnyone-reproduction/4ff8e90391cf32bb1dc40f7e4888911fdf896e84/assets/dance.mp4 -------------------------------------------------------------------------------- /assets/show.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AuroBit/ComfyUI-AnimateAnyone-reproduction/4ff8e90391cf32bb1dc40f7e4888911fdf896e84/assets/show.mp4 -------------------------------------------------------------------------------- /assets/test_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AuroBit/ComfyUI-AnimateAnyone-reproduction/4ff8e90391cf32bb1dc40f7e4888911fdf896e84/assets/test_12.png -------------------------------------------------------------------------------- /controlnet_sdv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import FromOriginalControlnetMixin 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.models.attention_processor import ( 25 | ADDED_KV_ATTENTION_PROCESSORS, 26 | CROSS_ATTENTION_PROCESSORS, 27 | AttentionProcessor, 28 | AttnAddedKVProcessor, 29 | AttnProcessor, 30 | ) 31 | from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps 32 | from diffusers.models.modeling_utils import ModelMixin 33 | from diffusers.models.unet_3d_blocks import ( 34 | get_down_block, get_up_block,UNetMidBlockSpatioTemporal, 35 | ) 36 | from diffusers.models import UNetSpatioTemporalConditionModel 37 | 38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 39 | 40 | 41 | @dataclass 42 | class ControlNetOutput(BaseOutput): 43 | """ 44 | The output of [`ControlNetModel`]. 45 | 46 | Args: 47 | down_block_res_samples (`tuple[torch.Tensor]`): 48 | A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should 49 | be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be 50 | used to condition the original UNet's downsampling activations. 51 | mid_down_block_re_sample (`torch.Tensor`): 52 | The activation of the midde block (the lowest sample resolution). Each tensor should be of shape 53 | `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. 54 | Output can be used to condition the original UNet's middle block activation. 55 | """ 56 | 57 | down_block_res_samples: Tuple[torch.Tensor] 58 | mid_block_res_sample: torch.Tensor 59 | 60 | 61 | class ControlNetConditioningEmbeddingSVD(nn.Module): 62 | """ 63 | Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN 64 | [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized 65 | training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the 66 | convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides 67 | (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full 68 | model) to encode image-space conditions ... into feature maps ..." 69 | """ 70 | 71 | def __init__( 72 | self, 73 | conditioning_embedding_channels: int, 74 | conditioning_channels: int = 3, 75 | block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), 76 | ): 77 | super().__init__() 78 | 79 | 80 | 81 | self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 82 | 83 | self.blocks = nn.ModuleList([]) 84 | 85 | for i in range(len(block_out_channels) - 1): 86 | channel_in = block_out_channels[i] 87 | channel_out = block_out_channels[i + 1] 88 | self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) 89 | self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 90 | 91 | self.conv_out = zero_module( 92 | nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 93 | ) 94 | 95 | def forward(self, conditioning): 96 | #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames 97 | #combine batch and frames dimensions 98 | batch_size, frames, channels, height, width = conditioning.size() 99 | conditioning = conditioning.view(batch_size * frames, channels, height, width) 100 | 101 | embedding = self.conv_in(conditioning) 102 | embedding = F.silu(embedding) 103 | 104 | for block in self.blocks: 105 | embedding = block(embedding) 106 | embedding = F.silu(embedding) 107 | 108 | embedding = self.conv_out(embedding) 109 | 110 | #split them apart again 111 | #actually not needed 112 | #new_channels, new_height, new_width = embedding.shape[1], embedding.shape[2], embedding.shape[3] 113 | #embedding = embedding.view(batch_size, frames, new_channels, new_height, new_width) 114 | 115 | 116 | return embedding 117 | 118 | 119 | class ControlNetSDVModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): 120 | r""" 121 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample 122 | shaped output. 123 | 124 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 125 | for all models (such as downloading or saving). 126 | 127 | Parameters: 128 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 129 | Height and width of input/output sample. 130 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 131 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 132 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 133 | The tuple of downsample blocks to use. 134 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 135 | The tuple of upsample blocks to use. 136 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 137 | The tuple of output channels for each block. 138 | addition_time_embed_dim: (`int`, defaults to 256): 139 | Dimension to to encode the additional time ids. 140 | projection_class_embeddings_input_dim (`int`, defaults to 768): 141 | The dimension of the projection of encoded `added_time_ids`. 142 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 143 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 144 | The dimension of the cross attention features. 145 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 146 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 147 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 148 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 149 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 150 | The number of attention heads. 151 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 152 | """ 153 | 154 | _supports_gradient_checkpointing = True 155 | 156 | @register_to_config 157 | def __init__( 158 | self, 159 | sample_size: Optional[int] = None, 160 | in_channels: int = 8, 161 | out_channels: int = 4, 162 | down_block_types: Tuple[str] = ( 163 | "CrossAttnDownBlockSpatioTemporal", 164 | "CrossAttnDownBlockSpatioTemporal", 165 | "CrossAttnDownBlockSpatioTemporal", 166 | "DownBlockSpatioTemporal", 167 | ), 168 | up_block_types: Tuple[str] = ( 169 | "UpBlockSpatioTemporal", 170 | "CrossAttnUpBlockSpatioTemporal", 171 | "CrossAttnUpBlockSpatioTemporal", 172 | "CrossAttnUpBlockSpatioTemporal", 173 | ), 174 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 175 | addition_time_embed_dim: int = 256, 176 | projection_class_embeddings_input_dim: int = 768, 177 | layers_per_block: Union[int, Tuple[int]] = 2, 178 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 179 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 180 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), 181 | num_frames: int = 25, 182 | conditioning_channels: int = 3, 183 | conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256), 184 | ): 185 | super().__init__() 186 | self.sample_size = sample_size 187 | 188 | print("layers per block is", layers_per_block) 189 | 190 | # Check inputs 191 | if len(down_block_types) != len(up_block_types): 192 | raise ValueError( 193 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 194 | ) 195 | 196 | if len(block_out_channels) != len(down_block_types): 197 | raise ValueError( 198 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 199 | ) 200 | 201 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 202 | raise ValueError( 203 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 204 | ) 205 | 206 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 207 | raise ValueError( 208 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 209 | ) 210 | 211 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 212 | raise ValueError( 213 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 214 | ) 215 | 216 | # input 217 | self.conv_in = nn.Conv2d( 218 | in_channels, 219 | block_out_channels[0], 220 | kernel_size=3, 221 | padding=1, 222 | ) 223 | 224 | # time 225 | time_embed_dim = block_out_channels[0] * 4 226 | 227 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 228 | timestep_input_dim = block_out_channels[0] 229 | 230 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 231 | 232 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 233 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 234 | 235 | self.down_blocks = nn.ModuleList([]) 236 | self.controlnet_down_blocks = nn.ModuleList([]) 237 | 238 | if isinstance(num_attention_heads, int): 239 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 240 | 241 | if isinstance(cross_attention_dim, int): 242 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 243 | 244 | if isinstance(layers_per_block, int): 245 | layers_per_block = [layers_per_block] * len(down_block_types) 246 | 247 | if isinstance(transformer_layers_per_block, int): 248 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 249 | 250 | blocks_time_embed_dim = time_embed_dim 251 | self.controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD( 252 | conditioning_embedding_channels=block_out_channels[0], 253 | block_out_channels=conditioning_embedding_out_channels, 254 | conditioning_channels=conditioning_channels, 255 | ) 256 | 257 | # down 258 | output_channel = block_out_channels[0] 259 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 260 | controlnet_block = zero_module(controlnet_block) 261 | self.controlnet_down_blocks.append(controlnet_block) 262 | 263 | 264 | 265 | for i, down_block_type in enumerate(down_block_types): 266 | input_channel = output_channel 267 | output_channel = block_out_channels[i] 268 | is_final_block = i == len(block_out_channels) - 1 269 | 270 | down_block = get_down_block( 271 | down_block_type, 272 | num_layers=layers_per_block[i], 273 | transformer_layers_per_block=transformer_layers_per_block[i], 274 | in_channels=input_channel, 275 | out_channels=output_channel, 276 | temb_channels=blocks_time_embed_dim, 277 | add_downsample=not is_final_block, 278 | resnet_eps=1e-5, 279 | cross_attention_dim=cross_attention_dim[i], 280 | num_attention_heads=num_attention_heads[i], 281 | resnet_act_fn="silu", 282 | ) 283 | self.down_blocks.append(down_block) 284 | 285 | for _ in range(layers_per_block[i]): 286 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 287 | controlnet_block = zero_module(controlnet_block) 288 | self.controlnet_down_blocks.append(controlnet_block) 289 | 290 | if not is_final_block: 291 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 292 | controlnet_block = zero_module(controlnet_block) 293 | self.controlnet_down_blocks.append(controlnet_block) 294 | 295 | 296 | # mid 297 | mid_block_channel = block_out_channels[-1] 298 | controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) 299 | controlnet_block = zero_module(controlnet_block) 300 | self.controlnet_mid_block = controlnet_block 301 | 302 | 303 | self.mid_block = UNetMidBlockSpatioTemporal( 304 | block_out_channels[-1], 305 | temb_channels=blocks_time_embed_dim, 306 | transformer_layers_per_block=transformer_layers_per_block[-1], 307 | cross_attention_dim=cross_attention_dim[-1], 308 | num_attention_heads=num_attention_heads[-1], 309 | ) 310 | 311 | 312 | 313 | 314 | # out 315 | #self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) 316 | #self.conv_act = nn.SiLU() 317 | 318 | #self.conv_out = nn.Conv2d( 319 | # block_out_channels[0], 320 | # out_channels, 321 | # kernel_size=3, 322 | # padding=1, 323 | #) 324 | 325 | @property 326 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 327 | r""" 328 | Returns: 329 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 330 | indexed by its weight name. 331 | """ 332 | # set recursively 333 | processors = {} 334 | 335 | def fn_recursive_add_processors( 336 | name: str, 337 | module: torch.nn.Module, 338 | processors: Dict[str, AttentionProcessor], 339 | ): 340 | if hasattr(module, "get_processor"): 341 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 342 | 343 | for sub_name, child in module.named_children(): 344 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 345 | 346 | return processors 347 | 348 | for name, module in self.named_children(): 349 | fn_recursive_add_processors(name, module, processors) 350 | 351 | return processors 352 | 353 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 354 | r""" 355 | Sets the attention processor to use to compute attention. 356 | 357 | Parameters: 358 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 359 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 360 | for **all** `Attention` layers. 361 | 362 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 363 | processor. This is strongly recommended when setting trainable attention processors. 364 | 365 | """ 366 | count = len(self.attn_processors.keys()) 367 | 368 | if isinstance(processor, dict) and len(processor) != count: 369 | raise ValueError( 370 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 371 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 372 | ) 373 | 374 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 375 | if hasattr(module, "set_processor"): 376 | if not isinstance(processor, dict): 377 | module.set_processor(processor) 378 | else: 379 | module.set_processor(processor.pop(f"{name}.processor")) 380 | 381 | for sub_name, child in module.named_children(): 382 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 383 | 384 | for name, module in self.named_children(): 385 | fn_recursive_attn_processor(name, module, processor) 386 | 387 | def set_default_attn_processor(self): 388 | """ 389 | Disables custom attention processors and sets the default attention implementation. 390 | """ 391 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 392 | processor = AttnProcessor() 393 | else: 394 | raise ValueError( 395 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 396 | ) 397 | 398 | self.set_attn_processor(processor) 399 | 400 | def _set_gradient_checkpointing(self, module, value=False): 401 | if hasattr(module, "gradient_checkpointing"): 402 | module.gradient_checkpointing = value 403 | 404 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 405 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 406 | """ 407 | Sets the attention processor to use [feed forward 408 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 409 | 410 | Parameters: 411 | chunk_size (`int`, *optional*): 412 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 413 | over each tensor of dim=`dim`. 414 | dim (`int`, *optional*, defaults to `0`): 415 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 416 | or dim=1 (sequence length). 417 | """ 418 | if dim not in [0, 1]: 419 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 420 | 421 | # By default chunk size is 1 422 | chunk_size = chunk_size or 1 423 | 424 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 425 | if hasattr(module, "set_chunk_feed_forward"): 426 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 427 | 428 | for child in module.children(): 429 | fn_recursive_feed_forward(child, chunk_size, dim) 430 | 431 | for module in self.children(): 432 | fn_recursive_feed_forward(module, chunk_size, dim) 433 | 434 | def forward( 435 | self, 436 | sample: torch.FloatTensor, 437 | timestep: Union[torch.Tensor, float, int], 438 | encoder_hidden_states: torch.Tensor, 439 | added_time_ids: torch.Tensor, 440 | controlnet_cond: torch.FloatTensor = None, 441 | image_only_indicator: Optional[torch.Tensor] = None, 442 | return_dict: bool = True, 443 | guess_mode: bool = False, 444 | conditioning_scale: float = 1.0, 445 | 446 | 447 | ) -> Union[ControlNetOutput, Tuple]: 448 | r""" 449 | The [`UNetSpatioTemporalConditionModel`] forward method. 450 | 451 | Args: 452 | sample (`torch.FloatTensor`): 453 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 454 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 455 | encoder_hidden_states (`torch.FloatTensor`): 456 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 457 | added_time_ids: (`torch.FloatTensor`): 458 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 459 | embeddings and added to the time embeddings. 460 | return_dict (`bool`, *optional*, defaults to `True`): 461 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 462 | tuple. 463 | Returns: 464 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 465 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 466 | a `tuple` is returned where the first element is the sample tensor. 467 | """ 468 | # 1. time 469 | timesteps = timestep 470 | if not torch.is_tensor(timesteps): 471 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 472 | # This would be a good case for the `match` statement (Python 3.10+) 473 | is_mps = sample.device.type == "mps" 474 | if isinstance(timestep, float): 475 | dtype = torch.float32 if is_mps else torch.float64 476 | else: 477 | dtype = torch.int32 if is_mps else torch.int64 478 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 479 | elif len(timesteps.shape) == 0: 480 | timesteps = timesteps[None].to(sample.device) 481 | 482 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 483 | batch_size, num_frames = sample.shape[:2] 484 | timesteps = timesteps.expand(batch_size) 485 | 486 | t_emb = self.time_proj(timesteps) 487 | 488 | # `Timesteps` does not contain any weights and will always return f32 tensors 489 | # but time_embedding might actually be running in fp16. so we need to cast here. 490 | # there might be better ways to encapsulate this. 491 | t_emb = t_emb.to(dtype=sample.dtype) 492 | 493 | emb = self.time_embedding(t_emb) 494 | 495 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 496 | time_embeds = time_embeds.reshape((batch_size, -1)) 497 | time_embeds = time_embeds.to(emb.dtype) 498 | aug_emb = self.add_embedding(time_embeds) 499 | emb = emb + aug_emb 500 | 501 | # Flatten the batch and frames dimensions 502 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 503 | sample = sample.flatten(0, 1) 504 | # Repeat the embeddings num_video_frames times 505 | # emb: [batch, channels] -> [batch * frames, channels] 506 | emb = emb.repeat_interleave(num_frames, dim=0) 507 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 508 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 509 | 510 | # 2. pre-process 511 | sample = self.conv_in(sample) 512 | 513 | #controlnet cond 514 | if controlnet_cond != None: 515 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 516 | sample = sample + controlnet_cond 517 | 518 | 519 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 520 | 521 | down_block_res_samples = (sample,) 522 | for downsample_block in self.down_blocks: 523 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 524 | sample, res_samples = downsample_block( 525 | hidden_states=sample, 526 | temb=emb, 527 | encoder_hidden_states=encoder_hidden_states, 528 | image_only_indicator=image_only_indicator, 529 | ) 530 | else: 531 | sample, res_samples = downsample_block( 532 | hidden_states=sample, 533 | temb=emb, 534 | image_only_indicator=image_only_indicator, 535 | ) 536 | 537 | down_block_res_samples += res_samples 538 | 539 | # 4. mid 540 | sample = self.mid_block( 541 | hidden_states=sample, 542 | temb=emb, 543 | encoder_hidden_states=encoder_hidden_states, 544 | image_only_indicator=image_only_indicator, 545 | ) 546 | 547 | controlnet_down_block_res_samples = () 548 | 549 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 550 | down_block_res_sample = controlnet_block(down_block_res_sample) 551 | controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) 552 | 553 | down_block_res_samples = controlnet_down_block_res_samples 554 | 555 | mid_block_res_sample = self.controlnet_mid_block(sample) 556 | 557 | # 6. scaling 558 | 559 | down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] 560 | mid_block_res_sample = mid_block_res_sample * conditioning_scale 561 | 562 | if not return_dict: 563 | return (down_block_res_samples, mid_block_res_sample) 564 | 565 | return ControlNetOutput( 566 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample 567 | ) 568 | 569 | 570 | @classmethod 571 | def from_unet( 572 | cls, 573 | unet: UNetSpatioTemporalConditionModel, 574 | controlnet_conditioning_channel_order: str = "rgb", 575 | conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), 576 | load_weights_from_unet: bool = True, 577 | conditioning_channels: int = 3, 578 | ): 579 | r""" 580 | Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. 581 | 582 | Parameters: 583 | unet (`UNet2DConditionModel`): 584 | The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied 585 | where applicable. 586 | """ 587 | 588 | transformer_layers_per_block = ( 589 | unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 590 | ) 591 | encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None 592 | encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None 593 | addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None 594 | addition_time_embed_dim = ( 595 | unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None 596 | ) 597 | print(unet.config) 598 | controlnet = cls( 599 | in_channels=unet.config.in_channels, 600 | down_block_types=unet.config.down_block_types, 601 | block_out_channels=unet.config.block_out_channels, 602 | addition_time_embed_dim=unet.config.addition_time_embed_dim, 603 | transformer_layers_per_block=unet.config.transformer_layers_per_block, 604 | cross_attention_dim=unet.config.cross_attention_dim, 605 | num_attention_heads=unet.config.num_attention_heads, 606 | num_frames=unet.config.num_frames, 607 | sample_size=unet.config.sample_size, # Added based on the dict 608 | layers_per_block=unet.config.layers_per_block, 609 | projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, 610 | conditioning_channels = conditioning_channels, 611 | conditioning_embedding_out_channels = conditioning_embedding_out_channels, 612 | ) 613 | #controlnet rgb channel order ignored, set to not makea difference by default 614 | 615 | if load_weights_from_unet: 616 | controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) 617 | controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) 618 | controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) 619 | 620 | # if controlnet.class_embedding: 621 | # controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) 622 | 623 | controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) 624 | controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) 625 | 626 | return controlnet 627 | 628 | @property 629 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 630 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 631 | r""" 632 | Returns: 633 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 634 | indexed by its weight name. 635 | """ 636 | # set recursively 637 | processors = {} 638 | 639 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 640 | if hasattr(module, "get_processor"): 641 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 642 | 643 | for sub_name, child in module.named_children(): 644 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 645 | 646 | return processors 647 | 648 | for name, module in self.named_children(): 649 | fn_recursive_add_processors(name, module, processors) 650 | 651 | return processors 652 | 653 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 654 | def set_attn_processor( 655 | self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False 656 | ): 657 | r""" 658 | Sets the attention processor to use to compute attention. 659 | 660 | Parameters: 661 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 662 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 663 | for **all** `Attention` layers. 664 | 665 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 666 | processor. This is strongly recommended when setting trainable attention processors. 667 | 668 | """ 669 | count = len(self.attn_processors.keys()) 670 | 671 | if isinstance(processor, dict) and len(processor) != count: 672 | raise ValueError( 673 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 674 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 675 | ) 676 | 677 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 678 | if hasattr(module, "set_processor"): 679 | if not isinstance(processor, dict): 680 | module.set_processor(processor, _remove_lora=_remove_lora) 681 | else: 682 | module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) 683 | 684 | for sub_name, child in module.named_children(): 685 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 686 | 687 | for name, module in self.named_children(): 688 | fn_recursive_attn_processor(name, module, processor) 689 | 690 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 691 | def set_default_attn_processor(self): 692 | """ 693 | Disables custom attention processors and sets the default attention implementation. 694 | """ 695 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 696 | processor = AttnAddedKVProcessor() 697 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 698 | processor = AttnProcessor() 699 | else: 700 | raise ValueError( 701 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 702 | ) 703 | 704 | self.set_attn_processor(processor, _remove_lora=True) 705 | 706 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice 707 | def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: 708 | r""" 709 | Enable sliced attention computation. 710 | 711 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 712 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 713 | 714 | Args: 715 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 716 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 717 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 718 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 719 | must be a multiple of `slice_size`. 720 | """ 721 | sliceable_head_dims = [] 722 | 723 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 724 | if hasattr(module, "set_attention_slice"): 725 | sliceable_head_dims.append(module.sliceable_head_dim) 726 | 727 | for child in module.children(): 728 | fn_recursive_retrieve_sliceable_dims(child) 729 | 730 | # retrieve number of attention layers 731 | for module in self.children(): 732 | fn_recursive_retrieve_sliceable_dims(module) 733 | 734 | num_sliceable_layers = len(sliceable_head_dims) 735 | 736 | if slice_size == "auto": 737 | # half the attention head size is usually a good trade-off between 738 | # speed and memory 739 | slice_size = [dim // 2 for dim in sliceable_head_dims] 740 | elif slice_size == "max": 741 | # make smallest slice possible 742 | slice_size = num_sliceable_layers * [1] 743 | 744 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 745 | 746 | if len(slice_size) != len(sliceable_head_dims): 747 | raise ValueError( 748 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 749 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 750 | ) 751 | 752 | for i in range(len(slice_size)): 753 | size = slice_size[i] 754 | dim = sliceable_head_dims[i] 755 | if size is not None and size > dim: 756 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 757 | 758 | # Recursively walk through all the children. 759 | # Any children which exposes the set_attention_slice method 760 | # gets the message 761 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 762 | if hasattr(module, "set_attention_slice"): 763 | module.set_attention_slice(slice_size.pop()) 764 | 765 | for child in module.children(): 766 | fn_recursive_set_attention_slice(child, slice_size) 767 | 768 | reversed_slice_size = list(reversed(slice_size)) 769 | for module in self.children(): 770 | fn_recursive_set_attention_slice(module, reversed_slice_size) 771 | 772 | # def _set_gradient_checkpointing(self, module, value: bool = False) -> None: 773 | # if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): 774 | # module.gradient_checkpointing = value 775 | 776 | 777 | def zero_module(module): 778 | for p in module.parameters(): 779 | nn.init.zeros_(p) 780 | return module 781 | -------------------------------------------------------------------------------- /pipeline_stable_video_diffusion_controlnet_long.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from dataclasses import dataclass 17 | from typing import Callable, Dict, List, Optional, Union 18 | 19 | import numpy as np 20 | import PIL.Image 21 | import torch 22 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 23 | from .controlnet_sdv import ControlNetSDVModel 24 | 25 | from diffusers.image_processor import VaeImageProcessor 26 | from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel 27 | from diffusers.utils import BaseOutput, logging 28 | from diffusers.utils.torch_utils import randn_tensor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 30 | from .unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel 31 | from diffusers.schedulers import EulerDiscreteScheduler 32 | #from diffusers.pipelines.utils import PIL_INTERPOLATION, BaseOutput, logging 33 | 34 | 35 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 36 | 37 | def _get_add_time_ids( 38 | noise_aug_strength, 39 | dtype, 40 | batch_size, 41 | fps=4, 42 | motion_bucket_id=128, 43 | unet=None, 44 | ): 45 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength] 46 | 47 | passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids) 48 | expected_add_embed_dim = unet.add_embedding.linear_1.in_features 49 | 50 | if expected_add_embed_dim != passed_add_embed_dim: 51 | raise ValueError( 52 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 53 | ) 54 | 55 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 56 | # add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 57 | 58 | 59 | return add_time_ids 60 | 61 | 62 | def _append_dims(x, target_dims): 63 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 64 | dims_to_append = target_dims - x.ndim 65 | if dims_to_append < 0: 66 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 67 | return x[(...,) + (None,) * dims_to_append] 68 | 69 | 70 | def tensor2vid(video: torch.Tensor, processor, output_type="np"): 71 | # Based on: 72 | # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 73 | 74 | batch_size, channels, num_frames, height, width = video.shape 75 | outputs = [] 76 | for batch_idx in range(batch_size): 77 | batch_vid = video[batch_idx].permute(1, 0, 2, 3) 78 | batch_output = processor.postprocess(batch_vid, output_type) 79 | 80 | outputs.append(batch_output) 81 | 82 | return outputs 83 | 84 | 85 | @dataclass 86 | class StableVideoDiffusionPipelineOutput(BaseOutput): 87 | r""" 88 | Output class for zero-shot text-to-video pipeline. 89 | 90 | Args: 91 | frames (`[List[PIL.Image.Image]`, `np.ndarray`]): 92 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 93 | num_channels)`. 94 | """ 95 | 96 | frames: Union[List[PIL.Image.Image], np.ndarray] 97 | 98 | 99 | class StableVideoDiffusionPipelineControlNet(DiffusionPipeline): 100 | r""" 101 | Pipeline to generate video from an input image using Stable Video Diffusion. 102 | 103 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 104 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 105 | 106 | Args: 107 | vae ([`AutoencoderKL`]): 108 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 109 | image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): 110 | Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). 111 | unet ([`UNetSpatioTemporalConditionModel`]): 112 | A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. 113 | scheduler ([`EulerDiscreteScheduler`]): 114 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. 115 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 116 | A `CLIPImageProcessor` to extract features from generated images. 117 | """ 118 | 119 | model_cpu_offload_seq = "image_encoder->unet->vae" 120 | _callback_tensor_inputs = ["latents"] 121 | 122 | def __init__( 123 | self, 124 | vae: AutoencoderKLTemporalDecoder, 125 | image_encoder: CLIPVisionModelWithProjection, 126 | unet: UNetSpatioTemporalConditionControlNetModel, 127 | controlnet: ControlNetSDVModel, 128 | scheduler: EulerDiscreteScheduler, 129 | feature_extractor: CLIPImageProcessor, 130 | ): 131 | super().__init__() 132 | 133 | self.register_modules( 134 | vae=vae, 135 | image_encoder=image_encoder, 136 | controlnet=controlnet, 137 | unet=unet, 138 | scheduler=scheduler, 139 | feature_extractor=feature_extractor, 140 | ) 141 | 142 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 143 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 144 | 145 | def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): 146 | dtype = next(self.image_encoder.parameters()).dtype 147 | 148 | if not isinstance(image, torch.Tensor): 149 | image = self.image_processor.pil_to_numpy(image) 150 | image = self.image_processor.numpy_to_pt(image) 151 | 152 | #image = image.unsqueeze(0) 153 | image = _resize_with_antialiasing(image, (224, 224)) 154 | 155 | image = image.to(device=device, dtype=dtype) 156 | image_embeddings = self.image_encoder(image).image_embeds 157 | image_embeddings = image_embeddings.unsqueeze(1) 158 | 159 | # duplicate image embeddings for each generation per prompt, using mps friendly method 160 | bs_embed, seq_len, _ = image_embeddings.shape 161 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) 162 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 163 | 164 | if do_classifier_free_guidance: 165 | negative_image_embeddings = torch.zeros_like(image_embeddings) 166 | 167 | # For classifier free guidance, we need to do two forward passes. 168 | # Here we concatenate the unconditional and text embeddings into a single batch 169 | # to avoid doing two forward passes 170 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) 171 | 172 | return image_embeddings 173 | 174 | def _encode_vae_image( 175 | self, 176 | image: torch.Tensor, 177 | device, 178 | num_videos_per_prompt, 179 | do_classifier_free_guidance, 180 | ): 181 | image = image.to(device=device) 182 | image_latents = self.vae.encode(image).latent_dist.mode() 183 | 184 | if do_classifier_free_guidance: 185 | negative_image_latents = torch.zeros_like(image_latents) 186 | 187 | # For classifier free guidance, we need to do two forward passes. 188 | # Here we concatenate the unconditional and text embeddings into a single batch 189 | # to avoid doing two forward passes 190 | image_latents = torch.cat([negative_image_latents, image_latents]) 191 | 192 | # duplicate image_latents for each generation per prompt, using mps friendly method 193 | image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) 194 | 195 | return image_latents 196 | 197 | def _get_add_time_ids( 198 | self, 199 | fps, 200 | motion_bucket_id, 201 | noise_aug_strength, 202 | dtype, 203 | batch_size, 204 | num_videos_per_prompt, 205 | do_classifier_free_guidance, 206 | ): 207 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength] 208 | 209 | passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) 210 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 211 | 212 | if expected_add_embed_dim != passed_add_embed_dim: 213 | raise ValueError( 214 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 215 | ) 216 | 217 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 218 | add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) 219 | 220 | if do_classifier_free_guidance: 221 | add_time_ids = torch.cat([add_time_ids, add_time_ids]) 222 | 223 | return add_time_ids 224 | 225 | def decode_latents(self, latents, num_frames, decode_chunk_size=14): 226 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] 227 | latents = latents.flatten(0, 1) 228 | 229 | latents = 1 / self.vae.config.scaling_factor * latents 230 | 231 | accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) 232 | 233 | # decode decode_chunk_size frames at a time to avoid OOM 234 | frames = [] 235 | for i in range(0, latents.shape[0], decode_chunk_size): 236 | num_frames_in = latents[i : i + decode_chunk_size].shape[0] 237 | decode_kwargs = {} 238 | #if accepts_num_frames: 239 | # # we only pass num_frames_in if it's expected 240 | # decode_kwargs["num_frames"] = num_frames_in 241 | decode_kwargs["num_frames"] = num_frames_in 242 | frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample 243 | frames.append(frame) 244 | frames = torch.cat(frames, dim=0) 245 | 246 | # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] 247 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) 248 | 249 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 250 | frames = frames.float() 251 | return frames 252 | 253 | def check_inputs(self, image, height, width): 254 | if ( 255 | not isinstance(image, torch.Tensor) 256 | and not isinstance(image, PIL.Image.Image) 257 | and not isinstance(image, list) 258 | ): 259 | raise ValueError( 260 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" 261 | f" {type(image)}" 262 | ) 263 | 264 | if height % 8 != 0 or width % 8 != 0: 265 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 266 | 267 | def prepare_latents( 268 | self, 269 | batch_size, 270 | num_frames, 271 | num_channels_latents, 272 | height, 273 | width, 274 | dtype, 275 | device, 276 | generator, 277 | latents=None, 278 | ): 279 | shape = ( 280 | batch_size, 281 | num_frames, 282 | num_channels_latents // 2, 283 | height // self.vae_scale_factor, 284 | width // self.vae_scale_factor, 285 | ) 286 | if isinstance(generator, list) and len(generator) != batch_size: 287 | raise ValueError( 288 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 289 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 290 | ) 291 | 292 | if latents is None: 293 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 294 | else: 295 | latents = latents.to(device) 296 | 297 | # scale the initial noise by the standard deviation required by the scheduler 298 | latents = latents * self.scheduler.init_noise_sigma 299 | return latents 300 | 301 | @property 302 | def guidance_scale(self): 303 | return self._guidance_scale 304 | 305 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 306 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 307 | # corresponds to doing no classifier free guidance. 308 | @property 309 | def do_classifier_free_guidance(self): 310 | return self._guidance_scale >= 1 and self.unet.config.time_cond_proj_dim is None 311 | 312 | @property 313 | def num_timesteps(self): 314 | return self._num_timesteps 315 | 316 | @torch.no_grad() 317 | def __call__( 318 | self, 319 | image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], 320 | controlnet_condition:[torch.FloatTensor] = None, 321 | height: int = 576, 322 | width: int = 1024, 323 | num_frames: Optional[int] = None, 324 | num_inference_steps: int = 25, 325 | min_guidance_scale: float = 1.0, 326 | max_guidance_scale: float = 3.0, 327 | fps: int = 7, 328 | motion_bucket_id: int = 127, 329 | noise_aug_strength: int = 0.02, 330 | decode_chunk_size: Optional[int] = None, 331 | num_videos_per_prompt: Optional[int] = 1, 332 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 333 | latents: Optional[torch.FloatTensor] = None, 334 | output_type: Optional[str] = "pil", 335 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 336 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 337 | return_dict: bool = True, 338 | controlnet_cond_scale=1.0, 339 | batch_size=1, 340 | overlap=5, 341 | frames_per_batch = 14, 342 | ): 343 | r""" 344 | The call function to the pipeline for generation. 345 | 346 | Args: 347 | image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): 348 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with 349 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). 350 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 351 | The height in pixels of the generated image. 352 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 353 | The width in pixels of the generated image. 354 | num_frames (`int`, *optional*): 355 | The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` 356 | num_inference_steps (`int`, *optional*, defaults to 25): 357 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 358 | expense of slower inference. This parameter is modulated by `strength`. 359 | min_guidance_scale (`float`, *optional*, defaults to 1.0): 360 | The minimum guidance scale. Used for the classifier free guidance with first frame. 361 | max_guidance_scale (`float`, *optional*, defaults to 3.0): 362 | The maximum guidance scale. Used for the classifier free guidance with last frame. 363 | fps (`int`, *optional*, defaults to 7): 364 | Frames per second. The rate at which the generated images shall be exported to a video after generation. 365 | Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. 366 | motion_bucket_id (`int`, *optional*, defaults to 127): 367 | The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. 368 | noise_aug_strength (`int`, *optional*, defaults to 0.02): 369 | The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. 370 | decode_chunk_size (`int`, *optional*): 371 | The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency 372 | between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once 373 | for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. 374 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 375 | The number of images to generate per prompt. 376 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 377 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 378 | generation deterministic. 379 | latents (`torch.FloatTensor`, *optional*): 380 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 381 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 382 | tensor is generated by sampling using the supplied random `generator`. 383 | output_type (`str`, *optional*, defaults to `"pil"`): 384 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 385 | callback_on_step_end (`Callable`, *optional*): 386 | A function that calls at the end of each denoising steps during the inference. The function is called 387 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 388 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 389 | `callback_on_step_end_tensor_inputs`. 390 | callback_on_step_end_tensor_inputs (`List`, *optional*): 391 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 392 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 393 | `._callback_tensor_inputs` attribute of your pipeline class. 394 | return_dict (`bool`, *optional*, defaults to `True`): 395 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 396 | plain tuple. 397 | 398 | Returns: 399 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: 400 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, 401 | otherwise a `tuple` is returned where the first element is a list of list with the generated frames. 402 | 403 | Examples: 404 | 405 | ```py 406 | from diffusers import StableVideoDiffusionPipeline 407 | from diffusers.utils import load_image, export_to_video 408 | 409 | pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") 410 | pipe.to("cuda") 411 | 412 | image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") 413 | image = image.resize((1024, 576)) 414 | 415 | frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] 416 | export_to_video(frames, "generated.mp4", fps=7) 417 | ``` 418 | """ 419 | # 0. Default height and width to unet 420 | height = height or self.unet.config.sample_size * self.vae_scale_factor 421 | width = width or self.unet.config.sample_size * self.vae_scale_factor 422 | 423 | num_frames = num_frames if num_frames is not None else self.unet.config.num_frames 424 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames 425 | 426 | # 1. Check inputs. Raise error if not correct 427 | self.check_inputs(image, height, width) 428 | 429 | # 2. Define call parameters 430 | #if isinstance(image, PIL.Image.Image): 431 | # batch_size = 1 432 | #elif isinstance(image, list): 433 | # batch_size = len(image) 434 | #else: 435 | # batch_size = image.shape[0] 436 | device = self._execution_device 437 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 438 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 439 | # corresponds to doing no classifier free guidance. 440 | do_classifier_free_guidance = max_guidance_scale >= 1.0 441 | 442 | # 3. Encode input image 443 | image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) 444 | 445 | # NOTE: Stable Diffusion Video was conditioned on fps - 1, which 446 | # is why it is reduced here. 447 | # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 448 | fps = fps - 1 449 | 450 | # 4. Encode input image using VAE 451 | image = self.image_processor.preprocess(image, height=height, width=width) 452 | noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) 453 | image = image + noise_aug_strength * noise 454 | 455 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 456 | if needs_upcasting: 457 | self.vae.to(dtype=torch.float32) 458 | 459 | image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) 460 | image_latents = image_latents.to(image_embeddings.dtype) 461 | 462 | # cast back to fp16 if needed 463 | if needs_upcasting: 464 | self.vae.to(dtype=torch.float16) 465 | 466 | # Repeat the image latents for each frame so we can concatenate them with the noise 467 | # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] 468 | image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) 469 | #image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents 470 | 471 | # 5. Get Added Time IDs 472 | added_time_ids = self._get_add_time_ids( 473 | fps, 474 | motion_bucket_id, 475 | noise_aug_strength, 476 | image_embeddings.dtype, 477 | batch_size, 478 | num_videos_per_prompt, 479 | do_classifier_free_guidance, 480 | ) 481 | added_time_ids = added_time_ids.to(device) 482 | 483 | # 4. Prepare timesteps 484 | self.scheduler.set_timesteps(num_inference_steps, device=device) 485 | timesteps = self.scheduler.timesteps 486 | 487 | # 5. Prepare latent variables 488 | 489 | num_channels_latents = self.unet.config.in_channels 490 | latents = self.prepare_latents( 491 | batch_size * num_videos_per_prompt, 492 | num_frames, 493 | num_channels_latents, 494 | height, 495 | width, 496 | image_embeddings.dtype, 497 | device, 498 | generator, 499 | latents, 500 | ) 501 | #prepare controlnet condition 502 | controlnet_condition = self.image_processor.preprocess(controlnet_condition, height=height, width=width) 503 | controlnet_condition = (controlnet_condition + 1.0) / 2 504 | controlnet_condition = controlnet_condition.unsqueeze(0) 505 | controlnet_condition = torch.cat([controlnet_condition] * 2) 506 | controlnet_condition = controlnet_condition.to(device, latents.dtype) 507 | controlnet_condition_all = controlnet_condition 508 | latents_all = latents 509 | 510 | # 7. Prepare guidance scale 511 | guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, frames_per_batch).unsqueeze(0) 512 | guidance_scale = guidance_scale.to(device, latents.dtype) 513 | guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) 514 | guidance_scale = _append_dims(guidance_scale, latents.ndim) 515 | 516 | self._guidance_scale = guidance_scale 517 | 518 | noise_aug_strength = 0.02 #"¯\_(ツ)_/¯ 519 | added_time_ids = _get_add_time_ids( 520 | noise_aug_strength, 521 | image_embeddings.dtype, 522 | batch_size, 523 | 6, 524 | 128, 525 | unet=self.unet, 526 | ) 527 | added_time_ids = torch.cat([added_time_ids] * 2) 528 | added_time_ids = added_time_ids.to(latents.device) 529 | 530 | 531 | # 8. Denoising loop 532 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 533 | self._num_timesteps = len(timesteps) 534 | with self.progress_bar(total=num_inference_steps) as progress_bar: 535 | for i, t in enumerate(timesteps): 536 | pred_tmp = torch.zeros_like(latents_all) 537 | counter = torch.zeros((latents.shape[0], num_frames, 1, 1, 1 )).to(device=latents.device) 538 | for batch, ind_start in enumerate(range(0, num_frames, frames_per_batch-overlap)): 539 | self.scheduler._step_index = None 540 | if ind_start + frames_per_batch >= num_frames: 541 | ind_start = num_frames - 1 - frames_per_batch 542 | latents = latents_all[:,ind_start:ind_start+frames_per_batch].contiguous() 543 | controlnet_condition = controlnet_condition_all[:,ind_start:ind_start+frames_per_batch].contiguous() 544 | 545 | 546 | # expand the latents if we are doing classifier free guidance 547 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 548 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 549 | 550 | # Concatenate image_latents over channels dimention 551 | 552 | latent_model_input = torch.cat([latent_model_input, image_latents[:,ind_start:ind_start+frames_per_batch].contiguous()], dim=2) 553 | down_block_res_samples, mid_block_res_sample = self.controlnet( 554 | latent_model_input, 555 | t, 556 | encoder_hidden_states=image_embeddings, 557 | controlnet_cond=controlnet_condition, 558 | added_time_ids=added_time_ids, 559 | conditioning_scale=controlnet_cond_scale, 560 | guess_mode=False, 561 | return_dict=False, 562 | ) 563 | 564 | 565 | # predict the noise residual 566 | noise_pred = self.unet( 567 | latent_model_input, 568 | t, 569 | encoder_hidden_states=image_embeddings, 570 | down_block_additional_residuals=down_block_res_samples, 571 | mid_block_additional_residual=mid_block_res_sample, 572 | added_time_ids=added_time_ids, 573 | return_dict=False, 574 | )[0] 575 | 576 | 577 | 578 | 579 | 580 | # perform guidance 581 | if do_classifier_free_guidance: 582 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 583 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) 584 | 585 | # compute the previous noisy sample x_t -> x_t-1 586 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 587 | 588 | if callback_on_step_end is not None: 589 | callback_kwargs = {} 590 | for k in callback_on_step_end_tensor_inputs: 591 | callback_kwargs[k] = locals()[k] 592 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 593 | 594 | latents = callback_outputs.pop("latents", latents) 595 | if batch == 0: 596 | pred_tmp[:,ind_start:ind_start+frames_per_batch] += latents 597 | counter[:,ind_start:ind_start+frames_per_batch] += 1 598 | else: 599 | pred_tmp[:, ind_start+1:ind_start+frames_per_batch] += latents[:,1:] 600 | counter[:,ind_start+1:ind_start+frames_per_batch] += 1 601 | pred_tmp /= counter 602 | latents_all = pred_tmp 603 | 604 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 605 | progress_bar.update() 606 | latents = latents_all 607 | 608 | if not output_type == "latent": 609 | # cast back to fp16 if needed 610 | if needs_upcasting: 611 | self.vae.to(dtype=torch.float16) 612 | frames = self.decode_latents(latents, num_frames, decode_chunk_size) 613 | frames = tensor2vid(frames, self.image_processor, output_type=output_type) 614 | else: 615 | frames = latents 616 | 617 | self.maybe_free_model_hooks() 618 | 619 | if not return_dict: 620 | return frames 621 | 622 | return StableVideoDiffusionPipelineOutput(frames=frames) 623 | 624 | 625 | # resizing utils 626 | # TODO: clean up later 627 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 628 | 629 | if input.ndim == 3: 630 | input = input.unsqueeze(0) # Add a batch dimension 631 | 632 | h, w = input.shape[-2:] 633 | factors = (h / size[0], w / size[1]) 634 | 635 | # First, we have to determine sigma 636 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 637 | sigmas = ( 638 | max((factors[0] - 1.0) / 2.0, 0.001), 639 | max((factors[1] - 1.0) / 2.0, 0.001), 640 | ) 641 | 642 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 643 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 644 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 645 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 646 | 647 | # Make sure it is odd 648 | if (ks[0] % 2) == 0: 649 | ks = ks[0] + 1, ks[1] 650 | 651 | if (ks[1] % 2) == 0: 652 | ks = ks[0], ks[1] + 1 653 | 654 | input = _gaussian_blur2d(input, ks, sigmas) 655 | 656 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 657 | return output 658 | 659 | 660 | def _compute_padding(kernel_size): 661 | """Compute padding tuple.""" 662 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 663 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 664 | if len(kernel_size) < 2: 665 | raise AssertionError(kernel_size) 666 | computed = [k - 1 for k in kernel_size] 667 | 668 | # for even kernels we need to do asymmetric padding :( 669 | out_padding = 2 * len(kernel_size) * [0] 670 | 671 | for i in range(len(kernel_size)): 672 | computed_tmp = computed[-(i + 1)] 673 | 674 | pad_front = computed_tmp // 2 675 | pad_rear = computed_tmp - pad_front 676 | 677 | out_padding[2 * i + 0] = pad_front 678 | out_padding[2 * i + 1] = pad_rear 679 | 680 | return out_padding 681 | 682 | 683 | def _filter2d(input, kernel): 684 | # prepare kernel 685 | b, c, h, w = input.shape 686 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 687 | 688 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 689 | 690 | height, width = tmp_kernel.shape[-2:] 691 | 692 | padding_shape: list[int] = _compute_padding([height, width]) 693 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 694 | 695 | # kernel and input tensor reshape to align element-wise or batch-wise params 696 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 697 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 698 | 699 | # convolve the tensor with the kernel. 700 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 701 | 702 | out = output.view(b, c, h, w) 703 | return out 704 | 705 | 706 | def _gaussian(window_size: int, sigma): 707 | if isinstance(sigma, float): 708 | sigma = torch.tensor([[sigma]]) 709 | 710 | batch_size = sigma.shape[0] 711 | 712 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 713 | 714 | if window_size % 2 == 0: 715 | x = x + 0.5 716 | 717 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 718 | 719 | return gauss / gauss.sum(-1, keepdim=True) 720 | 721 | 722 | def _gaussian_blur2d(input, kernel_size, sigma): 723 | if isinstance(sigma, tuple): 724 | sigma = torch.tensor([sigma], dtype=input.dtype) 725 | else: 726 | sigma = sigma.to(dtype=input.dtype) 727 | 728 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 729 | bs = sigma.shape[0] 730 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 731 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 732 | out_x = _filter2d(input, kernel_x[..., None, :]) 733 | out = _filter2d(out_x, kernel_y[..., None]) 734 | 735 | return out 736 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import requests 4 | 5 | file_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + '/models/animate_anyone' 6 | print(file_dir) 7 | 8 | os.makedirs(file_dir, exist_ok=True) 9 | 10 | 11 | model_files = [ 12 | ('model_index.json', 13 | 'https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/model_index.json?download=true'), 14 | ('feature_extractor/preprocessor_config.json', 15 | 'https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/feature_extractor/preprocessor_config.json?download=true'), 16 | ('image_encoder/config.json', 17 | 'https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/image_encoder/config.json?download=true'), 18 | ('image_encoder/model.safetensors', 19 | 'https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/image_encoder/model.fp16.safetensors?download=true'), 20 | ('scheduler/scheduler_config.json', 21 | 'https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/scheduler/scheduler_config.json?download=true'), 22 | ('vae/config.json', 23 | 'https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/vae/config.json?download=true'), 24 | ('vae/diffusion_pytorch_model.safetensors', 25 | 'https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true'), 26 | 27 | 28 | ('unet/config.json', 29 | 'https://modelscope.cn/api/v1/models/lightnessly/animate-anyone-v1/repo?Revision=master&FilePath=unet/config.json'), 30 | ('unet/diffusion_pytorch_model.safetensors', 31 | 'https://modelscope.cn/api/v1/models/lightnessly/animate-anyone-v1/repo?Revision=master&FilePath=unet/diffusion_pytorch_model.safetensors'), 32 | ('controlnet/config.json', 33 | 'https://modelscope.cn/api/v1/models/lightnessly/animate-anyone-v1/repo?Revision=master&FilePath=controlnet/config.json'), 34 | ('controlnet/diffusion_pytorch_model.safetensors', 35 | 'https://modelscope.cn/api/v1/models/lightnessly/animate-anyone-v1/repo?Revision=master&FilePath=controlnet/diffusion_pytorch_model.safetensors'), 36 | 37 | ] 38 | 39 | 40 | 41 | for file_info in model_files: 42 | file_name, file_url = file_info 43 | file_path = os.path.join(file_dir, file_name) 44 | print(f"Start download file: {file_url}") 45 | print(f" to: {file_path}") 46 | file_dirname = os.path.dirname(file_path) 47 | os.makedirs(file_dirname, exist_ok=True) 48 | 49 | if not os.path.exists(file_path): 50 | response = requests.get(file_url) 51 | with open(file_path, "wb") as f: 52 | f.write(response.content) 53 | print(f"Done: {file_path}") 54 | else: 55 | print(f"File exist: {file_path}") 56 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | diffusers==0.25 3 | numpy 4 | transformers 5 | opencv-python 6 | accelerate -------------------------------------------------------------------------------- /run_inference_release.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import datetime 4 | import numpy as np 5 | from PIL import Image 6 | from pipeline_stable_video_diffusion_controlnet_long import StableVideoDiffusionPipelineControlNet 7 | from controlnet_sdv import ControlNetSDVModel 8 | #from diffusers import T2IAdapter 9 | from unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel 10 | import cv2 11 | import re 12 | 13 | def write_mp4(video_path, samples): 14 | writer = cv2.VideoWriter( 15 | video_path, 16 | cv2.VideoWriter_fourcc(*"MP4V"), 17 | 15, 18 | (samples[0].shape[1], samples[0].shape[0]), 19 | ) 20 | 21 | for frame in samples: 22 | writer.write(frame) 23 | writer.release() 24 | 25 | def save_gifs_side_by_side(batch_output, validation_images, validation_control_images, output_folder): 26 | # Helper function to convert tensors to PIL images and save as GIF 27 | flattened_batch_output = [img for sublist in batch_output for img in sublist] 28 | video_path = output_folder+'/test_1.mp4' 29 | final_images = [] 30 | # Helper function to concatenate images horizontally 31 | def get_concat_h(im1, im2): 32 | dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height))) 33 | dst.paste(im1, (0, 0)) 34 | dst.paste(im2, (im1.width, 0)) 35 | return dst 36 | for idx, image_list in enumerate(zip(validation_images, validation_control_images, flattened_batch_output)): 37 | result = get_concat_h(image_list[0], image_list[1]) 38 | result = get_concat_h(result, image_list[2]) 39 | final_images.append(np.array(result)[:,:,::-1]) 40 | write_mp4(video_path, final_images) 41 | 42 | # Define functions 43 | def validate_and_convert_image(image, target_size=(256, 256)): 44 | if image is None: 45 | print("Encountered a None image") 46 | return None 47 | 48 | if isinstance(image, torch.Tensor): 49 | # Convert PyTorch tensor to PIL Image 50 | if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format 51 | if image.shape[0] == 1: # Convert single-channel grayscale to RGB 52 | image = image.repeat(3, 1, 1) 53 | image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() 54 | image = Image.fromarray(image) 55 | else: 56 | print(f"Invalid image tensor shape: {image.shape}") 57 | return None 58 | elif isinstance(image, Image.Image): 59 | # Resize PIL Image 60 | image = image.resize(target_size) 61 | else: 62 | print("Image is not a PIL Image or a PyTorch tensor") 63 | return None 64 | 65 | return image 66 | 67 | def create_image_grid(images, rows, cols, target_size=(448, 768)): 68 | valid_images = [validate_and_convert_image(img, target_size) for img in images] 69 | valid_images = [img for img in valid_images if img is not None] 70 | 71 | if not valid_images: 72 | print("No valid images to create a grid") 73 | return None 74 | 75 | w, h = target_size 76 | grid = Image.new('RGB', size=(cols * w, rows * h)) 77 | 78 | for i, image in enumerate(valid_images): 79 | grid.paste(image, box=((i % cols) * w, (i // cols) * h)) 80 | 81 | return grid 82 | 83 | def tensor_to_pil(tensor): 84 | """ Convert a PyTorch tensor to a PIL Image. """ 85 | # Convert tensor to numpy array 86 | if len(tensor.shape) == 4: # batch of images 87 | images = [Image.fromarray(img.numpy().transpose(1, 2, 0)) for img in tensor] 88 | else: # single image 89 | images = Image.fromarray(tensor.numpy().transpose(1, 2, 0)) 90 | return images 91 | 92 | def save_combined_frames(batch_output, validation_images, validation_control_images, output_folder): 93 | # Flatten batch_output to a list of PIL Images 94 | flattened_batch_output = [img for sublist in batch_output for img in sublist] 95 | 96 | # Convert tensors in lists to PIL Images 97 | validation_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_images] 98 | validation_control_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_control_images] 99 | flattened_batch_output = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in batch_output] 100 | 101 | # Flatten lists if they contain sublists (for tensors converted to multiple images) 102 | validation_images = [img for sublist in validation_images for img in (sublist if isinstance(sublist, list) else [sublist])] 103 | validation_control_images = [img for sublist in validation_control_images for img in (sublist if isinstance(sublist, list) else [sublist])] 104 | flattened_batch_output = [img for sublist in flattened_batch_output for img in (sublist if isinstance(sublist, list) else [sublist])] 105 | 106 | # Combine frames into a list 107 | combined_frames = validation_images + validation_control_images + flattened_batch_output 108 | 109 | # Calculate rows and columns for the grid 110 | num_images = len(combined_frames) 111 | cols = 3 112 | rows = (num_images + cols - 1) // cols 113 | 114 | # Create and save the grid image 115 | grid = create_image_grid(combined_frames, rows, cols, target_size=(256, 256)) 116 | if grid is not None: 117 | timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 118 | filename = f"combined_frames_{timestamp}.png" 119 | output_path = os.path.join(output_folder, filename) 120 | grid.save(output_path) 121 | else: 122 | print("Failed to create image grid") 123 | 124 | def load_images_from_folder(folder): 125 | images = [] 126 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed 127 | 128 | # Function to extract frame number from the filename 129 | def frame_number(filename): 130 | matches = re.findall(r'\d+', filename) # Find all sequences of digits in the filename 131 | if matches: 132 | if matches[-1] == '0000' and len(matches) > 1: 133 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000' 134 | return int(matches[-1]) # Otherwise, return the last sequence 135 | return float('inf') # Return 'inf' 136 | 137 | 138 | # Sorting files based on frame number 139 | sorted_files = sorted(os.listdir(folder)) 140 | 141 | # Load images in sorted order 142 | for filename in sorted_files: 143 | ext = os.path.splitext(filename)[1].lower() 144 | if ext in valid_extensions: 145 | img = Image.open(os.path.join(folder, filename)).convert('RGB') 146 | images.append(img) 147 | 148 | return images 149 | 150 | def load_images_from_folder_to_pil(folder, target_size=(512, 512)): 151 | images = [] 152 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed 153 | 154 | def frame_number(filename): 155 | matches = re.findall(r'\d+', filename) # Find all sequences of digits in the filename 156 | if matches: 157 | if matches[-1] == '0000' and len(matches) > 1: 158 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000' 159 | return int(matches[-1]) # Otherwise, return the last sequence 160 | return float('inf') # Return 'inf' 161 | 162 | 163 | # Sorting files based on frame number 164 | sorted_files = sorted(os.listdir(folder)) 165 | 166 | # Load, resize, and convert images 167 | for filename in sorted_files: 168 | ext = os.path.splitext(filename)[1].lower() 169 | if ext in valid_extensions: 170 | img = Image.open(os.path.join(folder, filename)).convert('RGB') 171 | images.append(img) 172 | 173 | return images[::2] 174 | 175 | # Usage example 176 | def convert_list_bgra_to_rgba(image_list): 177 | """ 178 | Convert a list of PIL Image objects from BGRA to RGBA format. 179 | 180 | Parameters: 181 | image_list (list of PIL.Image.Image): A list of images in BGRA format. 182 | 183 | Returns: 184 | list of PIL.Image.Image: The list of images converted to RGBA format. 185 | """ 186 | rgba_images = [] 187 | for image in image_list: 188 | if image.mode == 'RGBA' or image.mode == 'BGRA': 189 | # Split the image into its components 190 | b, g, r, a = image.split() 191 | # Re-merge in RGBA order 192 | converted_image = Image.merge("RGBA", (r, g, b, a)) 193 | else: 194 | # For non-alpha images, assume they are BGR and convert to RGB 195 | b, g, r = image.split() 196 | converted_image = Image.merge("RGB", (r, g, b)) 197 | 198 | rgba_images.append(converted_image) 199 | 200 | return rgba_images 201 | 202 | # Main script 203 | if __name__ == "__main__": 204 | from tqdm import tqdm 205 | args = { 206 | # "pretrained_model_name_or_path": "checkpoint/SVD/svd_14", 207 | "pretrained_model_name_or_path": "checkpoint", 208 | "validation_image_folder": "./testcase/81FyMPk-WIS/images", 209 | "validation_control_folder": "./testcase/81FyMPk-WIS/dwpose_woface", 210 | "output_dir": "./output", 211 | "height": 896, 212 | "width":704, 213 | # cant be bothered to add the args in myself, just use notepad 214 | } 215 | 216 | # Load validation images and control images 217 | validation_images = load_images_from_folder_to_pil(args["validation_image_folder"], (args['width'], args['height'])) 218 | #validation_images = convert_list_bgra_to_rgba(validation_images) 219 | validation_control_images = load_images_from_folder_to_pil(args["validation_control_folder"], (args['width'], args['height'])) 220 | 221 | 222 | controlnet = ControlNetSDVModel.from_pretrained("checkpoint/controlnet") 223 | unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(args["pretrained_model_name_or_path"]+'/unet') 224 | 225 | pipeline = StableVideoDiffusionPipelineControlNet.from_pretrained(args["pretrained_model_name_or_path"], controlnet=controlnet, unet=unet) 226 | pipeline.to(dtype=torch.float16) 227 | pipeline.enable_model_cpu_offload() 228 | 229 | print('Model loading: DONE.') 230 | 231 | val_save_dir = os.path.join(args["output_dir"], "validation_images") 232 | os.makedirs(val_save_dir, exist_ok=True) 233 | 234 | # Inference and saving loop 235 | final_result = [] 236 | #ref_image = validation_images[0] 237 | ref_image = Image.open('./testcase/test_12.png').convert('RGB') 238 | frames = 14 239 | num_frames = len(validation_images) 240 | 241 | print('Image and frame data loading: DONE.') 242 | 243 | video_frames = pipeline(ref_image, validation_control_images[:num_frames], decode_chunk_size=2,num_frames=num_frames,motion_bucket_id=127.0, fps=7,controlnet_cond_scale=1.0, width=args['width'], height=args["height"], min_guidance_scale=3.5, max_guidance_scale=3.5, frames_per_batch=frames, num_inference_steps=25, overlap=4).frames[0] 244 | final_result.append(video_frames) 245 | 246 | save_gifs_side_by_side(final_result,validation_images[:num_frames], validation_control_images[:num_frames],val_save_dir) 247 | -------------------------------------------------------------------------------- /unet_spatio_temporal_condition_controlnet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.loaders import UNet2DConditionLoadersMixin 9 | from diffusers.utils import BaseOutput, logging 10 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor 11 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 12 | from diffusers.models.modeling_utils import ModelMixin 13 | from diffusers.models.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block 14 | 15 | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17 | 18 | 19 | @dataclass 20 | class UNetSpatioTemporalConditionOutput(BaseOutput): 21 | """ 22 | The output of [`UNetSpatioTemporalConditionModel`]. 23 | 24 | Args: 25 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): 26 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 27 | """ 28 | 29 | sample: torch.FloatTensor = None 30 | 31 | 32 | class UNetSpatioTemporalConditionControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 33 | r""" 34 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample 35 | shaped output. 36 | 37 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 38 | for all models (such as downloading or saving). 39 | 40 | Parameters: 41 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 42 | Height and width of input/output sample. 43 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 44 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 45 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 46 | The tuple of downsample blocks to use. 47 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 48 | The tuple of upsample blocks to use. 49 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 50 | The tuple of output channels for each block. 51 | addition_time_embed_dim: (`int`, defaults to 256): 52 | Dimension to to encode the additional time ids. 53 | projection_class_embeddings_input_dim (`int`, defaults to 768): 54 | The dimension of the projection of encoded `added_time_ids`. 55 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 56 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 57 | The dimension of the cross attention features. 58 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 59 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 60 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 61 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 62 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 63 | The number of attention heads. 64 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 65 | """ 66 | 67 | _supports_gradient_checkpointing = True 68 | 69 | @register_to_config 70 | def __init__( 71 | self, 72 | sample_size: Optional[int] = None, 73 | in_channels: int = 8, 74 | out_channels: int = 4, 75 | down_block_types: Tuple[str] = ( 76 | "CrossAttnDownBlockSpatioTemporal", 77 | "CrossAttnDownBlockSpatioTemporal", 78 | "CrossAttnDownBlockSpatioTemporal", 79 | "DownBlockSpatioTemporal", 80 | ), 81 | up_block_types: Tuple[str] = ( 82 | "UpBlockSpatioTemporal", 83 | "CrossAttnUpBlockSpatioTemporal", 84 | "CrossAttnUpBlockSpatioTemporal", 85 | "CrossAttnUpBlockSpatioTemporal", 86 | ), 87 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 88 | addition_time_embed_dim: int = 256, 89 | projection_class_embeddings_input_dim: int = 768, 90 | layers_per_block: Union[int, Tuple[int]] = 2, 91 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 92 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 93 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), 94 | num_frames: int = 25, 95 | upcast_attention: bool = False, 96 | 97 | ): 98 | super().__init__() 99 | 100 | self.sample_size = sample_size 101 | 102 | # Check inputs 103 | if len(down_block_types) != len(up_block_types): 104 | raise ValueError( 105 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 106 | ) 107 | 108 | if len(block_out_channels) != len(down_block_types): 109 | raise ValueError( 110 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 111 | ) 112 | 113 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 114 | raise ValueError( 115 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 116 | ) 117 | 118 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 119 | raise ValueError( 120 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 121 | ) 122 | 123 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 124 | raise ValueError( 125 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 126 | ) 127 | 128 | # input 129 | self.conv_in = nn.Conv2d( 130 | in_channels, 131 | block_out_channels[0], 132 | kernel_size=3, 133 | padding=1, 134 | ) 135 | 136 | # time 137 | time_embed_dim = block_out_channels[0] * 4 138 | 139 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 140 | timestep_input_dim = block_out_channels[0] 141 | 142 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 143 | 144 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 145 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 146 | 147 | self.down_blocks = nn.ModuleList([]) 148 | self.up_blocks = nn.ModuleList([]) 149 | 150 | if isinstance(num_attention_heads, int): 151 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 152 | 153 | if isinstance(cross_attention_dim, int): 154 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 155 | 156 | if isinstance(layers_per_block, int): 157 | layers_per_block = [layers_per_block] * len(down_block_types) 158 | 159 | if isinstance(transformer_layers_per_block, int): 160 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 161 | 162 | blocks_time_embed_dim = time_embed_dim 163 | 164 | # down 165 | output_channel = block_out_channels[0] 166 | for i, down_block_type in enumerate(down_block_types): 167 | input_channel = output_channel 168 | output_channel = block_out_channels[i] 169 | is_final_block = i == len(block_out_channels) - 1 170 | 171 | down_block = get_down_block( 172 | down_block_type, 173 | num_layers=layers_per_block[i], 174 | transformer_layers_per_block=transformer_layers_per_block[i], 175 | in_channels=input_channel, 176 | out_channels=output_channel, 177 | temb_channels=blocks_time_embed_dim, 178 | add_downsample=not is_final_block, 179 | resnet_eps=1e-5, 180 | cross_attention_dim=cross_attention_dim[i], 181 | num_attention_heads=num_attention_heads[i], 182 | resnet_act_fn="silu", 183 | upcast_attention=upcast_attention, 184 | ) 185 | self.down_blocks.append(down_block) 186 | 187 | # mid 188 | self.mid_block = UNetMidBlockSpatioTemporal( 189 | block_out_channels[-1], 190 | temb_channels=blocks_time_embed_dim, 191 | transformer_layers_per_block=transformer_layers_per_block[-1], 192 | cross_attention_dim=cross_attention_dim[-1], 193 | num_attention_heads=num_attention_heads[-1], 194 | #upcast_attention=upcast_attention, 195 | ) 196 | 197 | # count how many layers upsample the images 198 | self.num_upsamplers = 0 199 | 200 | # up 201 | reversed_block_out_channels = list(reversed(block_out_channels)) 202 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 203 | reversed_layers_per_block = list(reversed(layers_per_block)) 204 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 205 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 206 | 207 | output_channel = reversed_block_out_channels[0] 208 | for i, up_block_type in enumerate(up_block_types): 209 | is_final_block = i == len(block_out_channels) - 1 210 | 211 | prev_output_channel = output_channel 212 | output_channel = reversed_block_out_channels[i] 213 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 214 | 215 | # add upsample block for all BUT final layer 216 | if not is_final_block: 217 | add_upsample = True 218 | self.num_upsamplers += 1 219 | else: 220 | add_upsample = False 221 | 222 | up_block = get_up_block( 223 | up_block_type, 224 | num_layers=reversed_layers_per_block[i] + 1, 225 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 226 | in_channels=input_channel, 227 | out_channels=output_channel, 228 | prev_output_channel=prev_output_channel, 229 | temb_channels=blocks_time_embed_dim, 230 | add_upsample=add_upsample, 231 | resnet_eps=1e-5, 232 | resolution_idx=i, 233 | cross_attention_dim=reversed_cross_attention_dim[i], 234 | num_attention_heads=reversed_num_attention_heads[i], 235 | resnet_act_fn="silu", 236 | upcast_attention=upcast_attention, 237 | ) 238 | self.up_blocks.append(up_block) 239 | prev_output_channel = output_channel 240 | 241 | # out 242 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) 243 | self.conv_act = nn.SiLU() 244 | 245 | self.conv_out = nn.Conv2d( 246 | block_out_channels[0], 247 | out_channels, 248 | kernel_size=3, 249 | padding=1, 250 | ) 251 | 252 | @property 253 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 254 | r""" 255 | Returns: 256 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 257 | indexed by its weight name. 258 | """ 259 | # set recursively 260 | processors = {} 261 | 262 | def fn_recursive_add_processors( 263 | name: str, 264 | module: torch.nn.Module, 265 | processors: Dict[str, AttentionProcessor], 266 | ): 267 | if hasattr(module, "get_processor"): 268 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 269 | 270 | for sub_name, child in module.named_children(): 271 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 272 | 273 | return processors 274 | 275 | for name, module in self.named_children(): 276 | fn_recursive_add_processors(name, module, processors) 277 | 278 | return processors 279 | 280 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 281 | r""" 282 | Sets the attention processor to use to compute attention. 283 | 284 | Parameters: 285 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 286 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 287 | for **all** `Attention` layers. 288 | 289 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 290 | processor. This is strongly recommended when setting trainable attention processors. 291 | 292 | """ 293 | count = len(self.attn_processors.keys()) 294 | 295 | if isinstance(processor, dict) and len(processor) != count: 296 | raise ValueError( 297 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 298 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 299 | ) 300 | 301 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 302 | if hasattr(module, "set_processor"): 303 | if not isinstance(processor, dict): 304 | module.set_processor(processor) 305 | else: 306 | module.set_processor(processor.pop(f"{name}.processor")) 307 | 308 | for sub_name, child in module.named_children(): 309 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 310 | 311 | for name, module in self.named_children(): 312 | fn_recursive_attn_processor(name, module, processor) 313 | 314 | def set_default_attn_processor(self): 315 | """ 316 | Disables custom attention processors and sets the default attention implementation. 317 | """ 318 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 319 | processor = AttnProcessor() 320 | else: 321 | raise ValueError( 322 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 323 | ) 324 | 325 | self.set_attn_processor(processor) 326 | 327 | def _set_gradient_checkpointing(self, module, value=False): 328 | if hasattr(module, "gradient_checkpointing"): 329 | module.gradient_checkpointing = value 330 | 331 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 332 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 333 | """ 334 | Sets the attention processor to use [feed forward 335 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 336 | 337 | Parameters: 338 | chunk_size (`int`, *optional*): 339 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 340 | over each tensor of dim=`dim`. 341 | dim (`int`, *optional*, defaults to `0`): 342 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 343 | or dim=1 (sequence length). 344 | """ 345 | if dim not in [0, 1]: 346 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 347 | 348 | # By default chunk size is 1 349 | chunk_size = chunk_size or 1 350 | 351 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 352 | if hasattr(module, "set_chunk_feed_forward"): 353 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 354 | 355 | for child in module.children(): 356 | fn_recursive_feed_forward(child, chunk_size, dim) 357 | 358 | for module in self.children(): 359 | fn_recursive_feed_forward(module, chunk_size, dim) 360 | 361 | def forward( 362 | self, 363 | sample: torch.FloatTensor, 364 | timestep: Union[torch.Tensor, float, int], 365 | encoder_hidden_states: torch.Tensor, 366 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 367 | mid_block_additional_residual: Optional[torch.Tensor] = None, 368 | return_dict: bool = True, 369 | added_time_ids: torch.Tensor=None, 370 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 371 | r""" 372 | The [`UNetSpatioTemporalConditionModel`] forward method. 373 | 374 | Args: 375 | sample (`torch.FloatTensor`): 376 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 377 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 378 | encoder_hidden_states (`torch.FloatTensor`): 379 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 380 | added_time_ids: (`torch.FloatTensor`): 381 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 382 | embeddings and added to the time embeddings. 383 | return_dict (`bool`, *optional*, defaults to `True`): 384 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 385 | tuple. 386 | Returns: 387 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 388 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 389 | a `tuple` is returned where the first element is the sample tensor. 390 | """ 391 | # 1. time 392 | timesteps = timestep 393 | if not torch.is_tensor(timesteps): 394 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 395 | # This would be a good case for the `match` statement (Python 3.10+) 396 | is_mps = sample.device.type == "mps" 397 | if isinstance(timestep, float): 398 | dtype = torch.float32 if is_mps else torch.float64 399 | else: 400 | dtype = torch.int32 if is_mps else torch.int64 401 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 402 | elif len(timesteps.shape) == 0: 403 | timesteps = timesteps[None].to(sample.device) 404 | 405 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 406 | batch_size, num_frames = sample.shape[:2] 407 | timesteps = timesteps.expand(batch_size) 408 | 409 | t_emb = self.time_proj(timesteps) 410 | 411 | # `Timesteps` does not contain any weights and will always return f32 tensors 412 | # but time_embedding might actually be running in fp16. so we need to cast here. 413 | # there might be better ways to encapsulate this. 414 | t_emb = t_emb.to(dtype=sample.dtype) 415 | 416 | emb = self.time_embedding(t_emb) 417 | 418 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 419 | time_embeds = time_embeds.reshape((batch_size, -1)) 420 | time_embeds = time_embeds.to(emb.dtype) 421 | aug_emb = self.add_embedding(time_embeds) 422 | emb = emb + aug_emb 423 | 424 | # Flatten the batch and frames dimensions 425 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 426 | sample = sample.flatten(0, 1) 427 | # Repeat the embeddings num_video_frames times 428 | # emb: [batch, channels] -> [batch * frames, channels] 429 | emb = emb.repeat_interleave(num_frames, dim=0) 430 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 431 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 432 | 433 | # 2. pre-process 434 | sample = self.conv_in(sample) 435 | 436 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 437 | 438 | down_block_res_samples = (sample,) 439 | for downsample_block in self.down_blocks: 440 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 441 | sample, res_samples = downsample_block( 442 | hidden_states=sample, 443 | temb=emb, 444 | encoder_hidden_states=encoder_hidden_states, 445 | image_only_indicator=image_only_indicator, 446 | ) 447 | else: 448 | sample, res_samples = downsample_block( 449 | hidden_states=sample, 450 | temb=emb, 451 | image_only_indicator=image_only_indicator, 452 | ) 453 | 454 | down_block_res_samples += res_samples 455 | 456 | new_down_block_res_samples = () 457 | 458 | for down_block_res_sample, down_block_additional_residual in zip( 459 | down_block_res_samples, down_block_additional_residuals 460 | ): 461 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 462 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 463 | 464 | down_block_res_samples = new_down_block_res_samples 465 | 466 | 467 | # 4. mid 468 | sample = self.mid_block( 469 | hidden_states=sample, 470 | temb=emb, 471 | encoder_hidden_states=encoder_hidden_states, 472 | image_only_indicator=image_only_indicator, 473 | ) 474 | sample = sample + mid_block_additional_residual 475 | 476 | 477 | # 5. up 478 | for i, upsample_block in enumerate(self.up_blocks): 479 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 480 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 481 | 482 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 483 | sample = upsample_block( 484 | hidden_states=sample, 485 | temb=emb, 486 | res_hidden_states_tuple=res_samples, 487 | encoder_hidden_states=encoder_hidden_states, 488 | image_only_indicator=image_only_indicator, 489 | ) 490 | else: 491 | sample = upsample_block( 492 | hidden_states=sample, 493 | temb=emb, 494 | res_hidden_states_tuple=res_samples, 495 | image_only_indicator=image_only_indicator, 496 | ) 497 | 498 | # 6. post-process 499 | sample = self.conv_norm_out(sample) 500 | sample = self.conv_act(sample) 501 | sample = self.conv_out(sample) 502 | 503 | # 7. Reshape back to original shape 504 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 505 | 506 | if not return_dict: 507 | return (sample,) 508 | 509 | return UNetSpatioTemporalConditionOutput(sample=sample) 510 | --------------------------------------------------------------------------------