├── ID_Animator_node.py ├── LICENSE ├── README.md ├── __init__.py ├── animatediff ├── models │ ├── attention.py │ ├── attention_bkp.py │ ├── file │ ├── motion_module.py │ ├── motion_module_bkp.py │ ├── resnet.py │ ├── sparse_controlnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ └── pipeline_animation.py └── utils │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ ├── convert_original_stable_diffusion_to_diffusers.py │ └── util.py ├── demo ├── ComfyUI_ID_Animator.gif ├── example.json └── lecun.png ├── faceadapter ├── attention_processor.py ├── face_adapter.py ├── init.py ├── resampler.py └── utils.py ├── if miss module check this requirements.txt ├── inference-v2.yaml ├── models ├── adapter │ └── put adapter file here ├── animatediff_models │ └── put animatediff_models here ├── image_encoder │ ├── config.json │ └── put image_encoder model here └── text_encoder │ ├── config.json │ └── put text_encoder model here └── pyproject.toml /ID_Animator_node.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | 5 | import os 6 | import random 7 | import sys 8 | from einops import rearrange 9 | import torchvision 10 | import yaml 11 | import torch 12 | import cv2 13 | from PIL import Image 14 | import numpy as np 15 | from huggingface_hub import hf_hub_download 16 | 17 | from transformers import CLIPTextModel, CLIPTokenizer 18 | from omegaconf import OmegaConf 19 | from safetensors import safe_open 20 | from insightface.app import FaceAnalysis 21 | from insightface.utils import face_align 22 | from diffusers import (AutoencoderKL, DDIMScheduler, ControlNetModel, 23 | KDPM2AncestralDiscreteScheduler, LMSDiscreteScheduler, 24 | AutoPipelineForInpainting, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, 25 | EulerDiscreteScheduler, HeunDiscreteScheduler, UNet2DConditionModel, 26 | KDPM2DiscreteScheduler, 27 | EulerAncestralDiscreteScheduler, UniPCMultistepScheduler, 28 | StableDiffusionXLControlNetPipeline, DDPMScheduler, TCDScheduler, LCMScheduler) 29 | 30 | from .faceadapter.face_adapter import FaceAdapterPlusForVideoLora 31 | from .animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, \ 32 | convert_ldm_vae_checkpoint 33 | from .animatediff.utils.util import load_weights 34 | from .animatediff.pipelines.pipeline_animation import AnimationPipeline 35 | from .animatediff.models.unet import UNet3DConditionModel 36 | 37 | import folder_paths 38 | 39 | dir_path = os.path.dirname(os.path.abspath(__file__)) 40 | path_dir = os.path.dirname(dir_path) 41 | file_path = os.path.dirname(path_dir) 42 | 43 | motion_path = os.path.join(dir_path, "models","animatediff_models",) 44 | motion_model_list = os.listdir(motion_path) 45 | #print(motion_model_list) 46 | 47 | adapter_lora_path = os.path.join(dir_path, "models","adapter") 48 | fonts_lists = os.listdir(adapter_lora_path) 49 | 50 | 51 | 52 | paths = [] 53 | for search_path in folder_paths.get_folder_paths("diffusers"): 54 | if os.path.exists(search_path): 55 | for root, subdir, files in os.walk(search_path, followlinks=True): 56 | if "model_index.json" in files: 57 | paths.append(os.path.relpath(root, start=search_path)) 58 | 59 | if paths != []: 60 | paths = ["none"] + [x for x in paths if x] 61 | else: 62 | paths = ["none", ] 63 | 64 | def tensor_to_image(tensor): 65 | # tensor = tensor.cpu() 66 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 67 | image = Image.fromarray(image_np, mode='RGB') 68 | return image 69 | def phi2narry(img): 70 | img = torch.from_numpy(np.array(img).astype(np.float32) / 255.0).unsqueeze(0) 71 | return img 72 | 73 | def narry_list(list_in): 74 | for i in range(len(list_in)): 75 | value = list_in[i] 76 | modified_value = phi2narry(value) 77 | list_in[i] = modified_value 78 | return list_in 79 | 80 | 81 | scheduler_list = [ 82 | "DDIM", 83 | "DDPM", 84 | "DPM++ 2M", 85 | "DPM++ 2M Karras", 86 | "DPM++ 2M SDE", 87 | "DPM++ 2M SDE Karras", 88 | "DPM++ SDE", 89 | "DPM++ SDE Karras", 90 | "DPM2", 91 | "DPM2 Karras", 92 | "DPM2 a", 93 | "DPM2 a Karras", 94 | "Heun", 95 | "LCM", 96 | "LMS", 97 | "LMS Karras", 98 | "UniPC", 99 | "UniPC_Bh2", 100 | ] 101 | 102 | 103 | def get_sheduler(name): 104 | scheduler = False 105 | if name == "DDIM": 106 | scheduler = DDIMScheduler 107 | elif name == "DDPM": 108 | scheduler = DDPMScheduler 109 | elif name == "DPM++ 2M": 110 | scheduler = DPMSolverMultistepScheduler 111 | elif name == "DPM++ 2M Karras": 112 | scheduler = DPMSolverMultistepScheduler(use_karras_sigmas=True) 113 | elif name == "DPM++ 2M SDE": 114 | scheduler = DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++") 115 | elif name == "DPM++ 2M SDE Karras": 116 | scheduler = DPMSolverMultistepScheduler(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++") 117 | elif name == "DPM++ SDE": 118 | scheduler = DPMSolverSinglestepScheduler 119 | elif name == "DPM++ SDE Karras": 120 | scheduler = DPMSolverSinglestepScheduler(use_karras_sigmas=True) 121 | elif name == "DPM2": 122 | scheduler = KDPM2DiscreteScheduler 123 | elif name == "DPM2 Karras": 124 | scheduler = KDPM2DiscreteScheduler(use_karras_sigmas=True) 125 | elif name == "DPM2 a": 126 | scheduler = KDPM2AncestralDiscreteScheduler 127 | elif name == "DPM2 a Karras": 128 | scheduler = KDPM2AncestralDiscreteScheduler(use_karras_sigmas=True) 129 | elif name == "Heun": 130 | scheduler = HeunDiscreteScheduler 131 | elif name == "LCD": 132 | scheduler = LCMScheduler 133 | elif name == "LMS": 134 | scheduler = LMSDiscreteScheduler 135 | elif name == "LMS Karras": 136 | scheduler = LMSDiscreteScheduler(use_karras_sigmas=True) 137 | elif name == "UniPC_Bh1": 138 | scheduler = UniPCMultistepScheduler(solver_type="bh1") 139 | elif name == "UniPC_Bh2": 140 | scheduler = UniPCMultistepScheduler(solver_type="bh2") 141 | return scheduler 142 | 143 | 144 | def get_local_path(file_path, model_path): 145 | path = os.path.join(file_path, "models", "diffusers", model_path) 146 | model_path = os.path.normpath(path) 147 | if sys.platform=='win32': 148 | model_path = model_path.replace('\\', "/") 149 | return model_path 150 | 151 | 152 | def get_instance_path(path): 153 | os_path = os.path.normpath(path) 154 | if sys.platform=='win32': 155 | os_path = os_path.replace('\\', "/") 156 | return os_path 157 | 158 | 159 | class ID_Animator: 160 | 161 | def __init__(self): 162 | pass 163 | @classmethod 164 | def INPUT_TYPES(cls): 165 | return { 166 | "required": { 167 | "image": ("IMAGE",), 168 | "repo_id": ("STRING", {"forceInput": True}), 169 | "prompt": ("STRING", {"multiline": True, 170 | "default": "A girl smiling,8k,best quality."}), 171 | "negative_prompt": ("STRING", {"multiline": True, 172 | "default": "semi-realistic, cgi, 3d, render, sketch, cartoon," 173 | " drawing, anime, text, close up, cropped, out of frame," 174 | " worst quality, low quality, jpeg artifacts, ugly, duplicate," 175 | " morbid, mutilated, extra fingers, mutated hands, poorly drawn hands," 176 | " poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, " 177 | "bad proportions, extra limbs, cloned face, disfigured, gross proportions," 178 | " malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers," 179 | " too many fingers, long neck"}), 180 | "scheduler": (scheduler_list,), 181 | "adapter_lora": (["none"]+fonts_lists,), 182 | "adapter_lora_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1,}), 183 | "face_lora": (folder_paths.get_filename_list("loras"),), 184 | "lora_alpha": ("FLOAT", {"default": 0.8, "min": 0.1, "max": 20.0, "step": 0.1,}), 185 | "steps": ("INT", {"default": 30, "min": 1, "max": 2048, "step": 1, "display": "number"}), 186 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 187 | "cfg": ("FLOAT", {"default": 8, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 188 | "height": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64, "display": "number"}), 189 | "width": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64, "display": "number"}), 190 | "video_length": ("INT", {"default": 16, "min": 1, "max": 32}), 191 | "scale": ("FLOAT", {"default": 0.8, "min": 0.1, "max": 10.0, "step": 0.1, "round": 0.01}), 192 | } 193 | } 194 | 195 | RETURN_TYPES = ("IMAGE",) 196 | RETURN_NAMES = ("image",) 197 | FUNCTION = "id_animator" 198 | CATEGORY = "ID_Animator" 199 | 200 | def load_model(self, inference_config, sd_version, scheduler, id_ckpt, image_encoder_path, dreambooth_model_path, 201 | motion_module_path, adapter_lora,adapter_lora_scale,face_lora,lora_alpha): 202 | face_lora = get_instance_path(folder_paths.get_full_path("loras", face_lora)) 203 | if adapter_lora=="none": 204 | adapter_lora= hf_hub_download( 205 | repo_id="guoyww/animatediff", 206 | filename="v3_sd15_adapter.ckpt", 207 | local_dir=get_instance_path(adapter_lora_path), 208 | ) 209 | else: 210 | adapter_lora=get_instance_path(os.path.join(adapter_lora_path,adapter_lora)) 211 | inference_config = OmegaConf.load(inference_config) 212 | 213 | tokenizer = CLIPTokenizer.from_pretrained(sd_version, subfolder="tokenizer", torch_dtype=torch.float16, 214 | ) 215 | text_encoder = CLIPTextModel.from_pretrained(sd_version, subfolder="text_encoder", torch_dtype=torch.float16, 216 | ).cuda() 217 | vae = AutoencoderKL.from_pretrained(sd_version, subfolder="vae", torch_dtype=torch.float16, 218 | ).cuda() 219 | unet = UNet3DConditionModel.from_pretrained_2d(sd_version, subfolder="unet", 220 | unet_additional_kwargs=OmegaConf.to_container( 221 | inference_config.unet_additional_kwargs) 222 | ).cuda() 223 | scheduler_used = get_sheduler(scheduler) 224 | pipeline = AnimationPipeline( 225 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 226 | controlnet=None, 227 | # beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1 228 | scheduler=scheduler_used(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs) 229 | # scheduler=EulerAncestralDiscreteScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs) 230 | # scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1 231 | 232 | ), torch_dtype=torch.float16, 233 | ).to("cuda") 234 | 235 | pipeline = load_weights( 236 | pipeline, 237 | # motion module 238 | motion_module_path=motion_module_path, 239 | motion_module_lora_configs=[], 240 | # domain adapter 241 | adapter_lora_path=adapter_lora, 242 | adapter_lora_scale=adapter_lora_scale, 243 | # image layers 244 | dreambooth_model_path=None, 245 | lora_model_path=face_lora, 246 | lora_alpha=lora_alpha 247 | ).to("cuda") 248 | if dreambooth_model_path != "": 249 | print(f"load dreambooth model from {dreambooth_model_path}") 250 | dreambooth_state_dict = {} 251 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 252 | for key in f.keys(): 253 | dreambooth_state_dict[key] = f.get_tensor(key) 254 | 255 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config) 256 | # print(vae) 257 | # vae ->to_q,to_k,to_v 258 | # print(converted_vae_checkpoint) 259 | convert_vae_keys = list(converted_vae_checkpoint.keys()) 260 | for key in convert_vae_keys: 261 | if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in key: 262 | new_key = None 263 | if "key" in key: 264 | new_key = key.replace("key", "to_k") 265 | elif "query" in key: 266 | new_key = key.replace("query", "to_q") 267 | elif "value" in key: 268 | new_key = key.replace("value", "to_v") 269 | elif "proj_attn" in key: 270 | new_key = key.replace("proj_attn", "to_out.0") 271 | if new_key: 272 | converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key) 273 | 274 | pipeline.vae.load_state_dict(converted_vae_checkpoint) 275 | 276 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config) 277 | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 278 | 279 | pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict).to("cuda") 280 | del dreambooth_state_dict 281 | pipeline = pipeline.to(torch.float16) 282 | id_animator = FaceAdapterPlusForVideoLora(pipeline, image_encoder_path, id_ckpt, num_tokens=16, 283 | device=torch.device("cuda"), torch_type=torch.float16) 284 | return id_animator 285 | 286 | def get_video_img(self,videos,rescale=False, n_rows=6): 287 | videos = rearrange(videos, "b c t h w -> t b c h w") 288 | outputs = [] 289 | for x in videos: 290 | x = torchvision.utils.make_grid(x, nrow=n_rows) 291 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 292 | if rescale: 293 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 294 | x=tensor_to_image(x) 295 | outputs.append(x) 296 | return outputs 297 | 298 | 299 | def id_animator(self, image, repo_id,prompt, negative_prompt, scheduler,adapter_lora,adapter_lora_scale,face_lora,lora_alpha 300 | ,steps,seed, cfg, height, width, video_length,scale): 301 | repo_id,dreambooth_model_path,motion_models=repo_id.split(",",2) 302 | inference_config = get_instance_path(os.path.join(dir_path,"inference-v2.yaml")) 303 | id_ckpt = get_instance_path(os.path.join(dir_path, "models", "animator.ckpt")) 304 | image_encoder_path = get_instance_path(os.path.join(dir_path, "models","image_encoder")) 305 | app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 306 | app.prepare(ctx_id=0, det_size=(320, 320)) 307 | 308 | animator = self.load_model(inference_config, repo_id, scheduler, id_ckpt, image_encoder_path, 309 | dreambooth_model_path, motion_models,adapter_lora,adapter_lora_scale,face_lora,lora_alpha) 310 | 311 | Pil_img = tensor_to_image(image) 312 | img = cv2.cvtColor(np.asarray(Pil_img), cv2.COLOR_RGB2BGR) 313 | faces = app.get(img) 314 | face_roi = face_align.norm_crop(img, faces[0]['kps'], 112) 315 | face_roi = cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB) 316 | pil_image = [Image.fromarray(face_roi).resize((224, 224))] 317 | sample = animator.generate(pil_image, negative_prompt=negative_prompt, prompt=prompt, num_inference_steps=steps, 318 | seed=seed, 319 | guidance_scale=cfg, 320 | width=width, 321 | height=height, 322 | video_length=video_length, 323 | scale=scale, 324 | ) 325 | 326 | gen =self.get_video_img(sample) # 获取生成动画单帧的pli列表 327 | gen =narry_list(gen) # 列表排序 328 | images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (height, width, 3))))) # numpy 329 | return (images,) 330 | 331 | class ID_Repo_Choice: 332 | def __init__(self): 333 | pass 334 | 335 | @classmethod 336 | def INPUT_TYPES(cls): 337 | return { 338 | "required": { 339 | "local_model_path": (paths,), 340 | "repo_id": ("STRING", {"default": "runwayml/stable-diffusion-v1-5"}), 341 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 342 | "motion_model": (["none"]+motion_model_list,) 343 | } 344 | } 345 | 346 | RETURN_TYPES = ("STRING",) 347 | RETURN_NAMES = ("repo_id",) 348 | FUNCTION = "repo_choice" 349 | CATEGORY = "ID_Animator" 350 | 351 | def repo_choice(self, local_model_path, repo_id,ckpt_name,motion_model): 352 | motion_model_path = os.path.join(dir_path, "models", "animatediff_models") 353 | if repo_id == "": 354 | if local_model_path == "none": 355 | raise "you need fill repo_id or download model in diffusers directory " 356 | elif local_model_path != "none": 357 | model_path = get_local_path(file_path, local_model_path) 358 | repo_id = get_instance_path(model_path) 359 | elif repo_id != "" and repo_id.find("/") == -1: 360 | raise "Incorrect repo_id format" 361 | elif repo_id != "" and repo_id.find("\\") != -1: 362 | repo_id = get_instance_path(repo_id) 363 | if motion_model =="none": 364 | motion_path = hf_hub_download( 365 | repo_id="guoyww/animatediff", 366 | filename="mm_sd_v15_v2.ckpt", 367 | local_dir=get_instance_path(motion_model_path), 368 | ) 369 | else: 370 | motion_path = get_instance_path(os.path.join(motion_model_path, motion_model)) 371 | ckpt_path = get_instance_path(folder_paths.get_full_path("checkpoints", ckpt_name)) 372 | 373 | repo_id=repo_id +","+ ckpt_path +","+ motion_path 374 | return (repo_id,) 375 | 376 | 377 | NODE_CLASS_MAPPINGS = { 378 | "ID_Animator": ID_Animator, 379 | "ID_Repo_Choice":ID_Repo_Choice 380 | } 381 | 382 | NODE_DISPLAY_NAME_MAPPINGS = { 383 | "ID_Animator": "ID_Animator", 384 | "ID_Repo_Choice":"ID_Repo_Choice" 385 | } 386 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A node using ID_Animator in comfyUI 2 | 3 | ## NOTICE 4 | You can find ID_Animator in this link [ID_Animator](https://github.com/ID-Animator/ID-Animator) 5 | 6 | My ComfyUI node list: 7 | ----- 8 | 9 | 1、ParlerTTS node:[ComfyUI_ParlerTTS](https://github.com/smthemex/ComfyUI_ParlerTTS) 10 | 11 | 2、Llama3_8B node:[ComfyUI_Llama3_8B](https://github.com/smthemex/ComfyUI_Llama3_8B) 12 | 13 | 3、HiDiffusion node:[ComfyUI_HiDiffusion_Pro](https://github.com/smthemex/ComfyUI_HiDiffusion_Pro) 14 | 15 | 4、ID_Animator node: [ComfyUI_ID_Animator](https://github.com/smthemex/ComfyUI_ID_Animator) 16 | 17 | 5、StoryDiffusion node:[ComfyUI_StoryDiffusion](https://github.com/smthemex/ComfyUI_StoryDiffusion) 18 | 19 | 6、Pops node:[ComfyUI_Pops](https://github.com/smthemex/ComfyUI_Pops) 20 | 21 | 7、stable-audio-open-1.0 node :[ComfyUI_StableAudio_Open](https://github.com/smthemex/ComfyUI_StableAudio_Open) 22 | 23 | 8、GLM4 node:[ComfyUI_ChatGLM_API](https://github.com/smthemex/ComfyUI_ChatGLM_API) 24 | 25 | 9、CustomNet node:[ComfyUI_CustomNet](https://github.com/smthemex/ComfyUI_CustomNet) 26 | 27 | 10、Pipeline_Tool node :[ComfyUI_Pipeline_Tool](https://github.com/smthemex/ComfyUI_Pipeline_Tool) 28 | 29 | 11、Pic2Story node :[ComfyUI_Pic2Story](https://github.com/smthemex/ComfyUI_Pic2Story) 30 | 31 | 12、PBR_Maker node:[ComfyUI_PBR_Maker](https://github.com/smthemex/ComfyUI_PBR_Maker) 32 | 33 | Update 34 | --- 35 | 2024-06-15 36 | 37 | 1、修复animateddiff帧率上限为32的问题。感谢ShmuelRonen 的提醒 38 | 2、加入face_lora 及lora_adapter的条件控制,模型地址在下面的模型说明里。 39 | 3、加入diffuser 0.28.0以上版本的支持 40 | 41 | 1. Fix the issue of animateddiff with a maximum frame rate of 32. Thank you for [ShmuelRonen](https://github.com/ShmuelRonen) 42 | 's reminder 43 | 2. Add conditional control for "face_lora" and "lora-adapter", and the model address is provided in the model description below. 44 | 3. . Add support for diffuser versions 0.28.0 and above 45 | 46 | --- 既往更新 Previous updates 47 | 48 | 1、输出改成单帧图像,方便接其他的视频合成节点,取消原作保存gif动画的选项。 49 | 2、新增模型加载菜单,逻辑上更清晰一些,你可以多放几个动作模型进“.. ComfyUI_ID_Animator/models/animatediff_models”目录 50 | 51 | 1. Change the output to a single frame image for easy access to other video synthesis nodes, and remove the option to save the original GIF animation. 52 | 2. Add a new model loading menu to make the logic clearer. You can add a few more action models to the ".. ComfyUI-ID-Animator/models/animateddiff_models" directory 53 | 54 | 1.Installation 安装 55 | ---- 56 | ``` python 57 | git https://github.com/smthemex/ComfyUI_ID_Animator.git 58 | ``` 59 | 2 Dependencies 需求库 60 | ----- 61 | If the module is missing, please refer to the separate installation of the missing module in the "if miss module check this requirements.txt" file 62 | 63 | 如果缺失模块,请打开"if miss module check this requirements.txt",单独安装缺失的模块 64 | 65 | 66 | 3 Download the checkpoints 下载模型 67 | ---- 68 | 69 | 3.1 dir.. ComfyUI_ID_Animator/models 70 | - Download ID-Animator checkpoint:"animator.ckpt" [link](https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/) 71 | 72 | 3.2 dir.. ComfyUI_ID_Animator/models/animatediff_models 73 | - Download AnimateDiff checkpoint like "/mm_sd_v15_v2.ckpt" [link](https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/) 74 | 75 | 3.3 dir.. comfy/models/diffusers 76 | - Download Stable Diffusion V1.5 all files [link](https://huggingface.co/spaces/ID-Animator/ID-Animator/tree/main/animatediff/sd) 77 | - or 78 | - Download Stable Diffusion V1.5 most files [link](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) 79 | 80 | 3.4 dir.. comfy/models/checkpoints 81 | - Download "realisticVisionV60B1_v51VAE.safetensors" [link](https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/) 82 | or any other dreambooth models 83 | 84 | 3.5 dir.. ComfyUI_ID_Animator/models/image_encoder 85 | - Download CLIP Image encoder [link](https://huggingface.co/spaces/ID-Animator/ID-Animator/tree/main/image_encoder) 86 | 87 | 3.6 dir.. ComfyUI_ID_Animator/models/adapter 88 | - Download "v3_sd15_adapter.ckpt" [link](https://huggingface.co/guoyww/animatediff/tree/main) 89 | 90 | 3.7 other models 91 | The first run will download the insightface models to the "X/user/username/.insightface/models/buffalo_l" directory 92 | 93 | 4 other 其他 94 | ---- 95 | 因为"ID_Animator"作者没有标注开源许可协议,所以我暂时把开源许可协议设置为Apache-2.0 license 96 | Because "ID_Animator"does not indicate the open source license agreement, I have temporarily set the open source license agreement to Apache-2.0 license 97 | 98 | 5 example 示例 99 | ---- 100 | 101 | ![](https://github.com/smthemex/ComfyUI_ID_Animator/blob/main/demo/ComfyUI_ID_Animator.gif) 102 | 103 | 104 | 105 | 6 Contact "ID_Animator" 106 | ----- 107 | Xuanhua He: hexuanhua@mail.ustc.edu.cn 108 | 109 | Quande Liu: qdliu0226@gmail.com 110 | 111 | Shengju Qian: thesouthfrog@gmail.com 112 | 113 | AnimateDif 114 | --- 115 | ``` 116 | @article{guo2023animatediff, 117 | title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning}, 118 | author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Liang, Zhengyang and Wang, Yaohui and Qiao, Yu and Agrawala, Maneesh and Lin, Dahua and Dai, Bo}, 119 | journal={International Conference on Learning Representations}, 120 | year={2024} 121 | } 122 | 123 | @article{guo2023sparsectrl, 124 | title={SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models}, 125 | author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Agrawala, Maneesh and Lin, Dahua and Dai, Bo}, 126 | journal={arXiv preprint arXiv:2311.16933}, 127 | year={2023} 128 | } 129 | ``` 130 | 131 | 132 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | python = sys.executable 4 | 5 | from .ID_Animator_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 6 | 7 | 8 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 9 | -------------------------------------------------------------------------------- /animatediff/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import FeedForward, AdaLayerNorm,Attention 15 | 16 | from einops import rearrange, repeat 17 | import pdb 18 | 19 | from diffusers.models.attention_processor import AttnProcessor,AttnProcessor2_0 20 | @dataclass 21 | class Transformer3DModelOutput(BaseOutput): 22 | sample: torch.FloatTensor 23 | from diffusers.utils import logging 24 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 25 | 26 | if is_xformers_available(): 27 | import xformers 28 | import xformers.ops 29 | else: 30 | xformers = None 31 | 32 | 33 | class Transformer3DModel(ModelMixin, ConfigMixin): 34 | @register_to_config 35 | def __init__( 36 | self, 37 | num_attention_heads: int = 16, 38 | attention_head_dim: int = 88, 39 | in_channels: Optional[int] = None, 40 | num_layers: int = 1, 41 | dropout: float = 0.0, 42 | norm_num_groups: int = 32, 43 | cross_attention_dim: Optional[int] = None, 44 | attention_bias: bool = False, 45 | activation_fn: str = "geglu", 46 | num_embeds_ada_norm: Optional[int] = None, 47 | use_linear_projection: bool = False, 48 | only_cross_attention: bool = False, 49 | upcast_attention: bool = False, 50 | unet_use_cross_frame_attention=None, 51 | unet_use_temporal_attention=None, 52 | processor: Optional["AttnProcessor"] = None, 53 | ): 54 | super().__init__() 55 | self.use_linear_projection = use_linear_projection 56 | self.num_attention_heads = num_attention_heads 57 | self.attention_head_dim = attention_head_dim 58 | inner_dim = num_attention_heads * attention_head_dim 59 | 60 | # Define input layers 61 | self.in_channels = in_channels 62 | 63 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 64 | if use_linear_projection: 65 | self.proj_in = nn.Linear(in_channels, inner_dim) 66 | else: 67 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 68 | 69 | # Define transformers blocks 70 | self.transformer_blocks = nn.ModuleList( 71 | [ 72 | BasicTransformerBlock( 73 | inner_dim, 74 | num_attention_heads, 75 | attention_head_dim, 76 | dropout=dropout, 77 | cross_attention_dim=cross_attention_dim, 78 | activation_fn=activation_fn, 79 | num_embeds_ada_norm=num_embeds_ada_norm, 80 | attention_bias=attention_bias, 81 | only_cross_attention=only_cross_attention, 82 | upcast_attention=upcast_attention, 83 | 84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 85 | unet_use_temporal_attention=unet_use_temporal_attention, 86 | ) 87 | for d in range(num_layers) 88 | ] 89 | ) 90 | 91 | # 4. Define output layers 92 | if use_linear_projection: 93 | self.proj_out = nn.Linear(in_channels, inner_dim) 94 | else: 95 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 96 | # if processor is None: 97 | # processor = ( 98 | # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() 99 | # ) 100 | # self.set_processor(processor) 101 | # def set_processor(self, processor: "AttnProcessor") -> None: 102 | # r""" 103 | # Set the attention processor to use. 104 | 105 | # Args: 106 | # processor (`AttnProcessor`): 107 | # The attention processor to use. 108 | # """ 109 | # # if current processor is in `self._modules` and if passed `processor` is not, we need to 110 | # # pop `processor` from `self._modules` 111 | # if ( 112 | # hasattr(self, "processor") 113 | # and isinstance(self.processor, torch.nn.Module) 114 | # and not isinstance(processor, torch.nn.Module) 115 | # ): 116 | # logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") 117 | # self._modules.pop("processor") 118 | 119 | # self.processor = processor 120 | 121 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 122 | # Input 123 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 124 | video_length = hidden_states.shape[2] 125 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 126 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 127 | 128 | batch, channel, height, weight = hidden_states.shape 129 | residual = hidden_states 130 | 131 | hidden_states = self.norm(hidden_states) 132 | if not self.use_linear_projection: 133 | hidden_states = self.proj_in(hidden_states) 134 | inner_dim = hidden_states.shape[1] 135 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 136 | else: 137 | inner_dim = hidden_states.shape[1] 138 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 139 | hidden_states = self.proj_in(hidden_states) 140 | 141 | # Blocks 142 | for block in self.transformer_blocks: 143 | hidden_states = block( 144 | hidden_states, 145 | encoder_hidden_states=encoder_hidden_states, 146 | timestep=timestep, 147 | video_length=video_length 148 | ) 149 | 150 | # Output 151 | if not self.use_linear_projection: 152 | hidden_states = ( 153 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 154 | ) 155 | hidden_states = self.proj_out(hidden_states) 156 | else: 157 | hidden_states = self.proj_out(hidden_states) 158 | hidden_states = ( 159 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 160 | ) 161 | 162 | output = hidden_states + residual 163 | 164 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 165 | if not return_dict: 166 | return (output,) 167 | 168 | return Transformer3DModelOutput(sample=output) 169 | 170 | 171 | class BasicTransformerBlock(nn.Module): 172 | def __init__( 173 | self, 174 | dim: int, 175 | num_attention_heads: int, 176 | attention_head_dim: int, 177 | dropout=0.0, 178 | cross_attention_dim: Optional[int] = None, 179 | activation_fn: str = "geglu", 180 | num_embeds_ada_norm: Optional[int] = None, 181 | attention_bias: bool = False, 182 | only_cross_attention: bool = False, 183 | upcast_attention: bool = False, 184 | 185 | unet_use_cross_frame_attention = None, 186 | unet_use_temporal_attention = None, 187 | ): 188 | super().__init__() 189 | self.only_cross_attention = only_cross_attention 190 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 191 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 192 | self.unet_use_temporal_attention = unet_use_temporal_attention 193 | 194 | # SC-Attn 195 | assert unet_use_cross_frame_attention is not None 196 | if unet_use_cross_frame_attention: 197 | self.attn1 = SparseCausalAttention2D( 198 | query_dim=dim, 199 | heads=num_attention_heads, 200 | dim_head=attention_head_dim, 201 | dropout=dropout, 202 | bias=attention_bias, 203 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 204 | upcast_attention=upcast_attention, 205 | ) 206 | else: 207 | #self-attention 208 | self.attn1 = Attention( 209 | query_dim=dim, 210 | heads=num_attention_heads, 211 | dim_head=attention_head_dim, 212 | dropout=dropout, 213 | bias=attention_bias, 214 | upcast_attention=upcast_attention, 215 | cross_attention_dim=None, 216 | ) 217 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 218 | 219 | # Cross-Attn 220 | if cross_attention_dim is not None: 221 | self.attn2 = Attention( 222 | query_dim=dim, 223 | cross_attention_dim=cross_attention_dim, 224 | heads=num_attention_heads, 225 | dim_head=attention_head_dim, 226 | dropout=dropout, 227 | bias=attention_bias, 228 | upcast_attention=upcast_attention, 229 | ) 230 | else: 231 | self.attn2 = None 232 | 233 | if cross_attention_dim is not None: 234 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 235 | else: 236 | self.norm2 = None 237 | 238 | # Feed-forward 239 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 240 | self.norm3 = nn.LayerNorm(dim) 241 | 242 | # Temp-Attn 243 | assert unet_use_temporal_attention is not None 244 | if unet_use_temporal_attention: 245 | self.attn_temp = Attention( 246 | query_dim=dim, 247 | heads=num_attention_heads, 248 | dim_head=attention_head_dim, 249 | dropout=dropout, 250 | bias=attention_bias, 251 | upcast_attention=upcast_attention, 252 | ) 253 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 254 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 255 | 256 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool,attention_op = None): 257 | if not is_xformers_available(): 258 | print("Here is how to install it") 259 | raise ModuleNotFoundError( 260 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 261 | " xformers", 262 | name="xformers", 263 | ) 264 | elif not torch.cuda.is_available(): 265 | raise ValueError( 266 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 267 | " available for GPU " 268 | ) 269 | else: 270 | try: 271 | # Make sure we can run the memory efficient attention 272 | _ = xformers.ops.memory_efficient_attention( 273 | torch.randn((1, 2, 40), device="cuda"), 274 | torch.randn((1, 2, 40), device="cuda"), 275 | torch.randn((1, 2, 40), device="cuda"), 276 | ) 277 | except Exception as e: 278 | raise e 279 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 280 | if self.attn2 is not None: 281 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 282 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 283 | 284 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 285 | # SparseCausal-Attention 286 | norm_hidden_states = ( 287 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 288 | ) 289 | 290 | # if self.only_cross_attention: 291 | # hidden_states = ( 292 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 293 | # ) 294 | # else: 295 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 296 | 297 | # pdb.set_trace() 298 | if self.unet_use_cross_frame_attention: 299 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 300 | else: 301 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 302 | 303 | if self.attn2 is not None: 304 | # Cross-Attention 305 | norm_hidden_states = ( 306 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 307 | ) 308 | hidden_states = ( 309 | self.attn2( 310 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 311 | ) 312 | + hidden_states 313 | ) 314 | # Feed-forward 315 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 316 | 317 | # Temporal-Attention 318 | if self.unet_use_temporal_attention: 319 | d = hidden_states.shape[1] 320 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 321 | norm_hidden_states = ( 322 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 323 | ) 324 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 325 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 326 | 327 | return hidden_states 328 | -------------------------------------------------------------------------------- /animatediff/models/attention_bkp.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm 15 | 16 | from einops import rearrange, repeat 17 | import pdb 18 | 19 | from diffusers.models.attention_processor import AttnProcessor,AttnProcessor2_0 20 | @dataclass 21 | class Transformer3DModelOutput(BaseOutput): 22 | sample: torch.FloatTensor 23 | from diffusers.utils import logging 24 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 25 | 26 | if is_xformers_available(): 27 | import xformers 28 | import xformers.ops 29 | else: 30 | xformers = None 31 | 32 | 33 | class Transformer3DModel(ModelMixin, ConfigMixin): 34 | @register_to_config 35 | def __init__( 36 | self, 37 | num_attention_heads: int = 16, 38 | attention_head_dim: int = 88, 39 | in_channels: Optional[int] = None, 40 | num_layers: int = 1, 41 | dropout: float = 0.0, 42 | norm_num_groups: int = 32, 43 | cross_attention_dim: Optional[int] = None, 44 | attention_bias: bool = False, 45 | activation_fn: str = "geglu", 46 | num_embeds_ada_norm: Optional[int] = None, 47 | use_linear_projection: bool = False, 48 | only_cross_attention: bool = False, 49 | upcast_attention: bool = False, 50 | unet_use_cross_frame_attention=None, 51 | unet_use_temporal_attention=None, 52 | processor: Optional["AttnProcessor"] = None, 53 | ): 54 | super().__init__() 55 | self.use_linear_projection = use_linear_projection 56 | self.num_attention_heads = num_attention_heads 57 | self.attention_head_dim = attention_head_dim 58 | inner_dim = num_attention_heads * attention_head_dim 59 | 60 | # Define input layers 61 | self.in_channels = in_channels 62 | 63 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 64 | if use_linear_projection: 65 | self.proj_in = nn.Linear(in_channels, inner_dim) 66 | else: 67 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 68 | 69 | # Define transformers blocks 70 | self.transformer_blocks = nn.ModuleList( 71 | [ 72 | BasicTransformerBlock( 73 | inner_dim, 74 | num_attention_heads, 75 | attention_head_dim, 76 | dropout=dropout, 77 | cross_attention_dim=cross_attention_dim, 78 | activation_fn=activation_fn, 79 | num_embeds_ada_norm=num_embeds_ada_norm, 80 | attention_bias=attention_bias, 81 | only_cross_attention=only_cross_attention, 82 | upcast_attention=upcast_attention, 83 | 84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 85 | unet_use_temporal_attention=unet_use_temporal_attention, 86 | ) 87 | for d in range(num_layers) 88 | ] 89 | ) 90 | 91 | # 4. Define output layers 92 | if use_linear_projection: 93 | self.proj_out = nn.Linear(in_channels, inner_dim) 94 | else: 95 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 96 | # if processor is None: 97 | # processor = ( 98 | # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() 99 | # ) 100 | # self.set_processor(processor) 101 | def set_processor(self, processor: "AttnProcessor") -> None: 102 | r""" 103 | Set the attention processor to use. 104 | 105 | Args: 106 | processor (`AttnProcessor`): 107 | The attention processor to use. 108 | """ 109 | # if current processor is in `self._modules` and if passed `processor` is not, we need to 110 | # pop `processor` from `self._modules` 111 | if ( 112 | hasattr(self, "processor") 113 | and isinstance(self.processor, torch.nn.Module) 114 | and not isinstance(processor, torch.nn.Module) 115 | ): 116 | logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") 117 | self._modules.pop("processor") 118 | 119 | self.processor = processor 120 | 121 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 122 | # Input 123 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 124 | video_length = hidden_states.shape[2] 125 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 126 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 127 | 128 | batch, channel, height, weight = hidden_states.shape 129 | residual = hidden_states 130 | 131 | hidden_states = self.norm(hidden_states) 132 | if not self.use_linear_projection: 133 | hidden_states = self.proj_in(hidden_states) 134 | inner_dim = hidden_states.shape[1] 135 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 136 | else: 137 | inner_dim = hidden_states.shape[1] 138 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 139 | hidden_states = self.proj_in(hidden_states) 140 | 141 | # Blocks 142 | for block in self.transformer_blocks: 143 | hidden_states = block( 144 | hidden_states, 145 | encoder_hidden_states=encoder_hidden_states, 146 | timestep=timestep, 147 | video_length=video_length 148 | ) 149 | 150 | # Output 151 | if not self.use_linear_projection: 152 | hidden_states = ( 153 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 154 | ) 155 | hidden_states = self.proj_out(hidden_states) 156 | else: 157 | hidden_states = self.proj_out(hidden_states) 158 | hidden_states = ( 159 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 160 | ) 161 | 162 | output = hidden_states + residual 163 | 164 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 165 | if not return_dict: 166 | return (output,) 167 | 168 | return Transformer3DModelOutput(sample=output) 169 | 170 | 171 | class BasicTransformerBlock(nn.Module): 172 | def __init__( 173 | self, 174 | dim: int, 175 | num_attention_heads: int, 176 | attention_head_dim: int, 177 | dropout=0.0, 178 | cross_attention_dim: Optional[int] = None, 179 | activation_fn: str = "geglu", 180 | num_embeds_ada_norm: Optional[int] = None, 181 | attention_bias: bool = False, 182 | only_cross_attention: bool = False, 183 | upcast_attention: bool = False, 184 | 185 | unet_use_cross_frame_attention = None, 186 | unet_use_temporal_attention = None, 187 | ): 188 | super().__init__() 189 | self.only_cross_attention = only_cross_attention 190 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 191 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 192 | self.unet_use_temporal_attention = unet_use_temporal_attention 193 | 194 | # SC-Attn 195 | assert unet_use_cross_frame_attention is not None 196 | if unet_use_cross_frame_attention: 197 | self.attn1 = SparseCausalAttention2D( 198 | query_dim=dim, 199 | heads=num_attention_heads, 200 | dim_head=attention_head_dim, 201 | dropout=dropout, 202 | bias=attention_bias, 203 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 204 | upcast_attention=upcast_attention, 205 | ) 206 | else: 207 | self.attn1 = CrossAttention( 208 | query_dim=dim, 209 | heads=num_attention_heads, 210 | dim_head=attention_head_dim, 211 | dropout=dropout, 212 | bias=attention_bias, 213 | upcast_attention=upcast_attention, 214 | ) 215 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 216 | 217 | # Cross-Attn 218 | if cross_attention_dim is not None: 219 | self.attn2 = CrossAttention( 220 | query_dim=dim, 221 | cross_attention_dim=cross_attention_dim, 222 | heads=num_attention_heads, 223 | dim_head=attention_head_dim, 224 | dropout=dropout, 225 | bias=attention_bias, 226 | upcast_attention=upcast_attention, 227 | ) 228 | else: 229 | self.attn2 = None 230 | 231 | if cross_attention_dim is not None: 232 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 233 | else: 234 | self.norm2 = None 235 | 236 | # Feed-forward 237 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 238 | self.norm3 = nn.LayerNorm(dim) 239 | 240 | # Temp-Attn 241 | assert unet_use_temporal_attention is not None 242 | if unet_use_temporal_attention: 243 | self.attn_temp = CrossAttention( 244 | query_dim=dim, 245 | heads=num_attention_heads, 246 | dim_head=attention_head_dim, 247 | dropout=dropout, 248 | bias=attention_bias, 249 | upcast_attention=upcast_attention, 250 | ) 251 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 252 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 253 | 254 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool,attention_op = None): 255 | if not is_xformers_available(): 256 | print("Here is how to install it") 257 | raise ModuleNotFoundError( 258 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 259 | " xformers", 260 | name="xformers", 261 | ) 262 | elif not torch.cuda.is_available(): 263 | raise ValueError( 264 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 265 | " available for GPU " 266 | ) 267 | else: 268 | try: 269 | # Make sure we can run the memory efficient attention 270 | _ = xformers.ops.memory_efficient_attention( 271 | torch.randn((1, 2, 40), device="cuda"), 272 | torch.randn((1, 2, 40), device="cuda"), 273 | torch.randn((1, 2, 40), device="cuda"), 274 | ) 275 | except Exception as e: 276 | raise e 277 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 278 | if self.attn2 is not None: 279 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 280 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 281 | 282 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 283 | # SparseCausal-Attention 284 | norm_hidden_states = ( 285 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 286 | ) 287 | 288 | # if self.only_cross_attention: 289 | # hidden_states = ( 290 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 291 | # ) 292 | # else: 293 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 294 | 295 | # pdb.set_trace() 296 | if self.unet_use_cross_frame_attention: 297 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 298 | else: 299 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 300 | 301 | if self.attn2 is not None: 302 | # Cross-Attention 303 | norm_hidden_states = ( 304 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 305 | ) 306 | hidden_states = ( 307 | self.attn2( 308 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 309 | ) 310 | + hidden_states 311 | ) 312 | 313 | # Feed-forward 314 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 315 | 316 | # Temporal-Attention 317 | if self.unet_use_temporal_attention: 318 | d = hidden_states.shape[1] 319 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 320 | norm_hidden_states = ( 321 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 322 | ) 323 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 324 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 325 | 326 | return hidden_states 327 | -------------------------------------------------------------------------------- /animatediff/models/file: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /animatediff/models/motion_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import torchvision 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import FeedForward,Attention 15 | 16 | from einops import rearrange, repeat 17 | import math 18 | 19 | 20 | def zero_module(module): 21 | # Zero out the parameters of a module and return it. 22 | for p in module.parameters(): 23 | p.detach().zero_() 24 | return module 25 | 26 | 27 | @dataclass 28 | class TemporalTransformer3DModelOutput(BaseOutput): 29 | sample: torch.FloatTensor 30 | 31 | 32 | if is_xformers_available(): 33 | import xformers 34 | import xformers.ops 35 | else: 36 | xformers = None 37 | 38 | 39 | def get_motion_module( 40 | in_channels, 41 | motion_module_type: str, 42 | motion_module_kwargs: dict 43 | ): 44 | if motion_module_type == "Vanilla": 45 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 46 | else: 47 | raise ValueError 48 | 49 | 50 | class VanillaTemporalModule(nn.Module): 51 | def __init__( 52 | self, 53 | in_channels, 54 | num_attention_heads = 8, 55 | num_transformer_block = 2, 56 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 57 | cross_frame_attention_mode = None, 58 | temporal_position_encoding = False, 59 | temporal_position_encoding_max_len = 24, 60 | temporal_attention_dim_div = 1, 61 | zero_initialize = True, 62 | ): 63 | super().__init__() 64 | 65 | self.temporal_transformer = TemporalTransformer3DModel( 66 | in_channels=in_channels, 67 | num_attention_heads=num_attention_heads, 68 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 69 | num_layers=num_transformer_block, 70 | attention_block_types=attention_block_types, 71 | cross_frame_attention_mode=cross_frame_attention_mode, 72 | temporal_position_encoding=temporal_position_encoding, 73 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 74 | ) 75 | 76 | if zero_initialize: 77 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 78 | 79 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 80 | hidden_states = input_tensor 81 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 82 | 83 | output = hidden_states 84 | return output 85 | 86 | 87 | class TemporalTransformer3DModel(nn.Module): 88 | def __init__( 89 | self, 90 | in_channels, 91 | num_attention_heads, 92 | attention_head_dim, 93 | 94 | num_layers, 95 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 96 | dropout = 0.0, 97 | norm_num_groups = 32, 98 | cross_attention_dim = 768, 99 | activation_fn = "geglu", 100 | attention_bias = False, 101 | upcast_attention = False, 102 | 103 | cross_frame_attention_mode = None, 104 | temporal_position_encoding = False, 105 | temporal_position_encoding_max_len = 24, 106 | ): 107 | super().__init__() 108 | 109 | inner_dim = num_attention_heads * attention_head_dim 110 | 111 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 112 | self.proj_in = nn.Linear(in_channels, inner_dim) 113 | 114 | self.transformer_blocks = nn.ModuleList( 115 | [ 116 | TemporalTransformerBlock( 117 | dim=inner_dim, 118 | num_attention_heads=num_attention_heads, 119 | attention_head_dim=attention_head_dim, 120 | attention_block_types=attention_block_types, 121 | dropout=dropout, 122 | norm_num_groups=norm_num_groups, 123 | cross_attention_dim=cross_attention_dim, 124 | activation_fn=activation_fn, 125 | attention_bias=attention_bias, 126 | upcast_attention=upcast_attention, 127 | cross_frame_attention_mode=cross_frame_attention_mode, 128 | temporal_position_encoding=temporal_position_encoding, 129 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 130 | ) 131 | for d in range(num_layers) 132 | ] 133 | ) 134 | self.proj_out = nn.Linear(inner_dim, in_channels) 135 | 136 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 137 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 138 | video_length = hidden_states.shape[2] 139 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 140 | 141 | batch, channel, height, weight = hidden_states.shape 142 | residual = hidden_states 143 | 144 | hidden_states = self.norm(hidden_states) 145 | inner_dim = hidden_states.shape[1] 146 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 147 | hidden_states = self.proj_in(hidden_states) 148 | 149 | # Transformer Blocks 150 | for block in self.transformer_blocks: 151 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 152 | 153 | # output 154 | hidden_states = self.proj_out(hidden_states) 155 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 156 | 157 | output = hidden_states + residual 158 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 159 | 160 | return output 161 | 162 | 163 | class TemporalTransformerBlock(nn.Module): 164 | def __init__( 165 | self, 166 | dim, 167 | num_attention_heads, 168 | attention_head_dim, 169 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 170 | dropout = 0.0, 171 | norm_num_groups = 32, 172 | cross_attention_dim = 768, 173 | activation_fn = "geglu", 174 | attention_bias = False, 175 | upcast_attention = False, 176 | cross_frame_attention_mode = None, 177 | temporal_position_encoding = False, 178 | temporal_position_encoding_max_len = 24, 179 | ): 180 | super().__init__() 181 | 182 | attention_blocks = [] 183 | norms = [] 184 | 185 | for block_name in attention_block_types: 186 | attention_blocks.append( 187 | VersatileAttention( 188 | attention_mode=block_name.split("_")[0], 189 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 190 | 191 | query_dim=dim, 192 | heads=num_attention_heads, 193 | dim_head=attention_head_dim, 194 | dropout=dropout, 195 | bias=attention_bias, 196 | upcast_attention=upcast_attention, 197 | 198 | cross_frame_attention_mode=cross_frame_attention_mode, 199 | temporal_position_encoding=temporal_position_encoding, 200 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 201 | ) 202 | ) 203 | norms.append(nn.LayerNorm(dim)) 204 | 205 | self.attention_blocks = nn.ModuleList(attention_blocks) 206 | self.norms = nn.ModuleList(norms) 207 | 208 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 209 | self.ff_norm = nn.LayerNorm(dim) 210 | 211 | 212 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 213 | for attention_block, norm in zip(self.attention_blocks, self.norms): 214 | norm_hidden_states = norm(hidden_states) 215 | hidden_states = attention_block( 216 | norm_hidden_states, 217 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 218 | video_length=video_length, 219 | ) + hidden_states 220 | 221 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 222 | 223 | output = hidden_states 224 | return output 225 | 226 | 227 | class PositionalEncoding(nn.Module): 228 | def __init__( 229 | self, 230 | d_model, 231 | dropout = 0., 232 | max_len = 24 233 | ): 234 | super().__init__() 235 | self.dropout = nn.Dropout(p=dropout) 236 | position = torch.arange(max_len).unsqueeze(1) 237 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 238 | pe = torch.zeros(1, max_len, d_model) 239 | pe[0, :, 0::2] = torch.sin(position * div_term) 240 | pe[0, :, 1::2] = torch.cos(position * div_term) 241 | self.register_buffer('pe', pe) 242 | 243 | def forward(self, x): 244 | x = x + self.pe[:, :x.size(1)] 245 | return self.dropout(x) 246 | 247 | class CrossAttention(nn.Module): 248 | r""" 249 | A cross attention layer. 250 | 251 | Parameters: 252 | query_dim (`int`): The number of channels in the query. 253 | cross_attention_dim (`int`, *optional*): 254 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. 255 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. 256 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. 257 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 258 | bias (`bool`, *optional*, defaults to False): 259 | Set to `True` for the query, key, and value linear layers to contain a bias parameter. 260 | """ 261 | 262 | def __init__( 263 | self, 264 | query_dim: int, 265 | cross_attention_dim: Optional[int] = None, 266 | heads: int = 8, 267 | dim_head: int = 64, 268 | dropout: float = 0.0, 269 | bias=False, 270 | upcast_attention: bool = False, 271 | upcast_softmax: bool = False, 272 | added_kv_proj_dim: Optional[int] = None, 273 | norm_num_groups: Optional[int] = None, 274 | ): 275 | super().__init__() 276 | inner_dim = dim_head * heads 277 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim 278 | self.upcast_attention = upcast_attention 279 | self.upcast_softmax = upcast_softmax 280 | 281 | self.scale = dim_head**-0.5 282 | 283 | self.heads = heads 284 | # for slice_size > 0 the attention score computation 285 | # is split across the batch axis to save memory 286 | # You can set slice_size with `set_attention_slice` 287 | self.sliceable_head_dim = heads 288 | self._slice_size = None 289 | self._use_memory_efficient_attention_xformers = False 290 | self.added_kv_proj_dim = added_kv_proj_dim 291 | 292 | if norm_num_groups is not None: 293 | self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) 294 | else: 295 | self.group_norm = None 296 | 297 | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) 298 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 299 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 300 | 301 | if self.added_kv_proj_dim is not None: 302 | self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) 303 | self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) 304 | 305 | self.to_out = nn.ModuleList([]) 306 | self.to_out.append(nn.Linear(inner_dim, query_dim)) 307 | self.to_out.append(nn.Dropout(dropout)) 308 | 309 | def reshape_heads_to_batch_dim(self, tensor): 310 | batch_size, seq_len, dim = tensor.shape 311 | head_size = self.heads 312 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 313 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 314 | return tensor 315 | 316 | def reshape_batch_dim_to_heads(self, tensor): 317 | batch_size, seq_len, dim = tensor.shape 318 | head_size = self.heads 319 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 320 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 321 | return tensor 322 | 323 | def set_attention_slice(self, slice_size): 324 | if slice_size is not None and slice_size > self.sliceable_head_dim: 325 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") 326 | 327 | self._slice_size = slice_size 328 | 329 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 330 | batch_size, sequence_length, _ = hidden_states.shape 331 | 332 | encoder_hidden_states = encoder_hidden_states 333 | 334 | if self.group_norm is not None: 335 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 336 | 337 | query = self.to_q(hidden_states) 338 | dim = query.shape[-1] 339 | query = self.reshape_heads_to_batch_dim(query) 340 | 341 | if self.added_kv_proj_dim is not None: 342 | key = self.to_k(hidden_states) 343 | value = self.to_v(hidden_states) 344 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) 345 | encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) 346 | 347 | key = self.reshape_heads_to_batch_dim(key) 348 | value = self.reshape_heads_to_batch_dim(value) 349 | encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) 350 | encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) 351 | 352 | key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) 353 | value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) 354 | else: 355 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 356 | key = self.to_k(encoder_hidden_states) 357 | value = self.to_v(encoder_hidden_states) 358 | 359 | key = self.reshape_heads_to_batch_dim(key) 360 | value = self.reshape_heads_to_batch_dim(value) 361 | 362 | if attention_mask is not None: 363 | if attention_mask.shape[-1] != query.shape[1]: 364 | target_length = query.shape[1] 365 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 366 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 367 | 368 | # attention, what we cannot get enough of 369 | if self._use_memory_efficient_attention_xformers: 370 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 371 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 372 | hidden_states = hidden_states.to(query.dtype) 373 | else: 374 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 375 | hidden_states = self._attention(query, key, value, attention_mask) 376 | else: 377 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 378 | 379 | # linear proj 380 | hidden_states = self.to_out[0](hidden_states) 381 | 382 | # dropout 383 | hidden_states = self.to_out[1](hidden_states) 384 | return hidden_states 385 | 386 | def _attention(self, query, key, value, attention_mask=None): 387 | if self.upcast_attention: 388 | query = query.float() 389 | key = key.float() 390 | 391 | attention_scores = torch.baddbmm( 392 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 393 | query, 394 | key.transpose(-1, -2), 395 | beta=0, 396 | alpha=self.scale, 397 | ) 398 | 399 | if attention_mask is not None: 400 | attention_scores = attention_scores + attention_mask 401 | 402 | if self.upcast_softmax: 403 | attention_scores = attention_scores.float() 404 | 405 | attention_probs = attention_scores.softmax(dim=-1) 406 | 407 | # cast back to the original dtype 408 | attention_probs = attention_probs.to(value.dtype) 409 | 410 | # compute attention output 411 | hidden_states = torch.bmm(attention_probs, value) 412 | 413 | # reshape hidden_states 414 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 415 | return hidden_states 416 | 417 | def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): 418 | batch_size_attention = query.shape[0] 419 | hidden_states = torch.zeros( 420 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype 421 | ) 422 | slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] 423 | for i in range(hidden_states.shape[0] // slice_size): 424 | start_idx = i * slice_size 425 | end_idx = (i + 1) * slice_size 426 | 427 | query_slice = query[start_idx:end_idx] 428 | key_slice = key[start_idx:end_idx] 429 | 430 | if self.upcast_attention: 431 | query_slice = query_slice.float() 432 | key_slice = key_slice.float() 433 | 434 | attn_slice = torch.baddbmm( 435 | torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), 436 | query_slice, 437 | key_slice.transpose(-1, -2), 438 | beta=0, 439 | alpha=self.scale, 440 | ) 441 | 442 | if attention_mask is not None: 443 | attn_slice = attn_slice + attention_mask[start_idx:end_idx] 444 | 445 | if self.upcast_softmax: 446 | attn_slice = attn_slice.float() 447 | 448 | attn_slice = attn_slice.softmax(dim=-1) 449 | 450 | # cast back to the original dtype 451 | attn_slice = attn_slice.to(value.dtype) 452 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 453 | 454 | hidden_states[start_idx:end_idx] = attn_slice 455 | 456 | # reshape hidden_states 457 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 458 | return hidden_states 459 | 460 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): 461 | # TODO attention_mask 462 | query = query.contiguous() 463 | key = key.contiguous() 464 | value = value.contiguous() 465 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 466 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 467 | return hidden_states 468 | 469 | class VersatileAttention(CrossAttention): 470 | def __init__( 471 | self, 472 | attention_mode = None, 473 | cross_frame_attention_mode = None, 474 | temporal_position_encoding = False, 475 | temporal_position_encoding_max_len = 24, 476 | *args, **kwargs 477 | ): 478 | super().__init__(*args, **kwargs) 479 | assert attention_mode == "Temporal" 480 | 481 | self.attention_mode = attention_mode 482 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 483 | 484 | self.pos_encoder = PositionalEncoding( 485 | kwargs["query_dim"], 486 | dropout=0., 487 | max_len=temporal_position_encoding_max_len 488 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 489 | 490 | def extra_repr(self): 491 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 492 | 493 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 494 | batch_size, sequence_length, _ = hidden_states.shape 495 | 496 | if self.attention_mode == "Temporal": 497 | d = hidden_states.shape[1] 498 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 499 | 500 | if self.pos_encoder is not None: 501 | hidden_states = self.pos_encoder(hidden_states) 502 | 503 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 504 | else: 505 | raise NotImplementedError 506 | 507 | encoder_hidden_states = encoder_hidden_states 508 | 509 | if self.group_norm is not None: 510 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 511 | 512 | query = self.to_q(hidden_states) 513 | dim = query.shape[-1] 514 | query = self.reshape_heads_to_batch_dim(query) 515 | 516 | if self.added_kv_proj_dim is not None: 517 | raise NotImplementedError 518 | 519 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 520 | key = self.to_k(encoder_hidden_states) 521 | value = self.to_v(encoder_hidden_states) 522 | 523 | key = self.reshape_heads_to_batch_dim(key) 524 | value = self.reshape_heads_to_batch_dim(value) 525 | 526 | if attention_mask is not None: 527 | if attention_mask.shape[-1] != query.shape[1]: 528 | target_length = query.shape[1] 529 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 530 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 531 | 532 | # attention, what we cannot get enough of 533 | if self._use_memory_efficient_attention_xformers: 534 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 535 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 536 | hidden_states = hidden_states.to(query.dtype) 537 | else: 538 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 539 | hidden_states = self._attention(query, key, value, attention_mask) 540 | else: 541 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 542 | 543 | # linear proj 544 | hidden_states = self.to_out[0](hidden_states) 545 | 546 | # dropout 547 | hidden_states = self.to_out[1](hidden_states) 548 | 549 | if self.attention_mode == "Temporal": 550 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 551 | 552 | return hidden_states 553 | -------------------------------------------------------------------------------- /animatediff/models/motion_module_bkp.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import torchvision 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward 15 | 16 | from einops import rearrange, repeat 17 | import math 18 | 19 | 20 | def zero_module(module): 21 | # Zero out the parameters of a module and return it. 22 | for p in module.parameters(): 23 | p.detach().zero_() 24 | return module 25 | 26 | 27 | @dataclass 28 | class TemporalTransformer3DModelOutput(BaseOutput): 29 | sample: torch.FloatTensor 30 | 31 | 32 | if is_xformers_available(): 33 | import xformers 34 | import xformers.ops 35 | else: 36 | xformers = None 37 | 38 | 39 | def get_motion_module( 40 | in_channels, 41 | motion_module_type: str, 42 | motion_module_kwargs: dict 43 | ): 44 | if motion_module_type == "Vanilla": 45 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 46 | else: 47 | raise ValueError 48 | 49 | 50 | class VanillaTemporalModule(nn.Module): 51 | def __init__( 52 | self, 53 | in_channels, 54 | num_attention_heads = 8, 55 | num_transformer_block = 2, 56 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 57 | cross_frame_attention_mode = None, 58 | temporal_position_encoding = False, 59 | temporal_position_encoding_max_len = 24, 60 | temporal_attention_dim_div = 1, 61 | zero_initialize = True, 62 | ): 63 | super().__init__() 64 | 65 | self.temporal_transformer = TemporalTransformer3DModel( 66 | in_channels=in_channels, 67 | num_attention_heads=num_attention_heads, 68 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 69 | num_layers=num_transformer_block, 70 | attention_block_types=attention_block_types, 71 | cross_frame_attention_mode=cross_frame_attention_mode, 72 | temporal_position_encoding=temporal_position_encoding, 73 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 74 | ) 75 | 76 | if zero_initialize: 77 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 78 | 79 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 80 | hidden_states = input_tensor 81 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 82 | 83 | output = hidden_states 84 | return output 85 | 86 | 87 | class TemporalTransformer3DModel(nn.Module): 88 | def __init__( 89 | self, 90 | in_channels, 91 | num_attention_heads, 92 | attention_head_dim, 93 | 94 | num_layers, 95 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 96 | dropout = 0.0, 97 | norm_num_groups = 32, 98 | cross_attention_dim = 768, 99 | activation_fn = "geglu", 100 | attention_bias = False, 101 | upcast_attention = False, 102 | 103 | cross_frame_attention_mode = None, 104 | temporal_position_encoding = False, 105 | temporal_position_encoding_max_len = 24, 106 | ): 107 | super().__init__() 108 | 109 | inner_dim = num_attention_heads * attention_head_dim 110 | 111 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 112 | self.proj_in = nn.Linear(in_channels, inner_dim) 113 | 114 | self.transformer_blocks = nn.ModuleList( 115 | [ 116 | TemporalTransformerBlock( 117 | dim=inner_dim, 118 | num_attention_heads=num_attention_heads, 119 | attention_head_dim=attention_head_dim, 120 | attention_block_types=attention_block_types, 121 | dropout=dropout, 122 | norm_num_groups=norm_num_groups, 123 | cross_attention_dim=cross_attention_dim, 124 | activation_fn=activation_fn, 125 | attention_bias=attention_bias, 126 | upcast_attention=upcast_attention, 127 | cross_frame_attention_mode=cross_frame_attention_mode, 128 | temporal_position_encoding=temporal_position_encoding, 129 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 130 | ) 131 | for d in range(num_layers) 132 | ] 133 | ) 134 | self.proj_out = nn.Linear(inner_dim, in_channels) 135 | 136 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 137 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 138 | video_length = hidden_states.shape[2] 139 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 140 | 141 | batch, channel, height, weight = hidden_states.shape 142 | residual = hidden_states 143 | 144 | hidden_states = self.norm(hidden_states) 145 | inner_dim = hidden_states.shape[1] 146 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 147 | hidden_states = self.proj_in(hidden_states) 148 | 149 | # Transformer Blocks 150 | for block in self.transformer_blocks: 151 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 152 | 153 | # output 154 | hidden_states = self.proj_out(hidden_states) 155 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 156 | 157 | output = hidden_states + residual 158 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 159 | 160 | return output 161 | 162 | 163 | class TemporalTransformerBlock(nn.Module): 164 | def __init__( 165 | self, 166 | dim, 167 | num_attention_heads, 168 | attention_head_dim, 169 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 170 | dropout = 0.0, 171 | norm_num_groups = 32, 172 | cross_attention_dim = 768, 173 | activation_fn = "geglu", 174 | attention_bias = False, 175 | upcast_attention = False, 176 | cross_frame_attention_mode = None, 177 | temporal_position_encoding = False, 178 | temporal_position_encoding_max_len = 24, 179 | ): 180 | super().__init__() 181 | 182 | attention_blocks = [] 183 | norms = [] 184 | 185 | for block_name in attention_block_types: 186 | attention_blocks.append( 187 | VersatileAttention( 188 | attention_mode=block_name.split("_")[0], 189 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 190 | 191 | query_dim=dim, 192 | heads=num_attention_heads, 193 | dim_head=attention_head_dim, 194 | dropout=dropout, 195 | bias=attention_bias, 196 | upcast_attention=upcast_attention, 197 | 198 | cross_frame_attention_mode=cross_frame_attention_mode, 199 | temporal_position_encoding=temporal_position_encoding, 200 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 201 | ) 202 | ) 203 | norms.append(nn.LayerNorm(dim)) 204 | 205 | self.attention_blocks = nn.ModuleList(attention_blocks) 206 | self.norms = nn.ModuleList(norms) 207 | 208 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 209 | self.ff_norm = nn.LayerNorm(dim) 210 | 211 | 212 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 213 | for attention_block, norm in zip(self.attention_blocks, self.norms): 214 | norm_hidden_states = norm(hidden_states) 215 | hidden_states = attention_block( 216 | norm_hidden_states, 217 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 218 | video_length=video_length, 219 | ) + hidden_states 220 | 221 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 222 | 223 | output = hidden_states 224 | return output 225 | 226 | 227 | class PositionalEncoding(nn.Module): 228 | def __init__( 229 | self, 230 | d_model, 231 | dropout = 0., 232 | max_len = 24 233 | ): 234 | super().__init__() 235 | self.dropout = nn.Dropout(p=dropout) 236 | position = torch.arange(max_len).unsqueeze(1) 237 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 238 | pe = torch.zeros(1, max_len, d_model) 239 | pe[0, :, 0::2] = torch.sin(position * div_term) 240 | pe[0, :, 1::2] = torch.cos(position * div_term) 241 | self.register_buffer('pe', pe) 242 | 243 | def forward(self, x): 244 | x = x + self.pe[:, :x.size(1)] 245 | return self.dropout(x) 246 | 247 | 248 | class VersatileAttention(CrossAttention): 249 | def __init__( 250 | self, 251 | attention_mode = None, 252 | cross_frame_attention_mode = None, 253 | temporal_position_encoding = False, 254 | temporal_position_encoding_max_len = 24, 255 | *args, **kwargs 256 | ): 257 | super().__init__(*args, **kwargs) 258 | assert attention_mode == "Temporal" 259 | 260 | self.attention_mode = attention_mode 261 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 262 | 263 | self.pos_encoder = PositionalEncoding( 264 | kwargs["query_dim"], 265 | dropout=0., 266 | max_len=temporal_position_encoding_max_len 267 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 268 | 269 | def extra_repr(self): 270 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 271 | 272 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 273 | batch_size, sequence_length, _ = hidden_states.shape 274 | 275 | if self.attention_mode == "Temporal": 276 | d = hidden_states.shape[1] 277 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 278 | 279 | if self.pos_encoder is not None: 280 | hidden_states = self.pos_encoder(hidden_states) 281 | 282 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 283 | else: 284 | raise NotImplementedError 285 | 286 | encoder_hidden_states = encoder_hidden_states 287 | 288 | if self.group_norm is not None: 289 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 290 | 291 | query = self.to_q(hidden_states) 292 | dim = query.shape[-1] 293 | query = self.reshape_heads_to_batch_dim(query) 294 | 295 | if self.added_kv_proj_dim is not None: 296 | raise NotImplementedError 297 | 298 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 299 | key = self.to_k(encoder_hidden_states) 300 | value = self.to_v(encoder_hidden_states) 301 | 302 | key = self.reshape_heads_to_batch_dim(key) 303 | value = self.reshape_heads_to_batch_dim(value) 304 | 305 | if attention_mask is not None: 306 | if attention_mask.shape[-1] != query.shape[1]: 307 | target_length = query.shape[1] 308 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 309 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 310 | 311 | # attention, what we cannot get enough of 312 | if self._use_memory_efficient_attention_xformers: 313 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 314 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 315 | hidden_states = hidden_states.to(query.dtype) 316 | else: 317 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 318 | hidden_states = self._attention(query, key, value, attention_mask) 319 | else: 320 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 321 | 322 | # linear proj 323 | hidden_states = self.to_out[0](hidden_states) 324 | 325 | # dropout 326 | hidden_states = self.to_out[1](hidden_states) 327 | 328 | if self.attention_mode == "Temporal": 329 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 330 | 331 | return hidden_states 332 | -------------------------------------------------------------------------------- /animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class InflatedGroupNorm(nn.GroupNorm): 22 | def forward(self, x): 23 | video_length = x.shape[2] 24 | 25 | x = rearrange(x, "b c f h w -> (b f) c h w") 26 | x = super().forward(x) 27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 28 | 29 | return x 30 | 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 34 | super().__init__() 35 | self.channels = channels 36 | self.out_channels = out_channels or channels 37 | self.use_conv = use_conv 38 | self.use_conv_transpose = use_conv_transpose 39 | self.name = name 40 | 41 | conv = None 42 | if use_conv_transpose: 43 | raise NotImplementedError 44 | elif use_conv: 45 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 46 | 47 | def forward(self, hidden_states, output_size=None): 48 | assert hidden_states.shape[1] == self.channels 49 | 50 | if self.use_conv_transpose: 51 | raise NotImplementedError 52 | 53 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 54 | dtype = hidden_states.dtype 55 | if dtype == torch.bfloat16: 56 | hidden_states = hidden_states.to(torch.float32) 57 | 58 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 59 | if hidden_states.shape[0] >= 64: 60 | hidden_states = hidden_states.contiguous() 61 | 62 | # if `output_size` is passed we force the interpolation output 63 | # size and do not make use of `scale_factor=2` 64 | if output_size is None: 65 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 66 | else: 67 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 68 | 69 | # If the input is bfloat16, we cast back to bfloat16 70 | if dtype == torch.bfloat16: 71 | hidden_states = hidden_states.to(dtype) 72 | 73 | # if self.use_conv: 74 | # if self.name == "conv": 75 | # hidden_states = self.conv(hidden_states) 76 | # else: 77 | # hidden_states = self.Conv2d_0(hidden_states) 78 | hidden_states = self.conv(hidden_states) 79 | 80 | return hidden_states 81 | 82 | 83 | class Downsample3D(nn.Module): 84 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 85 | super().__init__() 86 | self.channels = channels 87 | self.out_channels = out_channels or channels 88 | self.use_conv = use_conv 89 | self.padding = padding 90 | stride = 2 91 | self.name = name 92 | 93 | if use_conv: 94 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 95 | else: 96 | raise NotImplementedError 97 | 98 | def forward(self, hidden_states): 99 | assert hidden_states.shape[1] == self.channels 100 | if self.use_conv and self.padding == 0: 101 | raise NotImplementedError 102 | 103 | assert hidden_states.shape[1] == self.channels 104 | hidden_states = self.conv(hidden_states) 105 | 106 | return hidden_states 107 | 108 | 109 | class ResnetBlock3D(nn.Module): 110 | def __init__( 111 | self, 112 | *, 113 | in_channels, 114 | out_channels=None, 115 | conv_shortcut=False, 116 | dropout=0.0, 117 | temb_channels=512, 118 | groups=32, 119 | groups_out=None, 120 | pre_norm=True, 121 | eps=1e-6, 122 | non_linearity="swish", 123 | time_embedding_norm="default", 124 | output_scale_factor=1.0, 125 | use_in_shortcut=None, 126 | use_inflated_groupnorm=False, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | assert use_inflated_groupnorm != None 142 | if use_inflated_groupnorm: 143 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 144 | else: 145 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 146 | 147 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 148 | 149 | if temb_channels is not None: 150 | if self.time_embedding_norm == "default": 151 | time_emb_proj_out_channels = out_channels 152 | elif self.time_embedding_norm == "scale_shift": 153 | time_emb_proj_out_channels = out_channels * 2 154 | else: 155 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 156 | 157 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 158 | else: 159 | self.time_emb_proj = None 160 | 161 | if use_inflated_groupnorm: 162 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 163 | else: 164 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 165 | 166 | self.dropout = torch.nn.Dropout(dropout) 167 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 168 | 169 | if non_linearity == "swish": 170 | self.nonlinearity = lambda x: F.silu(x) 171 | elif non_linearity == "mish": 172 | self.nonlinearity = Mish() 173 | elif non_linearity == "silu": 174 | self.nonlinearity = nn.SiLU() 175 | 176 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 177 | 178 | self.conv_shortcut = None 179 | if self.use_in_shortcut: 180 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 181 | 182 | def forward(self, input_tensor, temb): 183 | hidden_states = input_tensor 184 | 185 | hidden_states = self.norm1(hidden_states) 186 | hidden_states = self.nonlinearity(hidden_states) 187 | 188 | hidden_states = self.conv1(hidden_states) 189 | 190 | if temb is not None: 191 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 192 | 193 | if temb is not None and self.time_embedding_norm == "default": 194 | hidden_states = hidden_states + temb 195 | 196 | hidden_states = self.norm2(hidden_states) 197 | 198 | if temb is not None and self.time_embedding_norm == "scale_shift": 199 | scale, shift = torch.chunk(temb, 2, dim=1) 200 | hidden_states = hidden_states * (1 + scale) + shift 201 | 202 | hidden_states = self.nonlinearity(hidden_states) 203 | 204 | hidden_states = self.dropout(hidden_states) 205 | hidden_states = self.conv2(hidden_states) 206 | 207 | if self.conv_shortcut is not None: 208 | input_tensor = self.conv_shortcut(input_tensor) 209 | 210 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 211 | 212 | return output_tensor 213 | 214 | 215 | class Mish(torch.nn.Module): 216 | def forward(self, hidden_states): 217 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /animatediff/models/sparse_controlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Changes were made to this source code by Yuwei Guo. 16 | from dataclasses import dataclass 17 | from typing import Any, Dict, List, Optional, Tuple, Union 18 | 19 | import torch 20 | from torch import nn 21 | from torch.nn import functional as F 22 | 23 | from diffusers.configuration_utils import ConfigMixin, register_to_config 24 | from diffusers.utils import BaseOutput, logging 25 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 26 | # from diffusers.modeling_utils import ModelMixin 27 | from diffusers import ModelMixin 28 | 29 | 30 | from .unet_blocks import ( 31 | CrossAttnDownBlock3D, 32 | DownBlock3D, 33 | UNetMidBlock3DCrossAttn, 34 | get_down_block, 35 | ) 36 | from einops import repeat, rearrange 37 | from .resnet import InflatedConv3d 38 | 39 | import diffusers 40 | dif_version = str(diffusers.__version__) 41 | dif_version_int= int(dif_version.split(".")[1]) 42 | if dif_version_int>=28: 43 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 44 | else: 45 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 46 | 47 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 48 | 49 | 50 | @dataclass 51 | class SparseControlNetOutput(BaseOutput): 52 | down_block_res_samples: Tuple[torch.Tensor] 53 | mid_block_res_sample: torch.Tensor 54 | 55 | 56 | class SparseControlNetConditioningEmbedding(nn.Module): 57 | def __init__( 58 | self, 59 | conditioning_embedding_channels: int, 60 | conditioning_channels: int = 3, 61 | block_out_channels: Tuple[int] = (16, 32, 96, 256), 62 | ): 63 | super().__init__() 64 | 65 | self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 66 | 67 | self.blocks = nn.ModuleList([]) 68 | 69 | for i in range(len(block_out_channels) - 1): 70 | channel_in = block_out_channels[i] 71 | channel_out = block_out_channels[i + 1] 72 | self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) 73 | self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 74 | 75 | self.conv_out = zero_module( 76 | InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 77 | ) 78 | 79 | def forward(self, conditioning): 80 | embedding = self.conv_in(conditioning) 81 | embedding = F.silu(embedding) 82 | 83 | for block in self.blocks: 84 | embedding = block(embedding) 85 | embedding = F.silu(embedding) 86 | 87 | embedding = self.conv_out(embedding) 88 | 89 | return embedding 90 | 91 | 92 | class SparseControlNetModel(ModelMixin, ConfigMixin): 93 | _supports_gradient_checkpointing = True 94 | 95 | @register_to_config 96 | def __init__( 97 | self, 98 | in_channels: int = 4, 99 | conditioning_channels: int = 3, 100 | flip_sin_to_cos: bool = True, 101 | freq_shift: int = 0, 102 | down_block_types: Tuple[str] = ( 103 | "CrossAttnDownBlock2D", 104 | "CrossAttnDownBlock2D", 105 | "CrossAttnDownBlock2D", 106 | "DownBlock2D", 107 | ), 108 | only_cross_attention: Union[bool, Tuple[bool]] = False, 109 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 110 | layers_per_block: int = 2, 111 | downsample_padding: int = 1, 112 | mid_block_scale_factor: float = 1, 113 | act_fn: str = "silu", 114 | norm_num_groups: Optional[int] = 32, 115 | norm_eps: float = 1e-5, 116 | cross_attention_dim: int = 1280, 117 | attention_head_dim: Union[int, Tuple[int]] = 8, 118 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 119 | use_linear_projection: bool = False, 120 | class_embed_type: Optional[str] = None, 121 | num_class_embeds: Optional[int] = None, 122 | upcast_attention: bool = False, 123 | resnet_time_scale_shift: str = "default", 124 | projection_class_embeddings_input_dim: Optional[int] = None, 125 | controlnet_conditioning_channel_order: str = "rgb", 126 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 127 | global_pool_conditions: bool = False, 128 | 129 | use_motion_module = True, 130 | motion_module_resolutions = ( 1,2,4,8 ), 131 | motion_module_mid_block = False, 132 | motion_module_type = "Vanilla", 133 | motion_module_kwargs = { 134 | "num_attention_heads": 8, 135 | "num_transformer_block": 1, 136 | "attention_block_types": ["Temporal_Self"], 137 | "temporal_position_encoding": True, 138 | "temporal_position_encoding_max_len": 32, 139 | "temporal_attention_dim_div": 1, 140 | "causal_temporal_attention": False, 141 | }, 142 | 143 | concate_conditioning_mask: bool = True, 144 | use_simplified_condition_embedding: bool = False, 145 | 146 | set_noisy_sample_input_to_zero: bool = False, 147 | ): 148 | super().__init__() 149 | 150 | # If `num_attention_heads` is not defined (which is the case for most models) 151 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 152 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 153 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 154 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 155 | # which is why we correct for the naming here. 156 | num_attention_heads = num_attention_heads or attention_head_dim 157 | 158 | # Check inputs 159 | if len(block_out_channels) != len(down_block_types): 160 | raise ValueError( 161 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 162 | ) 163 | 164 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 165 | raise ValueError( 166 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 167 | ) 168 | 169 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 170 | raise ValueError( 171 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 172 | ) 173 | 174 | # input 175 | self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero 176 | 177 | conv_in_kernel = 3 178 | conv_in_padding = (conv_in_kernel - 1) // 2 179 | self.conv_in = InflatedConv3d( 180 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 181 | ) 182 | 183 | if concate_conditioning_mask: 184 | conditioning_channels = conditioning_channels + 1 185 | self.concate_conditioning_mask = concate_conditioning_mask 186 | 187 | # control net conditioning embedding 188 | if use_simplified_condition_embedding: 189 | self.controlnet_cond_embedding = zero_module( 190 | InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) 191 | ) 192 | else: 193 | self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( 194 | conditioning_embedding_channels=block_out_channels[0], 195 | block_out_channels=conditioning_embedding_out_channels, 196 | conditioning_channels=conditioning_channels, 197 | ) 198 | self.use_simplified_condition_embedding = use_simplified_condition_embedding 199 | 200 | # time 201 | time_embed_dim = block_out_channels[0] * 4 202 | 203 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 204 | timestep_input_dim = block_out_channels[0] 205 | 206 | self.time_embedding = TimestepEmbedding( 207 | timestep_input_dim, 208 | time_embed_dim, 209 | act_fn=act_fn, 210 | ) 211 | 212 | # class embedding 213 | if class_embed_type is None and num_class_embeds is not None: 214 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 215 | elif class_embed_type == "timestep": 216 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 217 | elif class_embed_type == "identity": 218 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 219 | elif class_embed_type == "projection": 220 | if projection_class_embeddings_input_dim is None: 221 | raise ValueError( 222 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 223 | ) 224 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 225 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 226 | # 2. it projects from an arbitrary input dimension. 227 | # 228 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 229 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 230 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 231 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 232 | else: 233 | self.class_embedding = None 234 | 235 | 236 | self.down_blocks = nn.ModuleList([]) 237 | self.controlnet_down_blocks = nn.ModuleList([]) 238 | 239 | if isinstance(only_cross_attention, bool): 240 | only_cross_attention = [only_cross_attention] * len(down_block_types) 241 | 242 | if isinstance(attention_head_dim, int): 243 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 244 | 245 | if isinstance(num_attention_heads, int): 246 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 247 | 248 | # down 249 | output_channel = block_out_channels[0] 250 | 251 | controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) 252 | controlnet_block = zero_module(controlnet_block) 253 | self.controlnet_down_blocks.append(controlnet_block) 254 | 255 | for i, down_block_type in enumerate(down_block_types): 256 | res = 2 ** i 257 | input_channel = output_channel 258 | output_channel = block_out_channels[i] 259 | is_final_block = i == len(block_out_channels) - 1 260 | 261 | down_block = get_down_block( 262 | down_block_type, 263 | num_layers=layers_per_block, 264 | in_channels=input_channel, 265 | out_channels=output_channel, 266 | temb_channels=time_embed_dim, 267 | add_downsample=not is_final_block, 268 | resnet_eps=norm_eps, 269 | resnet_act_fn=act_fn, 270 | resnet_groups=norm_num_groups, 271 | cross_attention_dim=cross_attention_dim, 272 | attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 273 | downsample_padding=downsample_padding, 274 | use_linear_projection=use_linear_projection, 275 | only_cross_attention=only_cross_attention[i], 276 | upcast_attention=upcast_attention, 277 | resnet_time_scale_shift=resnet_time_scale_shift, 278 | 279 | use_inflated_groupnorm=True, 280 | 281 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 282 | motion_module_type=motion_module_type, 283 | motion_module_kwargs=motion_module_kwargs, 284 | ) 285 | self.down_blocks.append(down_block) 286 | 287 | for _ in range(layers_per_block): 288 | controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) 289 | controlnet_block = zero_module(controlnet_block) 290 | self.controlnet_down_blocks.append(controlnet_block) 291 | 292 | if not is_final_block: 293 | controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) 294 | controlnet_block = zero_module(controlnet_block) 295 | self.controlnet_down_blocks.append(controlnet_block) 296 | 297 | # mid 298 | mid_block_channel = block_out_channels[-1] 299 | 300 | controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1) 301 | controlnet_block = zero_module(controlnet_block) 302 | self.controlnet_mid_block = controlnet_block 303 | 304 | self.mid_block = UNetMidBlock3DCrossAttn( 305 | in_channels=mid_block_channel, 306 | temb_channels=time_embed_dim, 307 | resnet_eps=norm_eps, 308 | resnet_act_fn=act_fn, 309 | output_scale_factor=mid_block_scale_factor, 310 | resnet_time_scale_shift=resnet_time_scale_shift, 311 | cross_attention_dim=cross_attention_dim, 312 | attn_num_head_channels=num_attention_heads[-1], 313 | resnet_groups=norm_num_groups, 314 | use_linear_projection=use_linear_projection, 315 | upcast_attention=upcast_attention, 316 | 317 | use_inflated_groupnorm=True, 318 | use_motion_module=use_motion_module and motion_module_mid_block, 319 | motion_module_type=motion_module_type, 320 | motion_module_kwargs=motion_module_kwargs, 321 | ) 322 | 323 | @classmethod 324 | def from_unet( 325 | cls, 326 | unet: UNet2DConditionModel, 327 | controlnet_conditioning_channel_order: str = "rgb", 328 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 329 | load_weights_from_unet: bool = True, 330 | 331 | controlnet_additional_kwargs: dict = {}, 332 | ): 333 | controlnet = cls( 334 | in_channels=unet.config.in_channels, 335 | flip_sin_to_cos=unet.config.flip_sin_to_cos, 336 | freq_shift=unet.config.freq_shift, 337 | down_block_types=unet.config.down_block_types, 338 | only_cross_attention=unet.config.only_cross_attention, 339 | block_out_channels=unet.config.block_out_channels, 340 | layers_per_block=unet.config.layers_per_block, 341 | downsample_padding=unet.config.downsample_padding, 342 | mid_block_scale_factor=unet.config.mid_block_scale_factor, 343 | act_fn=unet.config.act_fn, 344 | norm_num_groups=unet.config.norm_num_groups, 345 | norm_eps=unet.config.norm_eps, 346 | cross_attention_dim=unet.config.cross_attention_dim, 347 | attention_head_dim=unet.config.attention_head_dim, 348 | num_attention_heads=unet.config.num_attention_heads, 349 | use_linear_projection=unet.config.use_linear_projection, 350 | class_embed_type=unet.config.class_embed_type, 351 | num_class_embeds=unet.config.num_class_embeds, 352 | upcast_attention=unet.config.upcast_attention, 353 | resnet_time_scale_shift=unet.config.resnet_time_scale_shift, 354 | projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, 355 | controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, 356 | conditioning_embedding_out_channels=conditioning_embedding_out_channels, 357 | 358 | **controlnet_additional_kwargs, 359 | ) 360 | 361 | if load_weights_from_unet: 362 | m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False) 363 | assert len(u) == 0 364 | m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False) 365 | assert len(u) == 0 366 | m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False) 367 | assert len(u) == 0 368 | 369 | if controlnet.class_embedding: 370 | m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False) 371 | assert len(u) == 0 372 | m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False) 373 | assert len(u) == 0 374 | m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False) 375 | assert len(u) == 0 376 | 377 | return controlnet 378 | 379 | @staticmethod 380 | def image_layer_filter(state_dict): 381 | new_state_dict = {} 382 | for name, param in state_dict.items(): 383 | if "motion_modules." in name or "lora" in name: continue 384 | new_state_dict[name] = param 385 | return new_state_dict 386 | 387 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice 388 | def set_attention_slice(self, slice_size): 389 | r""" 390 | Enable sliced attention computation. 391 | 392 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 393 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 394 | 395 | Args: 396 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 397 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 398 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 399 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 400 | must be a multiple of `slice_size`. 401 | """ 402 | sliceable_head_dims = [] 403 | 404 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 405 | if hasattr(module, "set_attention_slice"): 406 | sliceable_head_dims.append(module.sliceable_head_dim) 407 | 408 | for child in module.children(): 409 | fn_recursive_retrieve_sliceable_dims(child) 410 | 411 | # retrieve number of attention layers 412 | for module in self.children(): 413 | fn_recursive_retrieve_sliceable_dims(module) 414 | 415 | num_sliceable_layers = len(sliceable_head_dims) 416 | 417 | if slice_size == "auto": 418 | # half the attention head size is usually a good trade-off between 419 | # speed and memory 420 | slice_size = [dim // 2 for dim in sliceable_head_dims] 421 | elif slice_size == "max": 422 | # make smallest slice possible 423 | slice_size = num_sliceable_layers * [1] 424 | 425 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 426 | 427 | if len(slice_size) != len(sliceable_head_dims): 428 | raise ValueError( 429 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 430 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 431 | ) 432 | 433 | for i in range(len(slice_size)): 434 | size = slice_size[i] 435 | dim = sliceable_head_dims[i] 436 | if size is not None and size > dim: 437 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 438 | 439 | # Recursively walk through all the children. 440 | # Any children which exposes the set_attention_slice method 441 | # gets the message 442 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 443 | if hasattr(module, "set_attention_slice"): 444 | module.set_attention_slice(slice_size.pop()) 445 | 446 | for child in module.children(): 447 | fn_recursive_set_attention_slice(child, slice_size) 448 | 449 | reversed_slice_size = list(reversed(slice_size)) 450 | for module in self.children(): 451 | fn_recursive_set_attention_slice(module, reversed_slice_size) 452 | 453 | def _set_gradient_checkpointing(self, module, value=False): 454 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): 455 | module.gradient_checkpointing = value 456 | 457 | def forward( 458 | self, 459 | sample: torch.FloatTensor, 460 | timestep: Union[torch.Tensor, float, int], 461 | encoder_hidden_states: torch.Tensor, 462 | 463 | controlnet_cond: torch.FloatTensor, 464 | conditioning_mask: Optional[torch.FloatTensor] = None, 465 | 466 | conditioning_scale: float = 1.0, 467 | class_labels: Optional[torch.Tensor] = None, 468 | attention_mask: Optional[torch.Tensor] = None, 469 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 470 | guess_mode: bool = False, 471 | return_dict: bool = True, 472 | ) -> Union[SparseControlNetOutput, Tuple]: 473 | 474 | # set input noise to zero 475 | if self.set_noisy_sample_input_to_zero: 476 | sample = torch.zeros_like(sample).to(sample.device) 477 | 478 | # prepare attention_mask 479 | if attention_mask is not None: 480 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 481 | attention_mask = attention_mask.unsqueeze(1) 482 | 483 | # 1. time 484 | timesteps = timestep 485 | if not torch.is_tensor(timesteps): 486 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 487 | # This would be a good case for the `match` statement (Python 3.10+) 488 | is_mps = sample.device.type == "mps" 489 | if isinstance(timestep, float): 490 | dtype = torch.float32 if is_mps else torch.float64 491 | else: 492 | dtype = torch.int32 if is_mps else torch.int64 493 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 494 | elif len(timesteps.shape) == 0: 495 | timesteps = timesteps[None].to(sample.device) 496 | 497 | timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) 498 | encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1) 499 | 500 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 501 | timesteps = timesteps.expand(sample.shape[0]) 502 | 503 | t_emb = self.time_proj(timesteps) 504 | 505 | # timesteps does not contain any weights and will always return f32 tensors 506 | # but time_embedding might actually be running in fp16. so we need to cast here. 507 | # there might be better ways to encapsulate this. 508 | t_emb = t_emb.to(dtype=self.dtype) 509 | emb = self.time_embedding(t_emb) 510 | 511 | if self.class_embedding is not None: 512 | if class_labels is None: 513 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 514 | 515 | if self.config.class_embed_type == "timestep": 516 | class_labels = self.time_proj(class_labels) 517 | 518 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 519 | emb = emb + class_emb 520 | 521 | # 2. pre-process 522 | sample = self.conv_in(sample) 523 | if self.concate_conditioning_mask: 524 | controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) 525 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 526 | 527 | sample = sample + controlnet_cond 528 | 529 | # 3. down 530 | down_block_res_samples = (sample,) 531 | for downsample_block in self.down_blocks: 532 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 533 | sample, res_samples = downsample_block( 534 | hidden_states=sample, 535 | temb=emb, 536 | encoder_hidden_states=encoder_hidden_states, 537 | attention_mask=attention_mask, 538 | # cross_attention_kwargs=cross_attention_kwargs, 539 | ) 540 | else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 541 | 542 | down_block_res_samples += res_samples 543 | 544 | # 4. mid 545 | if self.mid_block is not None: 546 | sample = self.mid_block( 547 | sample, 548 | emb, 549 | encoder_hidden_states=encoder_hidden_states, 550 | attention_mask=attention_mask, 551 | # cross_attention_kwargs=cross_attention_kwargs, 552 | ) 553 | 554 | # 5. controlnet blocks 555 | controlnet_down_block_res_samples = () 556 | 557 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 558 | down_block_res_sample = controlnet_block(down_block_res_sample) 559 | controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) 560 | 561 | down_block_res_samples = controlnet_down_block_res_samples 562 | 563 | mid_block_res_sample = self.controlnet_mid_block(sample) 564 | 565 | # 6. scaling 566 | if guess_mode and not self.config.global_pool_conditions: 567 | scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 568 | 569 | scales = scales * conditioning_scale 570 | down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] 571 | mid_block_res_sample = mid_block_res_sample * scales[-1] # last one 572 | else: 573 | down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] 574 | mid_block_res_sample = mid_block_res_sample * conditioning_scale 575 | 576 | if self.config.global_pool_conditions: 577 | down_block_res_samples = [ 578 | torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples 579 | ] 580 | mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) 581 | 582 | if not return_dict: 583 | return (down_block_res_samples, mid_block_res_sample) 584 | 585 | return SparseControlNetOutput( 586 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample 587 | ) 588 | 589 | 590 | def zero_module(module): 591 | for p in module.parameters(): 592 | nn.init.zeros_(p) 593 | return module 594 | -------------------------------------------------------------------------------- /animatediff/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Changes were made to this source code by Yuwei Guo. 17 | """ Conversion script for the LoRA's safetensors checkpoints. """ 18 | 19 | import argparse 20 | 21 | import torch 22 | from safetensors.torch import load_file 23 | 24 | from diffusers import StableDiffusionPipeline 25 | 26 | 27 | def load_diffusers_lora(pipeline, state_dict, alpha=1.0): 28 | # directly update weight in diffusers model 29 | for key in state_dict: 30 | # only process lora down key 31 | if "up." in key: continue 32 | 33 | up_key = key.replace(".down.", ".up.") 34 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 35 | model_key = model_key.replace("to_out.", "to_out.0.") 36 | layer_infos = model_key.split(".")[:-1] 37 | 38 | curr_layer = pipeline.unet 39 | while len(layer_infos) > 0: 40 | temp_name = layer_infos.pop(0) 41 | curr_layer = curr_layer.__getattr__(temp_name) 42 | 43 | weight_down = state_dict[key] 44 | weight_up = state_dict[up_key] 45 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 46 | 47 | return pipeline 48 | 49 | 50 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 51 | # load base model 52 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 53 | 54 | # load LoRA weight from .safetensors 55 | # state_dict = load_file(checkpoint_path) 56 | 57 | visited = [] 58 | 59 | # directly update weight in diffusers model 60 | for key in state_dict: 61 | # it is suggested to print out the key, it usually will be something like below 62 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 63 | 64 | # as we have set the alpha beforehand, so just skip 65 | if ".alpha" in key or key in visited: 66 | continue 67 | 68 | if "text" in key: 69 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 70 | curr_layer = pipeline.text_encoder 71 | else: 72 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 73 | curr_layer = pipeline.unet 74 | 75 | # find the target layer 76 | temp_name = layer_infos.pop(0) 77 | while len(layer_infos) > -1: 78 | try: 79 | curr_layer = curr_layer.__getattr__(temp_name) 80 | if len(layer_infos) > 0: 81 | temp_name = layer_infos.pop(0) 82 | elif len(layer_infos) == 0: 83 | break 84 | except Exception: 85 | if len(temp_name) > 0: 86 | temp_name += "_" + layer_infos.pop(0) 87 | else: 88 | temp_name = layer_infos.pop(0) 89 | 90 | pair_keys = [] 91 | if "lora_down" in key: 92 | pair_keys.append(key.replace("lora_down", "lora_up")) 93 | pair_keys.append(key) 94 | else: 95 | pair_keys.append(key) 96 | pair_keys.append(key.replace("lora_up", "lora_down")) 97 | 98 | # update weight 99 | if len(state_dict[pair_keys[0]].shape) == 4: 100 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 101 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 102 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 103 | else: 104 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 105 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 106 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 107 | 108 | # update visited list 109 | for item in pair_keys: 110 | visited.append(item) 111 | 112 | return pipeline 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | 118 | parser.add_argument( 119 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 120 | ) 121 | parser.add_argument( 122 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 123 | ) 124 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 125 | parser.add_argument( 126 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 127 | ) 128 | parser.add_argument( 129 | "--lora_prefix_text_encoder", 130 | default="lora_te", 131 | type=str, 132 | help="The prefix of text encoder weight in safetensors", 133 | ) 134 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 135 | parser.add_argument( 136 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 137 | ) 138 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 139 | 140 | args = parser.parse_args() 141 | 142 | base_model_path = args.base_model_path 143 | checkpoint_path = args.checkpoint_path 144 | dump_path = args.dump_path 145 | lora_prefix_unet = args.lora_prefix_unet 146 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 147 | alpha = args.alpha 148 | 149 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 150 | 151 | pipe = pipe.to(args.device) 152 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 153 | -------------------------------------------------------------------------------- /animatediff/utils/convert_original_stable_diffusion_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Conversion script for the LDM checkpoints.""" 16 | 17 | import argparse 18 | import importlib 19 | 20 | import torch 21 | 22 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt 23 | 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument( 29 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 30 | ) 31 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml 32 | parser.add_argument( 33 | "--original_config_file", 34 | default=None, 35 | type=str, 36 | help="The YAML config file corresponding to the original architecture.", 37 | ) 38 | parser.add_argument( 39 | "--config_files", 40 | default=None, 41 | type=str, 42 | help="The YAML config file corresponding to the architecture.", 43 | ) 44 | parser.add_argument( 45 | "--num_in_channels", 46 | default=None, 47 | type=int, 48 | help="The number of input channels. If `None` number of input channels will be automatically inferred.", 49 | ) 50 | parser.add_argument( 51 | "--scheduler_type", 52 | default="pndm", 53 | type=str, 54 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", 55 | ) 56 | parser.add_argument( 57 | "--pipeline_type", 58 | default=None, 59 | type=str, 60 | help=( 61 | "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" 62 | ". If `None` pipeline will be automatically inferred." 63 | ), 64 | ) 65 | parser.add_argument( 66 | "--image_size", 67 | default=None, 68 | type=int, 69 | help=( 70 | "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" 71 | " Base. Use 768 for Stable Diffusion v2." 72 | ), 73 | ) 74 | parser.add_argument( 75 | "--prediction_type", 76 | default=None, 77 | type=str, 78 | help=( 79 | "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" 80 | " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." 81 | ), 82 | ) 83 | parser.add_argument( 84 | "--extract_ema", 85 | action="store_true", 86 | help=( 87 | "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" 88 | " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" 89 | " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." 90 | ), 91 | ) 92 | parser.add_argument( 93 | "--upcast_attention", 94 | action="store_true", 95 | help=( 96 | "Whether the attention computation should always be upcasted. This is necessary when running stable" 97 | " diffusion 2.1." 98 | ), 99 | ) 100 | parser.add_argument( 101 | "--from_safetensors", 102 | action="store_true", 103 | help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", 104 | ) 105 | parser.add_argument( 106 | "--to_safetensors", 107 | action="store_true", 108 | help="Whether to store pipeline in safetensors format or not.", 109 | ) 110 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 111 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 112 | parser.add_argument( 113 | "--stable_unclip", 114 | type=str, 115 | default=None, 116 | required=False, 117 | help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", 118 | ) 119 | parser.add_argument( 120 | "--stable_unclip_prior", 121 | type=str, 122 | default=None, 123 | required=False, 124 | help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", 125 | ) 126 | parser.add_argument( 127 | "--clip_stats_path", 128 | type=str, 129 | help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", 130 | required=False, 131 | ) 132 | parser.add_argument( 133 | "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." 134 | ) 135 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 136 | parser.add_argument( 137 | "--vae_path", 138 | type=str, 139 | default=None, 140 | required=False, 141 | help="Set to a path, hub id to an already converted vae to not convert it again.", 142 | ) 143 | parser.add_argument( 144 | "--pipeline_class_name", 145 | type=str, 146 | default=None, 147 | required=False, 148 | help="Specify the pipeline class name", 149 | ) 150 | 151 | args = parser.parse_args() 152 | 153 | if args.pipeline_class_name is not None: 154 | library = importlib.import_module("diffusers") 155 | class_obj = getattr(library, args.pipeline_class_name) 156 | pipeline_class = class_obj 157 | else: 158 | pipeline_class = None 159 | 160 | pipe = download_from_original_stable_diffusion_ckpt( 161 | checkpoint_path_or_dict=args.checkpoint_path, 162 | original_config_file=args.original_config_file, 163 | config_files=args.config_files, 164 | image_size=args.image_size, 165 | prediction_type=args.prediction_type, 166 | model_type=args.pipeline_type, 167 | extract_ema=args.extract_ema, 168 | scheduler_type=args.scheduler_type, 169 | num_in_channels=args.num_in_channels, 170 | upcast_attention=args.upcast_attention, 171 | from_safetensors=args.from_safetensors, 172 | device=args.device, 173 | stable_unclip=args.stable_unclip, 174 | stable_unclip_prior=args.stable_unclip_prior, 175 | clip_stats_path=args.clip_stats_path, 176 | controlnet=args.controlnet, 177 | vae_path=args.vae_path, 178 | pipeline_class=pipeline_class, 179 | ) 180 | 181 | if args.half: 182 | pipe.to(dtype=torch.float16) 183 | 184 | if args.controlnet: 185 | # only save the controlnet model 186 | pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 187 | else: 188 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) -------------------------------------------------------------------------------- /animatediff/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | import torch.distributed as dist 10 | 11 | from safetensors import safe_open 12 | from tqdm import tqdm 13 | from einops import rearrange 14 | from .convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 15 | from .convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora 16 | 17 | 18 | def zero_rank_print(s): 19 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 20 | from typing import List 21 | import PIL 22 | def export_to_video( 23 | video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 24 | ) -> str: 25 | # if output_video_path is None: 26 | # output_video_path = tempfile.NamedTemporaryFile(suffix=".webm").name 27 | 28 | if isinstance(video_frames[0], PIL.Image.Image): 29 | video_frames = [np.array(frame) for frame in video_frames] 30 | 31 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 32 | # fourcc = cv2.VideoWriter_fourcc(*'VP90') 33 | h, w, c = video_frames[0].shape 34 | video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) 35 | for i in range(len(video_frames)): 36 | img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) 37 | video_writer.write(img) 38 | 39 | return output_video_path 40 | 41 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=9): 42 | videos = rearrange(videos, "b c t h w -> t b c h w") 43 | outputs = [] 44 | for x in videos: 45 | x = torchvision.utils.make_grid(x, nrow=n_rows) 46 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 47 | if rescale: 48 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 49 | x = (x * 255).numpy().astype(np.uint8) 50 | outputs.append(x) 51 | os.makedirs(os.path.dirname(path), exist_ok=True) 52 | # export_to_video(outputs, output_video_path=path, fps=fps) 53 | 54 | imageio.mimsave(path, outputs, fps=fps) 55 | 56 | 57 | # DDIM Inversion 58 | @torch.no_grad() 59 | def init_prompt(prompt, pipeline): 60 | uncond_input = pipeline.tokenizer( 61 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 62 | return_tensors="pt" 63 | ) 64 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 65 | text_input = pipeline.tokenizer( 66 | [prompt], 67 | padding="max_length", 68 | max_length=pipeline.tokenizer.model_max_length, 69 | truncation=True, 70 | return_tensors="pt", 71 | ) 72 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 73 | context = torch.cat([uncond_embeddings, text_embeddings]) 74 | 75 | return context 76 | 77 | 78 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 79 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 80 | timestep, next_timestep = min( 81 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 82 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 83 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 84 | beta_prod_t = 1 - alpha_prod_t 85 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 86 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 87 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 88 | return next_sample 89 | 90 | 91 | def get_noise_pred_single(latents, t, context, unet): 92 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 93 | return noise_pred 94 | 95 | 96 | @torch.no_grad() 97 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 98 | context = init_prompt(prompt, pipeline) 99 | uncond_embeddings, cond_embeddings = context.chunk(2) 100 | all_latent = [latent] 101 | latent = latent.clone().detach() 102 | for i in tqdm(range(num_inv_steps)): 103 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 104 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 105 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 106 | all_latent.append(latent) 107 | return all_latent 108 | 109 | 110 | @torch.no_grad() 111 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 112 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 113 | return ddim_latents 114 | 115 | def load_weights( 116 | animation_pipeline, 117 | # motion module 118 | motion_module_path = "", 119 | motion_module_lora_configs = [], 120 | # domain adapter 121 | adapter_lora_path = "", 122 | adapter_lora_scale = 1.0, 123 | # image layers 124 | dreambooth_model_path = "", 125 | lora_model_path = "", 126 | lora_alpha = 0.8, 127 | ): 128 | # motion module 129 | unet_state_dict = {} 130 | if motion_module_path != "": 131 | print(f"load motion module from {motion_module_path}") 132 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 133 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 134 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 135 | unet_state_dict.pop("animatediff_config", "") 136 | 137 | missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 138 | print("motion_module missing:",len(missing)) 139 | print("motion_module unexpe:",len(unexpected)) 140 | assert len(unexpected) == 0 141 | del unet_state_dict 142 | 143 | # base model 144 | # if dreambooth_model_path != "": 145 | # print(f"load dreambooth model from {dreambooth_model_path}") 146 | # # if dreambooth_model_path.endswith(".safetensors"): 147 | # # dreambooth_state_dict = {} 148 | # # with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 149 | # # for key in f.keys(): 150 | # # dreambooth_state_dict[key] = f.get_tensor(key) 151 | # # elif dreambooth_model_path.endswith(".ckpt"): 152 | # # dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 153 | 154 | # # # 1. vae 155 | # # converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 156 | # # animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 157 | # # # 2. unet 158 | # # converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 159 | # # animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 160 | # # # 3. text_model 161 | # # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 162 | # # del dreambooth_state_dict 163 | # dreambooth_state_dict = {} 164 | # with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 165 | # for key in f.keys(): 166 | # dreambooth_state_dict[key] = f.get_tensor(key) 167 | 168 | # converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 169 | # # print(vae) 170 | # #vae ->to_q,to_k,to_v 171 | # # print(converted_vae_checkpoint) 172 | # convert_vae_keys = list(converted_vae_checkpoint.keys()) 173 | # for key in convert_vae_keys: 174 | # if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in key: 175 | # new_key = None 176 | # if "key" in key: 177 | # new_key = key.replace("key","to_k") 178 | # elif "query" in key: 179 | # new_key = key.replace("query","to_q") 180 | # elif "value" in key: 181 | # new_key = key.replace("value","to_v") 182 | # elif "proj_attn" in key: 183 | # new_key = key.replace("proj_attn","to_out.0") 184 | # if new_key: 185 | # converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key) 186 | 187 | # animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 188 | 189 | # converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 190 | # animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 191 | 192 | # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 193 | # del dreambooth_state_dict 194 | # lora layers 195 | if lora_model_path != "": 196 | print(f"load lora model from {lora_model_path}") 197 | assert lora_model_path.endswith(".safetensors") 198 | lora_state_dict = {} 199 | with safe_open(lora_model_path, framework="pt", device="cpu") as f: 200 | for key in f.keys(): 201 | lora_state_dict[key] = f.get_tensor(key) 202 | 203 | animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 204 | del lora_state_dict 205 | 206 | # domain adapter lora 207 | if adapter_lora_path != "": 208 | print(f"load domain lora from {adapter_lora_path}") 209 | domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") 210 | domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict 211 | domain_lora_state_dict.pop("animatediff_config", "") 212 | 213 | animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) 214 | 215 | # motion module lora 216 | for motion_module_lora_config in motion_module_lora_configs: 217 | path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 218 | print(f"load motion LoRA from {path}") 219 | motion_lora_state_dict = torch.load(path, map_location="cpu") 220 | motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 221 | motion_lora_state_dict.pop("animatediff_config", "") 222 | 223 | animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) 224 | 225 | return animation_pipeline 226 | -------------------------------------------------------------------------------- /demo/ComfyUI_ID_Animator.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_ID_Animator/005fc24083280862f861db92faec7f886cbe55e8/demo/ComfyUI_ID_Animator.gif -------------------------------------------------------------------------------- /demo/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 10, 3 | "last_link_id": 10, 4 | "nodes": [ 5 | { 6 | "id": 8, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 636, 10 | 109 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 314 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 8 25 | ], 26 | "shape": 3, 27 | "label": "图像", 28 | "slot_index": 0 29 | }, 30 | { 31 | "name": "MASK", 32 | "type": "MASK", 33 | "links": null, 34 | "shape": 3, 35 | "label": "遮罩" 36 | } 37 | ], 38 | "properties": { 39 | "Node name for S&R": "LoadImage" 40 | }, 41 | "widgets_values": [ 42 | "gakki.png", 43 | "image" 44 | ] 45 | }, 46 | { 47 | "id": 7, 48 | "type": "ID_Repo_Choice", 49 | "pos": [ 50 | 715, 51 | 529 52 | ], 53 | "size": { 54 | "0": 315, 55 | "1": 130 56 | }, 57 | "flags": {}, 58 | "order": 1, 59 | "mode": 0, 60 | "outputs": [ 61 | { 62 | "name": "repo_id", 63 | "type": "STRING", 64 | "links": [ 65 | 10 66 | ], 67 | "shape": 3, 68 | "label": "repo_id", 69 | "slot_index": 0 70 | } 71 | ], 72 | "properties": { 73 | "Node name for S&R": "ID_Repo_Choice" 74 | }, 75 | "widgets_values": [ 76 | "stable-diffusion-v1-5", 77 | "", 78 | "1SD1.5\\DreamShaper_8_pruned.safetensors", 79 | "mm_sd_v15_v2.ckpt" 80 | ] 81 | }, 82 | { 83 | "id": 10, 84 | "type": "ID_Animator", 85 | "pos": [ 86 | 1133, 87 | 354 88 | ], 89 | "size": [ 90 | 423.5861962980666, 91 | 494.66712251607464 92 | ], 93 | "flags": {}, 94 | "order": 2, 95 | "mode": 0, 96 | "inputs": [ 97 | { 98 | "name": "image", 99 | "type": "IMAGE", 100 | "link": 8, 101 | "label": "image" 102 | }, 103 | { 104 | "name": "repo_id", 105 | "type": "STRING", 106 | "link": 10, 107 | "widget": { 108 | "name": "repo_id" 109 | } 110 | } 111 | ], 112 | "outputs": [ 113 | { 114 | "name": "image", 115 | "type": "IMAGE", 116 | "links": [ 117 | 9 118 | ], 119 | "shape": 3, 120 | "label": "image", 121 | "slot_index": 0 122 | } 123 | ], 124 | "properties": { 125 | "Node name for S&R": "ID_Animator" 126 | }, 127 | "widgets_values": [ 128 | "", 129 | "Iron Man soars through the clouds, his repulsors blazing", 130 | "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 131 | "DDIM", 132 | "v3_sd15_adapter.ckpt", 133 | 1, 134 | "2d3dstyle\\3DMM_V11.safetensors", 135 | 0.8, 136 | 30, 137 | 251618243807904, 138 | "randomize", 139 | 8, 140 | 512, 141 | 512, 142 | 16, 143 | 0.8 144 | ] 145 | }, 146 | { 147 | "id": 9, 148 | "type": "SaveImage", 149 | "pos": [ 150 | 1651, 151 | 435 152 | ], 153 | "size": { 154 | "0": 315, 155 | "1": 270 156 | }, 157 | "flags": {}, 158 | "order": 3, 159 | "mode": 0, 160 | "inputs": [ 161 | { 162 | "name": "images", 163 | "type": "IMAGE", 164 | "link": 9, 165 | "label": "图像" 166 | } 167 | ], 168 | "properties": {}, 169 | "widgets_values": [ 170 | "ComfyUI" 171 | ] 172 | } 173 | ], 174 | "links": [ 175 | [ 176 | 8, 177 | 8, 178 | 0, 179 | 10, 180 | 0, 181 | "IMAGE" 182 | ], 183 | [ 184 | 9, 185 | 10, 186 | 0, 187 | 9, 188 | 0, 189 | "IMAGE" 190 | ], 191 | [ 192 | 10, 193 | 7, 194 | 0, 195 | 10, 196 | 1, 197 | "STRING" 198 | ] 199 | ], 200 | "groups": [], 201 | "config": {}, 202 | "extra": { 203 | "ds": { 204 | "scale": 1.2100000000000004, 205 | "offset": [ 206 | -434.98589577214676, 207 | -11.138948211041045 208 | ] 209 | } 210 | }, 211 | "version": 0.4 212 | } -------------------------------------------------------------------------------- /demo/lecun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_ID_Animator/005fc24083280862f861db92faec7f886cbe55e8/demo/lecun.png -------------------------------------------------------------------------------- /faceadapter/attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttnProcessor(nn.Module): 8 | r""" 9 | Default processor for performing attention-related computations. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | hidden_size=None, 15 | cross_attention_dim=None, 16 | ): 17 | super().__init__() 18 | 19 | def __call__( 20 | self, 21 | attn, 22 | hidden_states, 23 | encoder_hidden_states=None, 24 | attention_mask=None, 25 | temb=None, 26 | ): 27 | residual = hidden_states 28 | 29 | if attn.spatial_norm is not None: 30 | hidden_states = attn.spatial_norm(hidden_states, temb) 31 | 32 | input_ndim = hidden_states.ndim 33 | 34 | if input_ndim == 4: 35 | batch_size, channel, height, width = hidden_states.shape 36 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 37 | 38 | batch_size, sequence_length, _ = ( 39 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 40 | ) 41 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 42 | 43 | if attn.group_norm is not None: 44 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 45 | 46 | query = attn.to_q(hidden_states) 47 | 48 | if encoder_hidden_states is None: 49 | encoder_hidden_states = hidden_states 50 | elif attn.norm_cross: 51 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 52 | 53 | key = attn.to_k(encoder_hidden_states) 54 | value = attn.to_v(encoder_hidden_states) 55 | 56 | query = attn.head_to_batch_dim(query) 57 | key = attn.head_to_batch_dim(key) 58 | value = attn.head_to_batch_dim(value) 59 | 60 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 61 | hidden_states = torch.bmm(attention_probs, value) 62 | hidden_states = attn.batch_to_head_dim(hidden_states) 63 | 64 | # linear proj 65 | hidden_states = attn.to_out[0](hidden_states) 66 | # dropout 67 | hidden_states = attn.to_out[1](hidden_states) 68 | 69 | if input_ndim == 4: 70 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 71 | 72 | if attn.residual_connection: 73 | hidden_states = hidden_states + residual 74 | 75 | hidden_states = hidden_states / attn.rescale_output_factor 76 | 77 | return hidden_states 78 | 79 | class AttnProcessor2_0(torch.nn.Module): 80 | r""" 81 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 82 | """ 83 | 84 | def __init__( 85 | self, 86 | hidden_size=None, 87 | cross_attention_dim=None, 88 | ): 89 | super().__init__() 90 | if not hasattr(F, "scaled_dot_product_attention"): 91 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 92 | 93 | def __call__( 94 | self, 95 | attn, 96 | hidden_states, 97 | encoder_hidden_states=None, 98 | attention_mask=None, 99 | temb=None, 100 | ): 101 | residual = hidden_states 102 | 103 | if attn.spatial_norm is not None: 104 | hidden_states = attn.spatial_norm(hidden_states, temb) 105 | 106 | input_ndim = hidden_states.ndim 107 | 108 | if input_ndim == 4: 109 | batch_size, channel, height, width = hidden_states.shape 110 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 111 | 112 | batch_size, sequence_length, _ = ( 113 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 114 | ) 115 | 116 | if attention_mask is not None: 117 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 118 | # scaled_dot_product_attention expects attention_mask shape to be 119 | # (batch, heads, source_length, target_length) 120 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 121 | 122 | if attn.group_norm is not None: 123 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 124 | 125 | query = attn.to_q(hidden_states) 126 | 127 | if encoder_hidden_states is None: 128 | encoder_hidden_states = hidden_states 129 | elif attn.norm_cross: 130 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 131 | 132 | key = attn.to_k(encoder_hidden_states) 133 | value = attn.to_v(encoder_hidden_states) 134 | 135 | inner_dim = key.shape[-1] 136 | head_dim = inner_dim // attn.heads 137 | 138 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 139 | 140 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 141 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 142 | 143 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 144 | # TODO: add support for attn.scale when we move to Torch 2.1 145 | hidden_states = F.scaled_dot_product_attention( 146 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 147 | ) 148 | 149 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 150 | hidden_states = hidden_states.to(query.dtype) 151 | 152 | # linear proj 153 | hidden_states = attn.to_out[0](hidden_states) 154 | # dropout 155 | hidden_states = attn.to_out[1](hidden_states) 156 | 157 | if input_ndim == 4: 158 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 159 | 160 | if attn.residual_connection: 161 | hidden_states = hidden_states + residual 162 | 163 | hidden_states = hidden_states / attn.rescale_output_factor 164 | 165 | return hidden_states 166 | 167 | import torch 168 | import torch.nn as nn 169 | import torch.nn.functional as F 170 | 171 | from diffusers.models.lora import LoRALinearLayer 172 | 173 | 174 | 175 | from torch.utils.checkpoint import checkpoint 176 | 177 | class LoRAFaceAttnProcessor(nn.Module): 178 | r""" 179 | Attention processor for Face-Adapater. 180 | Args: 181 | hidden_size (`int`): 182 | The hidden size of the attention layer. 183 | cross_attention_dim (`int`): 184 | The number of channels in the `encoder_hidden_states`. 185 | scale (`float`, defaults to 1.0): 186 | the weight scale of image prompt. 187 | num_tokens (`int`, defaults to 4 when do face_adapter_plus it should be 16): 188 | The context length of the image features. 189 | """ 190 | 191 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): 192 | super().__init__() 193 | 194 | self.rank = rank 195 | self.lora_scale = lora_scale 196 | 197 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 198 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 199 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 200 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 201 | 202 | self.hidden_size = hidden_size 203 | self.cross_attention_dim = cross_attention_dim 204 | self.scale = scale 205 | self.num_tokens = num_tokens 206 | 207 | self.to_k_face = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 208 | self.to_v_face = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 209 | self.gradient_checkpointing=False 210 | def __call__( 211 | self, 212 | attn, 213 | hidden_states, 214 | encoder_hidden_states=None, 215 | attention_mask=None, 216 | temb=None, 217 | ): 218 | if self.gradient_checkpointing and self.training: 219 | def create_custom_forward(module, return_dict=None): 220 | def custom_forward(*inputs): 221 | if return_dict is not None: 222 | return module(*inputs, return_dict=return_dict) 223 | else: 224 | return module(*inputs) 225 | 226 | return custom_forward 227 | if attn.spatial_norm is not None: 228 | hidden_states = attn.spatial_norm(hidden_states, temb) 229 | 230 | input_ndim = hidden_states.ndim 231 | 232 | if input_ndim == 4: 233 | batch_size, channel, height, width = hidden_states.shape 234 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 235 | 236 | batch_size, sequence_length, _ = ( 237 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 238 | ) 239 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 240 | 241 | if attn.group_norm is not None: 242 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 243 | query = attn.to_q(hidden_states) + self.lora_scale *torch.utils.checkpoint.checkpoint( 244 | create_custom_forward(self.to_q_lora), 245 | hidden_states ) 246 | if encoder_hidden_states is None: 247 | encoder_hidden_states = hidden_states 248 | else: 249 | # get encoder_hidden_states, face_hidden_states 250 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 251 | encoder_hidden_states, face_hidden_states = ( 252 | encoder_hidden_states[:, :end_pos, :], 253 | encoder_hidden_states[:, end_pos:, :], 254 | ) 255 | if attn.norm_cross: 256 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 257 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * torch.utils.checkpoint.checkpoint( 258 | create_custom_forward(self.to_k_lora), 259 | encoder_hidden_states ) 260 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * torch.utils.checkpoint.checkpoint( 261 | create_custom_forward(self.to_v_lora), 262 | encoder_hidden_states ) 263 | query = attn.head_to_batch_dim(query) 264 | key = attn.head_to_batch_dim(key) 265 | value = attn.head_to_batch_dim(value) 266 | 267 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 268 | hidden_states = torch.bmm(attention_probs, value) 269 | hidden_states = attn.batch_to_head_dim(hidden_states) 270 | 271 | # for face-adapter 272 | face_key = self.to_k_face(face_hidden_states) 273 | face_value = self.to_v_face(face_hidden_states) 274 | 275 | face_key = attn.head_to_batch_dim(face_key) 276 | face_value = attn.head_to_batch_dim(face_value) 277 | 278 | face_attention_probs = attn.get_attention_scores(query, face_key, None) 279 | self.attn_map = face_attention_probs 280 | face_hidden_states = torch.bmm(face_attention_probs, face_value) 281 | face_hidden_states = attn.batch_to_head_dim(face_hidden_states) 282 | 283 | hidden_states = hidden_states + self.scale * face_hidden_states 284 | 285 | # linear proj 286 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) 287 | # dropout 288 | hidden_states = attn.to_out[1](hidden_states) 289 | 290 | if input_ndim == 4: 291 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 292 | 293 | if attn.residual_connection: 294 | hidden_states = hidden_states + residual 295 | 296 | hidden_states = hidden_states / attn.rescale_output_factor 297 | 298 | else: 299 | if attn.spatial_norm is not None: 300 | hidden_states = attn.spatial_norm(hidden_states, temb) 301 | 302 | input_ndim = hidden_states.ndim 303 | 304 | if input_ndim == 4: 305 | batch_size, channel, height, width = hidden_states.shape 306 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 307 | 308 | batch_size, sequence_length, _ = ( 309 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 310 | ) 311 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 312 | 313 | if attn.group_norm is not None: 314 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 315 | 316 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) 317 | 318 | if encoder_hidden_states is None: 319 | encoder_hidden_states = hidden_states 320 | else: 321 | # get encoder_hidden_states, face_hidden_states 322 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 323 | encoder_hidden_states, face_hidden_states = ( 324 | encoder_hidden_states[:, :end_pos, :], 325 | encoder_hidden_states[:, end_pos:, :], 326 | ) 327 | if attn.norm_cross: 328 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 329 | 330 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) 331 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) 332 | 333 | query = attn.head_to_batch_dim(query) 334 | key = attn.head_to_batch_dim(key) 335 | value = attn.head_to_batch_dim(value) 336 | 337 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 338 | hidden_states = torch.bmm(attention_probs, value) 339 | hidden_states = attn.batch_to_head_dim(hidden_states) 340 | 341 | # for face-adapter 342 | face_key = self.to_k_face(face_hidden_states) 343 | face_value = self.to_v_face(face_hidden_states) 344 | 345 | face_key = attn.head_to_batch_dim(face_key) 346 | face_value = attn.head_to_batch_dim(face_value) 347 | 348 | face_attention_probs = attn.get_attention_scores(query, face_key, None) 349 | self.attn_map = face_attention_probs 350 | face_hidden_states = torch.bmm(face_attention_probs, face_value) 351 | face_hidden_states = attn.batch_to_head_dim(face_hidden_states) 352 | 353 | hidden_states = hidden_states + self.scale * face_hidden_states 354 | 355 | # linear proj 356 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) 357 | # dropout 358 | hidden_states = attn.to_out[1](hidden_states) 359 | 360 | if input_ndim == 4: 361 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 362 | 363 | if attn.residual_connection: 364 | hidden_states = hidden_states + residual 365 | 366 | hidden_states = hidden_states / attn.rescale_output_factor 367 | return hidden_states 368 | -------------------------------------------------------------------------------- /faceadapter/face_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | from .attention_processor import LoRAFaceAttnProcessor 11 | 12 | from .utils import is_torch2_available, get_generator 13 | 14 | if is_torch2_available(): 15 | from .attention_processor import ( 16 | AttnProcessor2_0 as AttnProcessor, 17 | ) 18 | else: 19 | from .attention_processor import AttnProcessor 20 | from .resampler import Resampler 21 | 22 | 23 | class ImageProjModel(torch.nn.Module): 24 | """Projection Model""" 25 | 26 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 27 | super().__init__() 28 | 29 | self.generator = None 30 | self.cross_attention_dim = cross_attention_dim 31 | self.clip_extra_context_tokens = clip_extra_context_tokens 32 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 33 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 34 | 35 | def forward(self, image_embeds): 36 | embeds = image_embeds 37 | clip_extra_context_tokens = self.proj(embeds).reshape( 38 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 39 | ) 40 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 41 | return clip_extra_context_tokens 42 | 43 | 44 | class MLPProjModel(torch.nn.Module): 45 | """SD model with image prompt""" 46 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): 47 | super().__init__() 48 | 49 | self.proj = torch.nn.Sequential( 50 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), 51 | torch.nn.GELU(), 52 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), 53 | torch.nn.LayerNorm(cross_attention_dim) 54 | ) 55 | 56 | def forward(self, image_embeds): 57 | clip_extra_context_tokens = self.proj(image_embeds) 58 | return clip_extra_context_tokens 59 | 60 | 61 | class FaceAdapterLora: 62 | def __init__(self, sd_pipe, image_encoder_path, id_ckpt, device, num_tokens=4,torch_type=torch.float32): 63 | self.device = device 64 | self.image_encoder_path = image_encoder_path 65 | self.id_ckpt = id_ckpt 66 | self.num_tokens = num_tokens 67 | self.torch_type = torch_type 68 | 69 | self.pipe = sd_pipe.to(self.device) 70 | self.set_face_adapter() 71 | # load image encoder 72 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 73 | self.device, dtype=self.torch_type 74 | ) 75 | self.clip_image_processor = CLIPImageProcessor() 76 | # image proj model 77 | self.image_proj_model = self.init_proj() 78 | 79 | self.load_face_adapter() 80 | 81 | def init_proj(self): 82 | image_proj_model = ImageProjModel( 83 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 84 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 85 | clip_extra_context_tokens=self.num_tokens, 86 | ).to(self.device, dtype=self.torch_type) 87 | return image_proj_model 88 | 89 | def set_face_adapter(self): 90 | unet = self.pipe.unet 91 | attn_procs = {} 92 | for name in unet.attn_processors.keys(): 93 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 94 | if name.startswith("mid_block"): 95 | hidden_size = unet.config.block_out_channels[-1] 96 | elif name.startswith("up_blocks"): 97 | block_id = int(name[len("up_blocks.")]) 98 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 99 | elif name.startswith("down_blocks"): 100 | block_id = int(name[len("down_blocks.")]) 101 | hidden_size = unet.config.block_out_channels[block_id] 102 | if cross_attention_dim is None: 103 | attn_procs[name] = AttnProcessor().to(self.device, dtype=self.torch_type) 104 | else: 105 | attn_procs[name] = LoRAFaceAttnProcessor( 106 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=128, num_tokens=self.num_tokens, 107 | ).to(self.device, dtype=self.torch_type) 108 | unet.set_attn_processor(attn_procs) 109 | def load_face_adapter(self): 110 | state_dict = torch.load(self.id_ckpt, map_location="cpu") 111 | if 'state_dict' in state_dict: 112 | state_dict = state_dict['state_dict'] 113 | image_proj_dict={} 114 | face_adapter_proj={} 115 | for k,v in state_dict.items(): 116 | if k.startswith("module.image_proj_model"): 117 | image_proj_dict[k.replace("module.image_proj_model.", "")] = state_dict[k] 118 | elif k.startswith("module.adapter_modules."): 119 | face_adapter_proj[k.replace("module.adapter_modules.", "")] = state_dict[k] 120 | elif k.startswith("image_proj_model"): 121 | image_proj_dict[k.replace("image_proj_model.", "")] = state_dict[k] 122 | elif k.startswith("adapter_modules."): 123 | face_adapter_proj[k.replace("adapter_modules.", "")] = state_dict[k] 124 | else: 125 | print("ERROR!") 126 | return 127 | state_dict = {} 128 | state_dict['image_proj'] = image_proj_dict 129 | state_dict["face_adapter"] = face_adapter_proj 130 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 131 | adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 132 | adapter_layers.load_state_dict(state_dict["face_adapter"],strict=False) 133 | @torch.inference_mode() 134 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 135 | if pil_image is not None: 136 | if isinstance(pil_image, Image.Image): 137 | pil_image = [pil_image] 138 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 139 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=self.torch_type)).image_embeds 140 | else: 141 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.torch_type) 142 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 143 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 144 | return image_prompt_embeds, uncond_image_prompt_embeds 145 | 146 | def set_scale(self, scale): 147 | for attn_processor in self.pipe.unet.attn_processors.values(): 148 | if isinstance(attn_processor, LoRAFaceAttnProcessor): 149 | attn_processor.scale = scale 150 | 151 | def generate( 152 | self, 153 | pil_image=None, 154 | clip_image_embeds=None, 155 | prompt=None, 156 | negative_prompt=None, 157 | scale=1, 158 | num_samples=4, 159 | seed=None, 160 | guidance_scale=7.5, 161 | num_inference_steps=30, 162 | **kwargs, 163 | ): 164 | self.set_scale(scale) 165 | 166 | if pil_image is not None: 167 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 168 | else: 169 | num_prompts = clip_image_embeds.size(0) 170 | 171 | if prompt is None: 172 | prompt = "best quality, high quality" 173 | if negative_prompt is None: 174 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 175 | 176 | if not isinstance(prompt, List): 177 | prompt = [prompt] * num_prompts 178 | if not isinstance(negative_prompt, List): 179 | negative_prompt = [negative_prompt] * num_prompts 180 | 181 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 182 | pil_image=pil_image, clip_image_embeds=clip_image_embeds 183 | ) 184 | bs_embed, seq_len, _ = image_prompt_embeds.shape 185 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 186 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 187 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 188 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 189 | 190 | with torch.inference_mode(): 191 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 192 | prompt, 193 | device=self.device, 194 | num_images_per_prompt=num_samples, 195 | do_classifier_free_guidance=True, 196 | negative_prompt=negative_prompt, 197 | ) 198 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 199 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 200 | 201 | generator = get_generator(seed, self.device) 202 | 203 | images = self.pipe( 204 | prompt_embeds=prompt_embeds, 205 | negative_prompt_embeds=negative_prompt_embeds, 206 | guidance_scale=guidance_scale, 207 | num_inference_steps=num_inference_steps, 208 | generator=generator, 209 | **kwargs, 210 | ).images 211 | 212 | return images 213 | 214 | 215 | class FaceAdapterPlusForVideoLora(FaceAdapterLora): 216 | def init_proj(self): 217 | image_proj_model = Resampler( 218 | dim=self.pipe.unet.config.cross_attention_dim, 219 | depth=4, 220 | dim_head=64, 221 | heads=12, 222 | num_queries=self.num_tokens, 223 | embedding_dim=self.image_encoder.config.hidden_size, 224 | output_dim=self.pipe.unet.config.cross_attention_dim, 225 | ff_mult=4, 226 | ).to(self.device, dtype=self.torch_type) 227 | return image_proj_model 228 | 229 | @torch.inference_mode() 230 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 231 | if isinstance(pil_image, Image.Image): 232 | pil_image = [pil_image] 233 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 234 | clip_image = clip_image.to(self.device, dtype=self.torch_type) 235 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 236 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 237 | uncond_clip_image_embeds = self.image_encoder( 238 | torch.zeros_like(clip_image), output_hidden_states=True 239 | ).hidden_states[-2] 240 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 241 | return image_prompt_embeds, uncond_image_prompt_embeds 242 | 243 | def generate( 244 | self, 245 | pil_image=None, 246 | pil_image2=None, 247 | clip_image_embeds=None, 248 | prompt=None, 249 | negative_prompt=None, 250 | scale=1.0, 251 | num_samples=1, 252 | seed=None, 253 | guidance_scale=7.5, 254 | num_inference_steps=30, 255 | width=512, 256 | height=512, 257 | video_length=16, 258 | image_scale=0, 259 | controlnet_images: torch.FloatTensor = None, 260 | controlnet_image_index: list = [0], 261 | **kwargs, 262 | ): 263 | self.set_scale(scale) 264 | num_prompts=1 265 | 266 | if prompt is None: 267 | prompt = "best quality, high quality" 268 | if negative_prompt is None: 269 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 270 | 271 | if not isinstance(prompt, List): 272 | prompt = [prompt] * num_prompts 273 | if not isinstance(negative_prompt, List): 274 | negative_prompt = [negative_prompt] * num_prompts 275 | num_prompt_img =len(pil_image) 276 | total_image_prompt_embeds = 0 277 | for i in range(num_prompt_img): 278 | prompt_img = pil_image[i] 279 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 280 | pil_image=prompt_img, clip_image_embeds=clip_image_embeds 281 | ) 282 | bs_embed, seq_len, _ = image_prompt_embeds.shape 283 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 284 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 285 | total_image_prompt_embeds += image_prompt_embeds 286 | total_image_prompt_embeds/=num_prompt_img 287 | image_prompt_embeds = total_image_prompt_embeds 288 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 289 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 290 | with torch.inference_mode(): 291 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 292 | prompt, 293 | device=self.device, 294 | num_videos_per_prompt=num_samples, 295 | do_classifier_free_guidance=True, 296 | negative_prompt=negative_prompt, 297 | ) 298 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 299 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 300 | 301 | generator = get_generator(seed, self.device) 302 | 303 | video = self.pipe( 304 | prompt = "", 305 | prompt_embeds = prompt_embeds, 306 | negative_prompt_embeds=negative_prompt_embeds, 307 | guidance_scale=guidance_scale, 308 | num_inference_steps=num_inference_steps, 309 | generator=generator, 310 | width = width, 311 | height=height, 312 | video_length = video_length, 313 | controlnet_images = controlnet_images, 314 | controlnet_image_index=controlnet_image_index, 315 | **kwargs, 316 | ).videos 317 | 318 | return video 319 | 320 | def generate_video_edit( 321 | self, 322 | pil_image=None, 323 | clip_image_embeds=None, 324 | prompt=None, 325 | negative_prompt=None, 326 | scale=1.0, 327 | num_samples=1, 328 | seed=None, 329 | guidance_scale=7.5, 330 | num_inference_steps=30, 331 | width=512, 332 | height=512, 333 | video_length=16, 334 | video_latents=None, 335 | **kwargs, 336 | ): 337 | self.set_scale(scale) 338 | 339 | if pil_image is not None: 340 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 341 | else: 342 | num_prompts = clip_image_embeds.size(0) 343 | 344 | if prompt is None: 345 | prompt = "best quality, high quality" 346 | if negative_prompt is None: 347 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 348 | 349 | if not isinstance(prompt, List): 350 | prompt = [prompt] * num_prompts 351 | if not isinstance(negative_prompt, List): 352 | negative_prompt = [negative_prompt] * num_prompts 353 | 354 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 355 | pil_image=pil_image, clip_image_embeds=clip_image_embeds 356 | ) 357 | bs_embed, seq_len, _ = image_prompt_embeds.shape 358 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 359 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 360 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 361 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 362 | with torch.inference_mode(): 363 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 364 | prompt, 365 | device=self.device, 366 | num_videos_per_prompt=num_samples, 367 | do_classifier_free_guidance=True, 368 | negative_prompt=negative_prompt, 369 | ) 370 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 371 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 372 | 373 | generator = get_generator(seed, self.device) 374 | 375 | video = self.pipe.video_edit( 376 | prompt = "", 377 | prompt_embeds = prompt_embeds, 378 | negative_prompt_embeds=negative_prompt_embeds, 379 | guidance_scale=guidance_scale, 380 | num_inference_steps=num_inference_steps, 381 | generator=generator, 382 | width = width, 383 | height=height, 384 | video_length = video_length, 385 | latents=video_latents, 386 | **kwargs, 387 | ).videos 388 | 389 | return video -------------------------------------------------------------------------------- /faceadapter/init.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /faceadapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | # FFN 13 | def FeedForward(dim, mult=4): 14 | inner_dim = int(dim * mult) 15 | return nn.Sequential( 16 | nn.LayerNorm(dim), 17 | nn.Linear(dim, inner_dim, bias=False), 18 | nn.GELU(), 19 | nn.Linear(inner_dim, dim, bias=False), 20 | ) 21 | 22 | 23 | def reshape_tensor(x, heads): 24 | bs, length, width = x.shape 25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 26 | x = x.view(bs, length, heads, -1) 27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 28 | x = x.transpose(1, 2) 29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 30 | x = x.reshape(bs, heads, length, -1) 31 | return x 32 | 33 | 34 | class PerceiverAttention(nn.Module): 35 | def __init__(self, *, dim, dim_head=64, heads=8): 36 | super().__init__() 37 | self.scale = dim_head**-0.5 38 | self.dim_head = dim_head 39 | self.heads = heads 40 | inner_dim = dim_head * heads 41 | 42 | self.norm1 = nn.LayerNorm(dim) 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 47 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | 82 | class ResamplerFaceID(nn.Module): 83 | def __init__( 84 | self, 85 | dim=1024, 86 | depth=8, 87 | dim_head=64, 88 | heads=16, 89 | num_queries=8, 90 | embedding_dim=768, 91 | output_dim=1024, 92 | ff_mult=4, 93 | max_seq_len: int = 257, # CLIP tokens + CLS token 94 | apply_pos_emb: bool = False, 95 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 96 | ): 97 | super().__init__() 98 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None 99 | 100 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 101 | 102 | self.proj_in = nn.Linear(embedding_dim, dim) 103 | 104 | self.proj_out = nn.Linear(dim, output_dim) 105 | self.norm_out = nn.LayerNorm(output_dim) 106 | 107 | self.to_latents_from_mean_pooled_seq = ( 108 | nn.Sequential( 109 | nn.LayerNorm(dim), 110 | nn.Linear(dim, dim * num_latents_mean_pooled), 111 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 112 | ) 113 | if num_latents_mean_pooled > 0 114 | else None 115 | ) 116 | 117 | self.layers = nn.ModuleList([]) 118 | for _ in range(depth): 119 | self.layers.append( 120 | nn.ModuleList( 121 | [ 122 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 123 | FeedForward(dim=dim, mult=ff_mult), 124 | ] 125 | ) 126 | ) 127 | 128 | def forward(self, x): 129 | if self.pos_emb is not None: 130 | n, device = x.shape[1], x.device 131 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 132 | x = x + pos_emb 133 | #(x:2,1,512) 134 | #(latents:2,16,768) 135 | # x= x.unsqueeze(1) 136 | # x= x.repeat(x.size(0), latents.size(1), 1) 137 | latents = self.latents.repeat(x.size(0), 1, 1) 138 | 139 | x = self.proj_in(x) 140 | 141 | if self.to_latents_from_mean_pooled_seq: 142 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 143 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 144 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 145 | 146 | for attn, ff in self.layers: 147 | latents = attn(x, latents) + latents 148 | latents = ff(latents) + latents 149 | latents = self.proj_out(latents) 150 | return self.norm_out(latents) 151 | 152 | 153 | class Resampler(nn.Module): 154 | def __init__( 155 | self, 156 | dim=1024, 157 | depth=8, 158 | dim_head=64, 159 | heads=16, 160 | num_queries=8, 161 | embedding_dim=768, 162 | output_dim=1024, 163 | ff_mult=4, 164 | max_seq_len: int = 257, # CLIP tokens + CLS token 165 | apply_pos_emb: bool = False, 166 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 167 | ): 168 | super().__init__() 169 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None 170 | 171 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 172 | 173 | self.proj_in = nn.Linear(embedding_dim, dim) 174 | 175 | self.proj_out = nn.Linear(dim, output_dim) 176 | self.norm_out = nn.LayerNorm(output_dim) 177 | 178 | self.to_latents_from_mean_pooled_seq = ( 179 | nn.Sequential( 180 | nn.LayerNorm(dim), 181 | nn.Linear(dim, dim * num_latents_mean_pooled), 182 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 183 | ) 184 | if num_latents_mean_pooled > 0 185 | else None 186 | ) 187 | 188 | self.layers = nn.ModuleList([]) 189 | for _ in range(depth): 190 | self.layers.append( 191 | nn.ModuleList( 192 | [ 193 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 194 | FeedForward(dim=dim, mult=ff_mult), 195 | ] 196 | ) 197 | ) 198 | 199 | def forward(self, x): 200 | if self.pos_emb is not None: 201 | n, device = x.shape[1], x.device 202 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 203 | x = x + pos_emb 204 | latents = self.latents.repeat(x.size(0), 1, 1) 205 | 206 | x = self.proj_in(x) 207 | 208 | if self.to_latents_from_mean_pooled_seq: 209 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 210 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 211 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 212 | 213 | for attn, ff in self.layers: 214 | latents = attn(x, latents) + latents 215 | latents = ff(latents) + latents 216 | latents = self.proj_out(latents) 217 | return self.norm_out(latents) 218 | 219 | 220 | def masked_mean(t, *, dim, mask=None): 221 | if mask is None: 222 | return t.mean(dim=dim) 223 | 224 | denom = mask.sum(dim=dim, keepdim=True) 225 | mask = rearrange(mask, "b n -> b n 1") 226 | masked_t = t.masked_fill(~mask, 0.0) 227 | 228 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 229 | -------------------------------------------------------------------------------- /faceadapter/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from PIL import Image 5 | 6 | attn_maps = {} 7 | def hook_fn(name): 8 | def forward_hook(module, input, output): 9 | if hasattr(module.processor, "attn_map"): 10 | attn_maps[name] = module.processor.attn_map 11 | del module.processor.attn_map 12 | 13 | return forward_hook 14 | 15 | def register_cross_attention_hook(unet): 16 | for name, module in unet.named_modules(): 17 | if name.split('.')[-1].startswith('attn2'): 18 | module.register_forward_hook(hook_fn(name)) 19 | 20 | return unet 21 | 22 | def upscale(attn_map, target_size): 23 | attn_map = torch.mean(attn_map, dim=0) 24 | attn_map = attn_map.permute(1,0) 25 | temp_size = None 26 | 27 | for i in range(0,5): 28 | scale = 2 ** i 29 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: 30 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) 31 | break 32 | 33 | assert temp_size is not None, "temp_size cannot is None" 34 | 35 | attn_map = attn_map.view(attn_map.shape[0], *temp_size) 36 | 37 | attn_map = F.interpolate( 38 | attn_map.unsqueeze(0).to(dtype=torch.float32), 39 | size=target_size, 40 | mode='bilinear', 41 | align_corners=False 42 | )[0] 43 | 44 | attn_map = torch.softmax(attn_map, dim=0) 45 | return attn_map 46 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): 47 | 48 | idx = 0 if instance_or_negative else 1 49 | net_attn_maps = [] 50 | 51 | for name, attn_map in attn_maps.items(): 52 | attn_map = attn_map.cpu() if detach else attn_map 53 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() 54 | attn_map = upscale(attn_map, image_size) 55 | net_attn_maps.append(attn_map) 56 | 57 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) 58 | 59 | return net_attn_maps 60 | 61 | def attnmaps2images(net_attn_maps): 62 | 63 | #total_attn_scores = 0 64 | images = [] 65 | 66 | for attn_map in net_attn_maps: 67 | attn_map = attn_map.cpu().numpy() 68 | #total_attn_scores += attn_map.mean().item() 69 | 70 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 71 | normalized_attn_map = normalized_attn_map.astype(np.uint8) 72 | #print("norm: ", normalized_attn_map.shape) 73 | image = Image.fromarray(normalized_attn_map) 74 | 75 | #image = fix_save_attn_map(attn_map) 76 | images.append(image) 77 | 78 | #print(total_attn_scores) 79 | return images 80 | def is_torch2_available(): 81 | return hasattr(F, "scaled_dot_product_attention") 82 | 83 | def get_generator(seed, device): 84 | 85 | if seed is not None: 86 | if isinstance(seed, list): 87 | generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] 88 | else: 89 | generator = torch.Generator(device).manual_seed(seed) 90 | else: 91 | generator = None 92 | 93 | return generator -------------------------------------------------------------------------------- /if miss module check this requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.34.0 2 | imageio-ffmpeg==0.4.9 3 | accelerate==0.27.2 4 | addict==2.4.0 5 | einops==0.6.1 6 | einops-exts==0.0.4 7 | ffmpeg-python==0.2.0 8 | imageio-ffmpeg==0.4.9 9 | transformers==4.31.0 10 | torch==2.2.0 11 | torchvision==0.17.0 12 | xformers==0.0.24 13 | wandb==0.15.12 14 | yapf==0.40.2 15 | insightface==0.7.3 16 | omegaconf==2.3.0 17 | onnx==1.15.0 18 | onnxruntime 19 | opencv-python==4.9.0.80 20 | diffusers 21 | -------------------------------------------------------------------------------- /inference-v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: [1,2,4,8] 7 | motion_module_mid_block: true 8 | motion_module_decoder_only: false 9 | motion_module_type: "Vanilla" 10 | 11 | motion_module_kwargs: 12 | num_attention_heads: 8 13 | num_transformer_block: 1 14 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 15 | temporal_position_encoding: true 16 | temporal_position_encoding_max_len: 32 17 | temporal_attention_dim_div: 1 18 | 19 | noise_scheduler_kwargs: 20 | beta_start: 0.00085 21 | beta_end: 0.012 22 | beta_schedule: "linear" 23 | steps_offset: 1 24 | clip_sample: False 25 | -------------------------------------------------------------------------------- /models/adapter/put adapter file here: -------------------------------------------------------------------------------- 1 | put adapter file here 2 | -------------------------------------------------------------------------------- /models/animatediff_models/put animatediff_models here: -------------------------------------------------------------------------------- 1 | put animatediff_models here 2 | -------------------------------------------------------------------------------- /models/image_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "./image_encoder", 3 | "architectures": [ 4 | "CLIPVisionModelWithProjection" 5 | ], 6 | "attention_dropout": 0.0, 7 | "dropout": 0.0, 8 | "hidden_act": "gelu", 9 | "hidden_size": 1280, 10 | "image_size": 224, 11 | "initializer_factor": 1.0, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 5120, 14 | "layer_norm_eps": 1e-05, 15 | "model_type": "clip_vision_model", 16 | "num_attention_heads": 16, 17 | "num_channels": 3, 18 | "num_hidden_layers": 32, 19 | "patch_size": 14, 20 | "projection_dim": 1024, 21 | "torch_dtype": "float16", 22 | "transformers_version": "4.28.0.dev0" 23 | } 24 | -------------------------------------------------------------------------------- /models/image_encoder/put image_encoder model here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-large-patch14", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 2, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.22.0.dev0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /models/text_encoder/put text_encoder model here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_id_animator" 3 | description = "you can using id_animator in comfyUI" 4 | version = "1.0.0" 5 | license = { file = "LICENSE" } 6 | 7 | [project.urls] 8 | Repository = "https://github.com/smthemex/ComfyUI_ID_Animator" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "smthemx" 13 | DisplayName = "ComfyUI_ID_Animator" 14 | Icon = "" 15 | --------------------------------------------------------------------------------