├── README.md ├── __init__.py ├── assets ├── Aragaki.png ├── audio2video_workflow.json ├── face_reenacment_workflow.json ├── lyl.wav ├── pose2video_workflow.json ├── pose_ref_video.mp4 ├── solo.png └── woman.jpg ├── configs ├── inference │ ├── inference_audio.yaml │ └── inference_v2.yaml └── prompts │ ├── animation.yaml │ ├── animation_audio.yaml │ └── animation_facereenac.yaml ├── nodes.py ├── pyproject.toml ├── requirements.txt └── src ├── __init__.py ├── __pycache__ └── __init__.cpython-310.pyc ├── audio_models ├── __pycache__ │ ├── model.cpython-310.pyc │ ├── torch_utils.cpython-310.pyc │ └── wav2vec2.cpython-310.pyc ├── model.py ├── pose_model.py ├── torch_utils.py └── wav2vec2.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── attention.cpython-310.pyc │ ├── motion_module.cpython-310.pyc │ ├── mutual_self_attention.cpython-310.pyc │ ├── pose_guider.cpython-310.pyc │ ├── resnet.cpython-310.pyc │ ├── transformer_2d.cpython-310.pyc │ ├── transformer_3d.cpython-310.pyc │ ├── unet_2d_blocks.cpython-310.pyc │ ├── unet_2d_condition.cpython-310.pyc │ ├── unet_3d.cpython-310.pyc │ └── unet_3d_blocks.cpython-310.pyc ├── attention.py ├── motion_module.py ├── mutual_self_attention.py ├── pose_guider.py ├── resnet.py ├── transformer_2d.py ├── transformer_3d.py ├── unet_2d_blocks.py ├── unet_2d_condition.py ├── unet_3d.py └── unet_3d_blocks.py ├── pipelines ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── context.cpython-310.pyc │ ├── pipeline_pose2vid_long.cpython-310.pyc │ └── utils.cpython-310.pyc ├── context.py ├── pipeline_pose2vid.py ├── pipeline_pose2vid_long.py └── utils.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── audio_util.cpython-310.pyc ├── draw_util.cpython-310.pyc ├── face_landmark.cpython-310.pyc ├── logger.cpython-310.pyc ├── mp_utils.cpython-310.pyc ├── pose_util.cpython-310.pyc └── util.cpython-310.pyc ├── audio_util.py ├── draw_util.py ├── face_landmark.py ├── frame_interpolation.py ├── logger.py ├── mp_models ├── blaze_face_short_range.tflite ├── face_landmarker_v2_with_blendshapes.task └── pose_landmarker_heavy.task ├── mp_utils.py ├── pose_util.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | #### Updates: 2 | ① Implement the frame_interpolation to speed up generation 3 | 4 | ② Modify the current code and support chain with the [VHS nodes](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite), i just found that comfyUI IMAGE type requires the torch float32 datatype, and AniPortrait heavily used numpy of image unit8 datatype,so i just changed my mind from my own image/video upload and generation nodes to the prevelance SOTA VHS image/video upload and video combined nodes,it WYSIWYG and inteactive well and instantly render the result 5 | - ✅ [2024/04/09] raw video to pose video with reference image(aka self-driven) 6 | - ✅ [2024/04/09] audio driven 7 | - ✅ [2024/04/09] face reenacment 8 | - ✅ [2024/04/22] implement audio2pose model and [pre-trained weight](https://huggingface.co/ZJYang/AniPortrait/tree/main) for audio2video,the face reenacment and audio2video workflow has been modified, currently inference up to a maximum length of 10 seconds has been supported,you can experiment with the length hyperparameter. 9 | 10 | U can contact me thr ![twitter_1](https://github.com/frankchieng/ComfyUI_Aniportrait/assets/130369523/27b4fcae-e50c-477d-86f4-dacf7fd052f4)[twitter](https://twitter.com/kurtqian) ![wechat_1](https://github.com/frankchieng/ComfyUI_Aniportrait/assets/130369523/b95cd0a2-4188-4eb3-b1de-5f6eeab71045) Weixin:GalaticKing 11 | 12 | 13 | ### audio driven combined with reference image and reference video 14 | ![截图 2024-08-30 12-04-53](https://github.com/user-attachments/assets/10b73c50-a046-41d5-abd1-5ea40a23ad3a) 15 | [audio2video workflow](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/audio2video_workflow.json) 16 | 17 | 18 | 21 | 22 |
19 | 20 |
23 | 24 | ### raw video to pose video with reference image 25 | ![pose2video](https://github.com/frankchieng/ComfyUI_Aniportrait/assets/130369523/882e3685-ee13-4798-9f90-d195d6595a97) 26 | 27 | 28 | 31 | 32 |
29 | 30 |
33 | 34 | ### face reenacment 35 | ![face_reenacment](https://github.com/frankchieng/ComfyUI_Aniportrait/assets/130369523/82f2ae7c-b7c2-49a7-8f13-4456ebff55e6) 36 | [video2video workflow](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/face_reenacment_workflow.json) 37 | 38 | 39 | 42 | 43 |
40 | 41 |
44 | 45 | This is unofficial implementation of AniPortrait in ComfyUI custom_node,cuz i have routine jobs,so i will update this project when i have time 46 | > [Aniportrait_pose2video.json](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/pose2video_workflow.json) 47 | 48 | > [Audio driven](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/audio2video_workflow.json) 49 | 50 | > [face reenacment](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/face_reenacment_workflow.json) 51 | 52 | you should run 53 | ```shell 54 | git clone https://github.com/frankchieng/ComfyUI_Aniportrait.git 55 | ``` 56 | then run 57 | ```shell 58 | pip install -r requirements.txt 59 | ``` 60 | download the pretrained models 61 | > [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) 62 | 63 | > [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 64 | 65 | > [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder) 66 | 67 | > [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) 68 | 69 | download the weights: 70 | > [denoising_unet.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main) 71 | > [reference_unet.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main) 72 | > [pose_guider.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main) 73 | > [motion_module.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main) 74 | > [audio2mesh.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main) 75 | > [audio2pose.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main) 76 | > [film_net_fp16.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main) 77 | ```text 78 | ./pretrained_model/ 79 | |-- image_encoder 80 | | |-- config.json 81 | | `-- pytorch_model.bin 82 | |-- sd-vae-ft-mse 83 | | |-- config.json 84 | | |-- diffusion_pytorch_model.bin 85 | | `-- diffusion_pytorch_model.safetensors 86 | |-- stable-diffusion-v1-5 87 | | |-- feature_extractor 88 | | | `-- preprocessor_config.json 89 | | |-- model_index.json 90 | | |-- unet 91 | | | |-- config.json 92 | | | `-- diffusion_pytorch_model.bin 93 | | `-- v1-inference.yaml 94 | |-- wav2vec2-base-960h 95 | | |-- config.json 96 | | |-- feature_extractor_config.json 97 | | |-- preprocessor_config.json 98 | | |-- pytorch_model.bin 99 | | |-- README.md 100 | | |-- special_tokens_map.json 101 | | |-- tokenizer_config.json 102 | | `-- vocab.json 103 | |-- audio2mesh.pt 104 | |-- audio2pose.pt 105 | |-- denoising_unet.pth 106 | |-- motion_module.pth 107 | |-- pose_guider.pth 108 | |-- reference_unet.pth 109 | |-- film_net_fp16.pt 110 | ``` 111 | 112 | Tips : 113 | The intermediate audio file will be generated and deleted,the raw video to pose video with audio and pose2video mp4 file will be located in the output directory of ComfyUI 114 | the original uploaded mp4 video requires square size like 512x512, otherwise the result will be weird 115 | #### I've updated diffusers from 0.24.x to 0.26.2,so the diffusers/models/embeddings.py classname of PositionNet changed to GLIGENTextBoundingboxProjection and CaptionProjection changed to PixArtAlphaTextProjection,you should pay attention to it and modify the corresponding python files like src/models/transformer_2d.py if you installed the lower version of diffusers 116 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import folder_paths 2 | import os 3 | import ffmpeg 4 | from PIL import Image 5 | import cv2 6 | from tqdm import tqdm 7 | import re 8 | import torch 9 | from .nodes import PoseGenVideo, RefImagePath, Audio2Video, AudioPath #,GenerateRefPose 10 | 11 | from .src.utils.util import get_fps, read_frames, save_videos_from_pil, calculate_file_hash, get_sorted_dir_files_from_directory, get_audio, lazy_eval, hash_path, validate_path 12 | import numpy as np 13 | from .src.utils.draw_util import FaceMeshVisualizer 14 | from .src.utils.mp_utils import LMKExtractor 15 | 16 | video_extensions = ['webm', 'mp4', 'mkv', 'gif'] 17 | 18 | class VideoGenPose: 19 | @classmethod 20 | def INPUT_TYPES(s): 21 | return { 22 | "required": { 23 | "image": ("IMAGE",), 24 | "filename_prefix": ("STRING", {"default": "AniPortrait"}), 25 | "height": ("INT", {"default": 512, "min": 0, "max": 1024, "step": 1}), 26 | "width": ("INT", {"default": 512, "min": 0, "max": 1024, "step": 1}), 27 | }, 28 | } 29 | 30 | RETURN_TYPES = ("IMAGE",) 31 | RETURN_NAMES = ("pose_images",) 32 | OUTPUT_NODE = True 33 | CATEGORY = "AniPortrait 🎥Video" 34 | FUNCTION = "generate_pose_video" 35 | 36 | def generate_pose_video(self, image, filename_prefix, height, width): 37 | 38 | frames = (image.numpy() * 255).astype(np.uint8) 39 | lmk_extractor = LMKExtractor() 40 | vis = FaceMeshVisualizer(forehead_edge=False) 41 | 42 | kps_results = [] 43 | for i, frame_pil in enumerate(tqdm(frames)): 44 | image_np = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR) 45 | image_np = cv2.resize(image_np, (height, width)) 46 | face_result = lmk_extractor(image_np) 47 | try: 48 | lmks = face_result['lmks'].astype(np.float32) 49 | pose_img = vis.draw_landmarks((image_np.shape[1], image_np.shape[0]), lmks, normed=True) 50 | pose_img = Image.fromarray(cv2.cvtColor(pose_img, cv2.COLOR_BGR2RGB)) 51 | except: 52 | pose_img = kps_results[-1] 53 | 54 | kps_results.append(pose_img) 55 | 56 | iterable = (x for x in kps_results) 57 | images = torch.from_numpy(np.fromiter(iterable, np.dtype((np.float32, (height, width, 3))))) 58 | return (images,) 59 | 60 | 61 | class LoadVideoPath: 62 | @classmethod 63 | def INPUT_TYPES(s): 64 | return { 65 | "required": { 66 | "video": ("STRING", {"default": "X://insert/path/here.mp4", "aniportrait_path_extensions": video_extensions}), 67 | }, 68 | } 69 | 70 | CATEGORY = "AniPortrait 🎥Video" 71 | 72 | RETURN_TYPES = ("AniPortrait_Video", "IMAGE", "Frame_per_second", "AniPortrait_Audio", ) 73 | RETURN_NAMES = ("video", "frames", "frame_per_second", "audio",) 74 | FUNCTION = "load_video" 75 | 76 | def load_video(self, **kwargs): 77 | if kwargs['video'] is None or validate_path(kwargs['video']) != True: 78 | raise Exception("video is not a valid path: " + kwargs['video']) 79 | return load_video_av(**kwargs) 80 | 81 | @classmethod 82 | def IS_CHANGED(s, video, **kwargs): 83 | return hash_path(video) 84 | 85 | @classmethod 86 | def VALIDATE_INPUTS(s, video, **kwargs): 87 | return validate_path(video, allow_none=True) 88 | 89 | 90 | def load_video_av(video: str): 91 | fps = get_fps(video) 92 | frames = read_frames(video) 93 | input_dir = folder_paths.get_output_directory() 94 | audio_output = os.path.join(input_dir, 'audio_from_video.aac') 95 | 96 | return (video, frames, fps, audio_output) 97 | 98 | NODE_CLASS_MAPPINGS = { 99 | "AniPortrait_Video_Gen_Pose": VideoGenPose, 100 | "AniPortrait_LoadVideoPath": LoadVideoPath, 101 | "AniPortrait_Pose_Gen_Video": PoseGenVideo, 102 | "AniPortrait_Ref_Image_Path": RefImagePath, 103 | # "AniPortrait_Generate_Ref_Pose": GenerateRefPose, 104 | "AniPortrait_Audio2Video": Audio2Video, 105 | "AniPortrait_Audio_Path": AudioPath, 106 | } 107 | 108 | NODE_DISPLAY_NAME_MAPPINGS = { 109 | "AniPortrait_Video_Gen_Pose": "Video MediaPipe Face Detection🎥AniPortrait", 110 | "AniPortrait_LoadVideoPath": "Load Video (Path) 🎥AniPortrait", 111 | "AniPortrait_Pose_Gen_Video": "Pose Generate Video 🎥AniPortrait", 112 | "AniPortrait_Ref_Image_Path": "Ref Image Path 🎥AniPortrait", 113 | # "AniPortrait_Generate_Ref_Pose": "Generate Ref Pose 🎥AniPortrait", 114 | "AniPortrait_Audio2Video": "Audio Gen Video 🎥AniPortrait", 115 | "AniPortrait_Audio_Path": "Audio Path 🎥AniPortrait", 116 | } 117 | -------------------------------------------------------------------------------- /assets/Aragaki.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/Aragaki.png -------------------------------------------------------------------------------- /assets/audio2video_workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 17, 3 | "last_link_id": 31, 4 | "nodes": [ 5 | { 6 | "id": 10, 7 | "type": "AniPortrait_Audio2Video", 8 | "pos": [ 9 | -1712, 10 | -63 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 506 15 | }, 16 | "flags": {}, 17 | "order": 3, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "ref_image", 22 | "type": "IMAGE", 23 | "link": 12, 24 | "label": "ref_image" 25 | }, 26 | { 27 | "name": "images", 28 | "type": "IMAGE", 29 | "link": 14, 30 | "label": "images" 31 | }, 32 | { 33 | "name": "audio_path", 34 | "type": "Audio_Path", 35 | "link": 15, 36 | "label": "audio_path" 37 | }, 38 | { 39 | "name": "fps", 40 | "type": "INT", 41 | "link": null, 42 | "widget": { 43 | "name": "fps" 44 | }, 45 | "label": "fps" 46 | } 47 | ], 48 | "outputs": [ 49 | { 50 | "name": "images", 51 | "type": "IMAGE", 52 | "links": [ 53 | 31 54 | ], 55 | "shape": 3, 56 | "label": "images" 57 | } 58 | ], 59 | "properties": { 60 | "Node name for S&R": "AniPortrait_Audio2Video" 61 | }, 62 | "widgets_values": [ 63 | 512, 64 | 512, 65 | 1748, 66 | "randomize", 67 | 3.5, 68 | 25, 69 | "pretrained_model/sd-vae-ft-mse", 70 | "pretrained_model/stable-diffusion-v1-5", 71 | "fp16", 72 | true, 73 | 0, 74 | 3, 75 | "pretrained_model/motion_module.pth", 76 | "pretrained_model/image_encoder", 77 | "pretrained_model/denoising_unet.pth", 78 | "pretrained_model/reference_unet.pth", 79 | "pretrained_model/pose_guider.pth", 80 | 0 81 | ] 82 | }, 83 | { 84 | "id": 1, 85 | "type": "VHS_LoadVideo", 86 | "pos": [ 87 | -2366, 88 | -60 89 | ], 90 | "size": [ 91 | 251.52520751953125, 92 | 507.52520751953125 93 | ], 94 | "flags": {}, 95 | "order": 0, 96 | "mode": 0, 97 | "inputs": [ 98 | { 99 | "name": "meta_batch", 100 | "type": "VHS_BatchManager", 101 | "link": null, 102 | "label": "meta_batch" 103 | }, 104 | { 105 | "name": "vae", 106 | "type": "VAE", 107 | "link": null, 108 | "label": "vae" 109 | } 110 | ], 111 | "outputs": [ 112 | { 113 | "name": "IMAGE", 114 | "type": "IMAGE", 115 | "links": [ 116 | 14 117 | ], 118 | "slot_index": 0, 119 | "shape": 3, 120 | "label": "IMAGE" 121 | }, 122 | { 123 | "name": "frame_count", 124 | "type": "INT", 125 | "links": null, 126 | "shape": 3, 127 | "label": "frame_count" 128 | }, 129 | { 130 | "name": "audio", 131 | "type": "VHS_AUDIO", 132 | "links": null, 133 | "shape": 3, 134 | "label": "audio" 135 | }, 136 | { 137 | "name": "video_info", 138 | "type": "VHS_VIDEOINFO", 139 | "links": [], 140 | "slot_index": 3, 141 | "shape": 3, 142 | "label": "video_info" 143 | } 144 | ], 145 | "properties": { 146 | "Node name for S&R": "VHS_LoadVideo" 147 | }, 148 | "widgets_values": { 149 | "video": "pose_ref_video.mp4", 150 | "force_rate": 0, 151 | "force_size": "Disabled", 152 | "custom_width": 512, 153 | "custom_height": 512, 154 | "frame_load_cap": 0, 155 | "skip_first_frames": 0, 156 | "select_every_nth": 1, 157 | "choose video to upload": "image", 158 | "videopreview": { 159 | "hidden": false, 160 | "paused": false, 161 | "params": { 162 | "frame_load_cap": 0, 163 | "skip_first_frames": 0, 164 | "force_rate": 0, 165 | "filename": "pose_ref_video.mp4", 166 | "type": "input", 167 | "format": "video/mp4", 168 | "select_every_nth": 1 169 | } 170 | } 171 | } 172 | }, 173 | { 174 | "id": 7, 175 | "type": "LoadImage", 176 | "pos": [ 177 | -2082, 178 | 122 179 | ], 180 | "size": { 181 | "0": 315, 182 | "1": 314 183 | }, 184 | "flags": {}, 185 | "order": 1, 186 | "mode": 0, 187 | "outputs": [ 188 | { 189 | "name": "IMAGE", 190 | "type": "IMAGE", 191 | "links": [ 192 | 12 193 | ], 194 | "slot_index": 0, 195 | "shape": 3, 196 | "label": "IMAGE" 197 | }, 198 | { 199 | "name": "MASK", 200 | "type": "MASK", 201 | "links": null, 202 | "shape": 3, 203 | "label": "MASK" 204 | } 205 | ], 206 | "properties": { 207 | "Node name for S&R": "LoadImage" 208 | }, 209 | "widgets_values": [ 210 | "man.jpg", 211 | "image" 212 | ] 213 | }, 214 | { 215 | "id": 8, 216 | "type": "AniPortrait_Audio_Path", 217 | "pos": [ 218 | -2114, 219 | -238 220 | ], 221 | "size": { 222 | "0": 315, 223 | "1": 102 224 | }, 225 | "flags": {}, 226 | "order": 2, 227 | "mode": 0, 228 | "outputs": [ 229 | { 230 | "name": "audio_path", 231 | "type": "Audio_Path", 232 | "links": [ 233 | 15 234 | ], 235 | "slot_index": 0, 236 | "shape": 3, 237 | "label": "audio_path" 238 | }, 239 | { 240 | "name": "audio", 241 | "type": "VHS_AUDIO", 242 | "links": [ 243 | 30 244 | ], 245 | "shape": 3, 246 | "label": "audio", 247 | "slot_index": 1 248 | } 249 | ], 250 | "properties": { 251 | "Node name for S&R": "AniPortrait_Audio_Path" 252 | }, 253 | "widgets_values": [ 254 | "/home/qm/test.wav", 255 | 0 256 | ] 257 | }, 258 | { 259 | "id": 17, 260 | "type": "VHS_VideoCombine", 261 | "pos": [ 262 | -1297, 263 | -16 264 | ], 265 | "size": [ 266 | 218.82891845703125, 267 | 523.3297469669985 268 | ], 269 | "flags": {}, 270 | "order": 5, 271 | "mode": 0, 272 | "inputs": [ 273 | { 274 | "name": "images", 275 | "type": "IMAGE", 276 | "link": 31, 277 | "label": "images" 278 | }, 279 | { 280 | "name": "audio", 281 | "type": "AUDIO", 282 | "link": 29, 283 | "label": "audio" 284 | }, 285 | { 286 | "name": "meta_batch", 287 | "type": "VHS_BatchManager", 288 | "link": null, 289 | "label": "meta_batch" 290 | }, 291 | { 292 | "name": "vae", 293 | "type": "VAE", 294 | "link": null, 295 | "label": "vae" 296 | } 297 | ], 298 | "outputs": [ 299 | { 300 | "name": "Filenames", 301 | "type": "VHS_FILENAMES", 302 | "links": null, 303 | "shape": 3, 304 | "label": "Filenames" 305 | } 306 | ], 307 | "properties": { 308 | "Node name for S&R": "VHS_VideoCombine" 309 | }, 310 | "widgets_values": { 311 | "frame_rate": 30, 312 | "loop_count": 0, 313 | "filename_prefix": "Aniportrait", 314 | "format": "video/h264-mp4", 315 | "pix_fmt": "yuv420p", 316 | "crf": 19, 317 | "save_metadata": true, 318 | "pingpong": false, 319 | "save_output": true, 320 | "videopreview": { 321 | "hidden": false, 322 | "paused": false, 323 | "params": { 324 | "filename": "Aniportrait_00002-audio.mp4", 325 | "subfolder": "", 326 | "type": "output", 327 | "format": "video/h264-mp4", 328 | "frame_rate": 30 329 | }, 330 | "muted": false 331 | } 332 | } 333 | }, 334 | { 335 | "id": 13, 336 | "type": "VHS_VHSAudioToAudio", 337 | "pos": [ 338 | -1682, 339 | -200 340 | ], 341 | "size": { 342 | "0": 304.67462158203125, 343 | "1": 26 344 | }, 345 | "flags": {}, 346 | "order": 4, 347 | "mode": 0, 348 | "inputs": [ 349 | { 350 | "name": "vhs_audio", 351 | "type": "VHS_AUDIO", 352 | "link": 30, 353 | "label": "vhs_audio" 354 | } 355 | ], 356 | "outputs": [ 357 | { 358 | "name": "audio", 359 | "type": "AUDIO", 360 | "links": [ 361 | 29 362 | ], 363 | "shape": 3, 364 | "label": "audio", 365 | "slot_index": 0 366 | } 367 | ], 368 | "properties": { 369 | "Node name for S&R": "VHS_VHSAudioToAudio" 370 | }, 371 | "widgets_values": {} 372 | } 373 | ], 374 | "links": [ 375 | [ 376 | 12, 377 | 7, 378 | 0, 379 | 10, 380 | 0, 381 | "IMAGE" 382 | ], 383 | [ 384 | 14, 385 | 1, 386 | 0, 387 | 10, 388 | 1, 389 | "IMAGE" 390 | ], 391 | [ 392 | 15, 393 | 8, 394 | 0, 395 | 10, 396 | 2, 397 | "Audio_Path" 398 | ], 399 | [ 400 | 29, 401 | 13, 402 | 0, 403 | 17, 404 | 1, 405 | "AUDIO" 406 | ], 407 | [ 408 | 30, 409 | 8, 410 | 1, 411 | 13, 412 | 0, 413 | "VHS_AUDIO" 414 | ], 415 | [ 416 | 31, 417 | 10, 418 | 0, 419 | 17, 420 | 0, 421 | "IMAGE" 422 | ] 423 | ], 424 | "groups": [], 425 | "config": {}, 426 | "extra": { 427 | "ds": { 428 | "scale": 1.061076460950001, 429 | "offset": [ 430 | 2608.218390205725, 431 | 532.0839393915426 432 | ] 433 | } 434 | }, 435 | "version": 0.4 436 | } 437 | -------------------------------------------------------------------------------- /assets/face_reenacment_workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 12, 3 | "last_link_id": 16, 4 | "nodes": [ 5 | { 6 | "id": 7, 7 | "type": "VHS_LoadVideo", 8 | "pos": [ 9 | -2279, 10 | 347 11 | ], 12 | "size": [ 13 | 235.1999969482422, 14 | 491.1999969482422 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "meta_batch", 22 | "type": "VHS_BatchManager", 23 | "link": null, 24 | "label": "meta_batch" 25 | } 26 | ], 27 | "outputs": [ 28 | { 29 | "name": "IMAGE", 30 | "type": "IMAGE", 31 | "links": [ 32 | 15 33 | ], 34 | "shape": 3, 35 | "label": "IMAGE", 36 | "slot_index": 0 37 | }, 38 | { 39 | "name": "frame_count", 40 | "type": "INT", 41 | "links": null, 42 | "shape": 3, 43 | "label": "frame_count" 44 | }, 45 | { 46 | "name": "audio", 47 | "type": "VHS_AUDIO", 48 | "links": [ 49 | 8 50 | ], 51 | "shape": 3, 52 | "label": "audio", 53 | "slot_index": 2 54 | }, 55 | { 56 | "name": "video_info", 57 | "type": "VHS_VIDEOINFO", 58 | "links": [ 59 | 9 60 | ], 61 | "shape": 3, 62 | "label": "video_info", 63 | "slot_index": 3 64 | } 65 | ], 66 | "properties": { 67 | "Node name for S&R": "VHS_LoadVideo" 68 | }, 69 | "widgets_values": { 70 | "video": "pose_ref_video.mp4", 71 | "force_rate": 0, 72 | "force_size": "Disabled", 73 | "custom_width": 512, 74 | "custom_height": 512, 75 | "frame_load_cap": 0, 76 | "skip_first_frames": 0, 77 | "select_every_nth": 1, 78 | "choose video to upload": "image", 79 | "videopreview": { 80 | "hidden": false, 81 | "paused": false, 82 | "params": { 83 | "frame_load_cap": 0, 84 | "skip_first_frames": 0, 85 | "force_rate": 0, 86 | "filename": "pose_ref_video.mp4", 87 | "type": "input", 88 | "format": "video/mp4", 89 | "select_every_nth": 1 90 | } 91 | } 92 | } 93 | }, 94 | { 95 | "id": 10, 96 | "type": "VHS_VideoInfo", 97 | "pos": [ 98 | -1987, 99 | 728 100 | ], 101 | "size": { 102 | "0": 393, 103 | "1": 206 104 | }, 105 | "flags": {}, 106 | "order": 2, 107 | "mode": 0, 108 | "inputs": [ 109 | { 110 | "name": "video_info", 111 | "type": "VHS_VIDEOINFO", 112 | "link": 9, 113 | "label": "video_info" 114 | } 115 | ], 116 | "outputs": [ 117 | { 118 | "name": "source_fps🟨", 119 | "type": "FLOAT", 120 | "links": [ 121 | 10 122 | ], 123 | "shape": 3, 124 | "label": "source_fps🟨", 125 | "slot_index": 0 126 | }, 127 | { 128 | "name": "source_frame_count🟨", 129 | "type": "INT", 130 | "links": null, 131 | "shape": 3, 132 | "label": "source_frame_count🟨" 133 | }, 134 | { 135 | "name": "source_duration🟨", 136 | "type": "FLOAT", 137 | "links": null, 138 | "shape": 3, 139 | "label": "source_duration🟨" 140 | }, 141 | { 142 | "name": "source_width🟨", 143 | "type": "INT", 144 | "links": null, 145 | "shape": 3, 146 | "label": "source_width🟨" 147 | }, 148 | { 149 | "name": "source_height🟨", 150 | "type": "INT", 151 | "links": null, 152 | "shape": 3, 153 | "label": "source_height🟨" 154 | }, 155 | { 156 | "name": "loaded_fps🟦", 157 | "type": "FLOAT", 158 | "links": null, 159 | "shape": 3, 160 | "label": "loaded_fps🟦" 161 | }, 162 | { 163 | "name": "loaded_frame_count🟦", 164 | "type": "INT", 165 | "links": null, 166 | "shape": 3, 167 | "label": "loaded_frame_count🟦" 168 | }, 169 | { 170 | "name": "loaded_duration🟦", 171 | "type": "FLOAT", 172 | "links": null, 173 | "shape": 3, 174 | "label": "loaded_duration🟦" 175 | }, 176 | { 177 | "name": "loaded_width🟦", 178 | "type": "INT", 179 | "links": null, 180 | "shape": 3, 181 | "label": "loaded_width🟦" 182 | }, 183 | { 184 | "name": "loaded_height🟦", 185 | "type": "INT", 186 | "links": null, 187 | "shape": 3, 188 | "label": "loaded_height🟦" 189 | } 190 | ], 191 | "properties": { 192 | "Node name for S&R": "VHS_VideoInfo" 193 | }, 194 | "widgets_values": {} 195 | }, 196 | { 197 | "id": 11, 198 | "type": "CR Float To Integer", 199 | "pos": [ 200 | -1545, 201 | 854 202 | ], 203 | "size": { 204 | "0": 315, 205 | "1": 78 206 | }, 207 | "flags": {}, 208 | "order": 3, 209 | "mode": 0, 210 | "inputs": [ 211 | { 212 | "name": "_float", 213 | "type": "FLOAT", 214 | "link": 10, 215 | "widget": { 216 | "name": "_float" 217 | }, 218 | "label": "_float" 219 | } 220 | ], 221 | "outputs": [ 222 | { 223 | "name": "INT", 224 | "type": "INT", 225 | "links": [ 226 | 13 227 | ], 228 | "shape": 3, 229 | "label": "INT", 230 | "slot_index": 0 231 | }, 232 | { 233 | "name": "show_help", 234 | "type": "STRING", 235 | "links": null, 236 | "shape": 3, 237 | "label": "show_help" 238 | } 239 | ], 240 | "properties": { 241 | "Node name for S&R": "CR Float To Integer" 242 | }, 243 | "widgets_values": [ 244 | 0 245 | ] 246 | }, 247 | { 248 | "id": 8, 249 | "type": "LoadImage", 250 | "pos": [ 251 | -1964, 252 | 326 253 | ], 254 | "size": { 255 | "0": 315, 256 | "1": 314 257 | }, 258 | "flags": {}, 259 | "order": 1, 260 | "mode": 0, 261 | "outputs": [ 262 | { 263 | "name": "IMAGE", 264 | "type": "IMAGE", 265 | "links": [ 266 | 14 267 | ], 268 | "shape": 3, 269 | "label": "IMAGE", 270 | "slot_index": 0 271 | }, 272 | { 273 | "name": "MASK", 274 | "type": "MASK", 275 | "links": null, 276 | "shape": 3, 277 | "label": "MASK" 278 | } 279 | ], 280 | "properties": { 281 | "Node name for S&R": "LoadImage" 282 | }, 283 | "widgets_values": [ 284 | "solo (2).png", 285 | "image" 286 | ] 287 | }, 288 | { 289 | "id": 9, 290 | "type": "VHS_VideoCombine", 291 | "pos": [ 292 | -1166, 293 | 282 294 | ], 295 | "size": [ 296 | 315, 297 | 599 298 | ], 299 | "flags": {}, 300 | "order": 5, 301 | "mode": 0, 302 | "inputs": [ 303 | { 304 | "name": "images", 305 | "type": "IMAGE", 306 | "link": 16, 307 | "label": "images" 308 | }, 309 | { 310 | "name": "audio", 311 | "type": "VHS_AUDIO", 312 | "link": 8, 313 | "label": "audio" 314 | }, 315 | { 316 | "name": "meta_batch", 317 | "type": "VHS_BatchManager", 318 | "link": null, 319 | "label": "meta_batch" 320 | } 321 | ], 322 | "outputs": [ 323 | { 324 | "name": "Filenames", 325 | "type": "VHS_FILENAMES", 326 | "links": null, 327 | "shape": 3, 328 | "label": "Filenames" 329 | } 330 | ], 331 | "properties": { 332 | "Node name for S&R": "VHS_VideoCombine" 333 | }, 334 | "widgets_values": { 335 | "frame_rate": 30, 336 | "loop_count": 0, 337 | "filename_prefix": "Aniportrait", 338 | "format": "video/h264-mp4", 339 | "pix_fmt": "yuv420p", 340 | "crf": 19, 341 | "save_metadata": true, 342 | "pingpong": false, 343 | "save_output": true, 344 | "videopreview": { 345 | "hidden": false, 346 | "paused": false, 347 | "params": { 348 | "filename": "Aniportrait_00003-audio.mp4", 349 | "subfolder": "", 350 | "type": "output", 351 | "format": "video/h264-mp4" 352 | } 353 | } 354 | } 355 | }, 356 | { 357 | "id": 12, 358 | "type": "AniPortrait_Audio2Video", 359 | "pos": [ 360 | -1545, 361 | 280 362 | ], 363 | "size": { 364 | "0": 315, 365 | "1": 506 366 | }, 367 | "flags": {}, 368 | "order": 4, 369 | "mode": 0, 370 | "inputs": [ 371 | { 372 | "name": "ref_image", 373 | "type": "IMAGE", 374 | "link": 14, 375 | "label": "ref_image" 376 | }, 377 | { 378 | "name": "images", 379 | "type": "IMAGE", 380 | "link": 15, 381 | "label": "images", 382 | "slot_index": 1 383 | }, 384 | { 385 | "name": "audio_path", 386 | "type": "Audio_Path", 387 | "link": null, 388 | "label": "audio_path" 389 | }, 390 | { 391 | "name": "fps", 392 | "type": "INT", 393 | "link": 13, 394 | "widget": { 395 | "name": "fps" 396 | }, 397 | "label": "fps" 398 | } 399 | ], 400 | "outputs": [ 401 | { 402 | "name": "images", 403 | "type": "IMAGE", 404 | "links": [ 405 | 16 406 | ], 407 | "shape": 3, 408 | "label": "images", 409 | "slot_index": 0 410 | } 411 | ], 412 | "properties": { 413 | "Node name for S&R": "AniPortrait_Audio2Video" 414 | }, 415 | "widgets_values": [ 416 | 512, 417 | 512, 418 | 713, 419 | "randomize", 420 | 3.5, 421 | 25, 422 | "pretrained_model/sd-vae-ft-mse", 423 | "pretrained_model/stable-diffusion-v1-5", 424 | "fp16", 425 | true, 426 | 60, 427 | 3, 428 | "pretrained_model/motion_module.pth", 429 | "pretrained_model/image_encoder", 430 | "pretrained_model/denoising_unet.pth", 431 | "pretrained_model/reference_unet.pth", 432 | "pretrained_model/pose_guider.pth", 433 | 0 434 | ] 435 | } 436 | ], 437 | "links": [ 438 | [ 439 | 8, 440 | 7, 441 | 2, 442 | 9, 443 | 1, 444 | "VHS_AUDIO" 445 | ], 446 | [ 447 | 9, 448 | 7, 449 | 3, 450 | 10, 451 | 0, 452 | "VHS_VIDEOINFO" 453 | ], 454 | [ 455 | 10, 456 | 10, 457 | 0, 458 | 11, 459 | 0, 460 | "FLOAT" 461 | ], 462 | [ 463 | 13, 464 | 11, 465 | 0, 466 | 12, 467 | 3, 468 | "INT" 469 | ], 470 | [ 471 | 14, 472 | 8, 473 | 0, 474 | 12, 475 | 0, 476 | "IMAGE" 477 | ], 478 | [ 479 | 15, 480 | 7, 481 | 0, 482 | 12, 483 | 1, 484 | "IMAGE" 485 | ], 486 | [ 487 | 16, 488 | 12, 489 | 0, 490 | 9, 491 | 0, 492 | "IMAGE" 493 | ] 494 | ], 495 | "groups": [], 496 | "config": {}, 497 | "extra": {}, 498 | "version": 0.4 499 | } -------------------------------------------------------------------------------- /assets/lyl.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/lyl.wav -------------------------------------------------------------------------------- /assets/pose2video_workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 59, 3 | "last_link_id": 47, 4 | "nodes": [ 5 | { 6 | "id": 42, 7 | "type": "VHS_LoadVideo", 8 | "pos": [ 9 | -1964, 10 | 875 11 | ], 12 | "size": [ 13 | 235.1999969482422, 14 | 491.1999969482422 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "batch_manager", 22 | "type": "VHS_BatchManager", 23 | "link": null, 24 | "label": "batch_manager" 25 | } 26 | ], 27 | "outputs": [ 28 | { 29 | "name": "IMAGE", 30 | "type": "IMAGE", 31 | "links": [ 32 | 38 33 | ], 34 | "shape": 3, 35 | "label": "IMAGE", 36 | "slot_index": 0 37 | }, 38 | { 39 | "name": "frame_count", 40 | "type": "INT", 41 | "links": null, 42 | "shape": 3, 43 | "label": "frame_count" 44 | }, 45 | { 46 | "name": "audio", 47 | "type": "VHS_AUDIO", 48 | "links": [ 49 | 34 50 | ], 51 | "shape": 3, 52 | "label": "audio" 53 | }, 54 | { 55 | "name": "video_info", 56 | "type": "VHS_VIDEOINFO", 57 | "links": [ 58 | 31 59 | ], 60 | "shape": 3, 61 | "label": "video_info", 62 | "slot_index": 3 63 | } 64 | ], 65 | "properties": { 66 | "Node name for S&R": "VHS_LoadVideo" 67 | }, 68 | "widgets_values": { 69 | "video": "pose_ref_video.mp4", 70 | "force_rate": 0, 71 | "force_size": "Disabled", 72 | "custom_width": 512, 73 | "custom_height": 512, 74 | "frame_load_cap": 0, 75 | "skip_first_frames": 0, 76 | "select_every_nth": 1, 77 | "choose video to upload": "image", 78 | "videopreview": { 79 | "hidden": false, 80 | "paused": false, 81 | "params": { 82 | "frame_load_cap": 0, 83 | "skip_first_frames": 0, 84 | "force_rate": 0, 85 | "filename": "pose_ref_video.mp4", 86 | "type": "input", 87 | "format": "video/mp4", 88 | "select_every_nth": 1 89 | } 90 | } 91 | } 92 | }, 93 | { 94 | "id": 29, 95 | "type": "VHS_VideoInfo", 96 | "pos": [ 97 | -1642, 98 | 1154 99 | ], 100 | "size": { 101 | "0": 393, 102 | "1": 206 103 | }, 104 | "flags": {}, 105 | "order": 3, 106 | "mode": 0, 107 | "inputs": [ 108 | { 109 | "name": "video_info", 110 | "type": "VHS_VIDEOINFO", 111 | "link": 31, 112 | "label": "video_info" 113 | } 114 | ], 115 | "outputs": [ 116 | { 117 | "name": "source_fps🟨", 118 | "type": "FLOAT", 119 | "links": [], 120 | "shape": 3, 121 | "label": "source_fps🟨", 122 | "slot_index": 0 123 | }, 124 | { 125 | "name": "source_frame_count🟨", 126 | "type": "INT", 127 | "links": [ 128 | 42 129 | ], 130 | "shape": 3, 131 | "label": "source_frame_count🟨", 132 | "slot_index": 1 133 | }, 134 | { 135 | "name": "source_duration🟨", 136 | "type": "FLOAT", 137 | "links": null, 138 | "shape": 3, 139 | "label": "source_duration🟨" 140 | }, 141 | { 142 | "name": "source_width🟨", 143 | "type": "INT", 144 | "links": null, 145 | "shape": 3, 146 | "label": "source_width🟨", 147 | "slot_index": 3 148 | }, 149 | { 150 | "name": "source_height🟨", 151 | "type": "INT", 152 | "links": null, 153 | "shape": 3, 154 | "label": "source_height🟨" 155 | }, 156 | { 157 | "name": "loaded_fps🟦", 158 | "type": "FLOAT", 159 | "links": null, 160 | "shape": 3, 161 | "label": "loaded_fps🟦" 162 | }, 163 | { 164 | "name": "loaded_frame_count🟦", 165 | "type": "INT", 166 | "links": null, 167 | "shape": 3, 168 | "label": "loaded_frame_count🟦" 169 | }, 170 | { 171 | "name": "loaded_duration🟦", 172 | "type": "FLOAT", 173 | "links": null, 174 | "shape": 3, 175 | "label": "loaded_duration🟦" 176 | }, 177 | { 178 | "name": "loaded_width🟦", 179 | "type": "INT", 180 | "links": null, 181 | "shape": 3, 182 | "label": "loaded_width🟦" 183 | }, 184 | { 185 | "name": "loaded_height🟦", 186 | "type": "INT", 187 | "links": null, 188 | "shape": 3, 189 | "label": "loaded_height🟦" 190 | } 191 | ], 192 | "properties": { 193 | "Node name for S&R": "VHS_VideoInfo" 194 | }, 195 | "widgets_values": {} 196 | }, 197 | { 198 | "id": 52, 199 | "type": "AniPortrait_Video_Gen_Pose", 200 | "pos": [ 201 | -1642, 202 | 1005 203 | ], 204 | "size": { 205 | "0": 361.20001220703125, 206 | "1": 106 207 | }, 208 | "flags": {}, 209 | "order": 2, 210 | "mode": 0, 211 | "inputs": [ 212 | { 213 | "name": "image", 214 | "type": "IMAGE", 215 | "link": 38, 216 | "label": "image" 217 | } 218 | ], 219 | "outputs": [ 220 | { 221 | "name": "pose_images", 222 | "type": "IMAGE", 223 | "links": [ 224 | 41, 225 | 47 226 | ], 227 | "shape": 3, 228 | "label": "pose_images", 229 | "slot_index": 0 230 | } 231 | ], 232 | "properties": { 233 | "Node name for S&R": "AniPortrait_Video_Gen_Pose" 234 | }, 235 | "widgets_values": [ 236 | "AniPortrait", 237 | 512, 238 | 512 239 | ] 240 | }, 241 | { 242 | "id": 56, 243 | "type": "LoadImage", 244 | "pos": [ 245 | -1649, 246 | 633 247 | ], 248 | "size": { 249 | "0": 315, 250 | "1": 314 251 | }, 252 | "flags": {}, 253 | "order": 1, 254 | "mode": 0, 255 | "outputs": [ 256 | { 257 | "name": "IMAGE", 258 | "type": "IMAGE", 259 | "links": [ 260 | 45 261 | ], 262 | "shape": 3, 263 | "label": "IMAGE", 264 | "slot_index": 0 265 | }, 266 | { 267 | "name": "MASK", 268 | "type": "MASK", 269 | "links": null, 270 | "shape": 3, 271 | "label": "MASK" 272 | } 273 | ], 274 | "properties": { 275 | "Node name for S&R": "LoadImage" 276 | }, 277 | "widgets_values": [ 278 | "solo.png", 279 | "image" 280 | ] 281 | }, 282 | { 283 | "id": 53, 284 | "type": "AniPortrait_Pose_Gen_Video", 285 | "pos": [ 286 | -1226, 287 | 797 288 | ], 289 | "size": { 290 | "0": 315, 291 | "1": 462 292 | }, 293 | "flags": {}, 294 | "order": 5, 295 | "mode": 0, 296 | "inputs": [ 297 | { 298 | "name": "ref_image", 299 | "type": "IMAGE", 300 | "link": 45, 301 | "label": "ref_image", 302 | "slot_index": 0 303 | }, 304 | { 305 | "name": "pose_images", 306 | "type": "IMAGE", 307 | "link": 41, 308 | "label": "pose_images" 309 | }, 310 | { 311 | "name": "frame_count", 312 | "type": "INT", 313 | "link": 42, 314 | "widget": { 315 | "name": "frame_count" 316 | }, 317 | "slot_index": 2, 318 | "label": "frame_count" 319 | } 320 | ], 321 | "outputs": [ 322 | { 323 | "name": "images", 324 | "type": "IMAGE", 325 | "links": [ 326 | 43, 327 | 46 328 | ], 329 | "shape": 3, 330 | "label": "images", 331 | "slot_index": 0 332 | } 333 | ], 334 | "properties": { 335 | "Node name for S&R": "AniPortrait_Pose_Gen_Video" 336 | }, 337 | "widgets_values": [ 338 | 0, 339 | 512, 340 | 512, 341 | 688, 342 | "randomize", 343 | 3.5, 344 | 25, 345 | "pretrained_model/sd-vae-ft-mse", 346 | "pretrained_model/stable-diffusion-v1-5", 347 | "fp16", 348 | true, 349 | 3, 350 | "pretrained_model/motion_module.pth", 351 | "pretrained_model/image_encoder", 352 | "pretrained_model/denoising_unet.pth", 353 | "pretrained_model/reference_unet.pth", 354 | "pretrained_model/pose_guider.pth" 355 | ] 356 | }, 357 | { 358 | "id": 44, 359 | "type": "VHS_VideoCombine", 360 | "pos": [ 361 | -877, 362 | 794 363 | ], 364 | "size": [ 365 | 315, 366 | 599 367 | ], 368 | "flags": {}, 369 | "order": 6, 370 | "mode": 0, 371 | "inputs": [ 372 | { 373 | "name": "images", 374 | "type": "IMAGE", 375 | "link": 43, 376 | "label": "images", 377 | "slot_index": 0 378 | }, 379 | { 380 | "name": "audio", 381 | "type": "VHS_AUDIO", 382 | "link": 34, 383 | "label": "audio", 384 | "slot_index": 1 385 | }, 386 | { 387 | "name": "batch_manager", 388 | "type": "VHS_BatchManager", 389 | "link": null, 390 | "label": "batch_manager" 391 | } 392 | ], 393 | "outputs": [ 394 | { 395 | "name": "Filenames", 396 | "type": "VHS_FILENAMES", 397 | "links": null, 398 | "shape": 3, 399 | "label": "Filenames" 400 | } 401 | ], 402 | "properties": { 403 | "Node name for S&R": "VHS_VideoCombine" 404 | }, 405 | "widgets_values": { 406 | "frame_rate": 25, 407 | "loop_count": 0, 408 | "filename_prefix": "Aniportrait", 409 | "format": "video/h264-mp4", 410 | "pix_fmt": "yuv420p", 411 | "crf": 19, 412 | "save_metadata": true, 413 | "pingpong": false, 414 | "save_output": true, 415 | "videopreview": { 416 | "hidden": false, 417 | "paused": false, 418 | "params": { 419 | "filename": "Aniportrait_00004-audio.mp4", 420 | "subfolder": "", 421 | "type": "output", 422 | "format": "video/h264-mp4" 423 | } 424 | } 425 | } 426 | }, 427 | { 428 | "id": 59, 429 | "type": "PreviewImage", 430 | "pos": [ 431 | -1245, 432 | 504 433 | ], 434 | "size": { 435 | "0": 210, 436 | "1": 246 437 | }, 438 | "flags": {}, 439 | "order": 4, 440 | "mode": 0, 441 | "inputs": [ 442 | { 443 | "name": "images", 444 | "type": "IMAGE", 445 | "link": 47, 446 | "label": "images" 447 | } 448 | ], 449 | "properties": { 450 | "Node name for S&R": "PreviewImage" 451 | } 452 | }, 453 | { 454 | "id": 58, 455 | "type": "PreviewImage", 456 | "pos": [ 457 | -890, 458 | 503 459 | ], 460 | "size": { 461 | "0": 210, 462 | "1": 246 463 | }, 464 | "flags": {}, 465 | "order": 7, 466 | "mode": 0, 467 | "inputs": [ 468 | { 469 | "name": "images", 470 | "type": "IMAGE", 471 | "link": 46, 472 | "label": "images" 473 | } 474 | ], 475 | "properties": { 476 | "Node name for S&R": "PreviewImage" 477 | } 478 | } 479 | ], 480 | "links": [ 481 | [ 482 | 31, 483 | 42, 484 | 3, 485 | 29, 486 | 0, 487 | "VHS_VIDEOINFO" 488 | ], 489 | [ 490 | 34, 491 | 42, 492 | 2, 493 | 44, 494 | 1, 495 | "VHS_AUDIO" 496 | ], 497 | [ 498 | 38, 499 | 42, 500 | 0, 501 | 52, 502 | 0, 503 | "IMAGE" 504 | ], 505 | [ 506 | 41, 507 | 52, 508 | 0, 509 | 53, 510 | 1, 511 | "IMAGE" 512 | ], 513 | [ 514 | 42, 515 | 29, 516 | 1, 517 | 53, 518 | 2, 519 | "INT" 520 | ], 521 | [ 522 | 43, 523 | 53, 524 | 0, 525 | 44, 526 | 0, 527 | "IMAGE" 528 | ], 529 | [ 530 | 45, 531 | 56, 532 | 0, 533 | 53, 534 | 0, 535 | "IMAGE" 536 | ], 537 | [ 538 | 46, 539 | 53, 540 | 0, 541 | 58, 542 | 0, 543 | "IMAGE" 544 | ], 545 | [ 546 | 47, 547 | 52, 548 | 0, 549 | 59, 550 | 0, 551 | "IMAGE" 552 | ] 553 | ], 554 | "groups": [], 555 | "config": {}, 556 | "extra": {}, 557 | "version": 0.4 558 | } -------------------------------------------------------------------------------- /assets/pose_ref_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/pose_ref_video.mp4 -------------------------------------------------------------------------------- /assets/solo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/solo.png -------------------------------------------------------------------------------- /assets/woman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/woman.jpg -------------------------------------------------------------------------------- /configs/inference/inference_audio.yaml: -------------------------------------------------------------------------------- 1 | a2m_model: 2 | out_dim: 1404 3 | latent_dim: 512 4 | model_path: pretrained_model/wav2vec2-base-960h 5 | only_last_fetures: True 6 | from_pretrained: True 7 | 8 | a2p_model: 9 | out_dim: 6 10 | latent_dim: 512 11 | model_path: pretrained_model/wav2vec2-base-960h 12 | only_last_fetures: True 13 | from_pretrained: True 14 | 15 | pretrained_model: 16 | a2m_ckpt: pretrained_model/audio2mesh.pt 17 | a2p_ckpt: pretrained_model/audio2pose.pt 18 | -------------------------------------------------------------------------------- /configs/inference/inference_v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | motion_module_mid_block: true 12 | motion_module_decoder_only: false 13 | motion_module_type: Vanilla 14 | motion_module_kwargs: 15 | num_attention_heads: 8 16 | num_transformer_block: 1 17 | attention_block_types: 18 | - Temporal_Self 19 | - Temporal_Self 20 | temporal_position_encoding: true 21 | temporal_position_encoding_max_len: 32 22 | temporal_attention_dim_div: 1 23 | 24 | noise_scheduler_kwargs: 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "linear" 28 | clip_sample: false 29 | steps_offset: 1 30 | ### Zero-SNR params 31 | prediction_type: "v_prediction" 32 | rescale_betas_zero_snr: True 33 | timestep_spacing: "trailing" 34 | 35 | sampler: DDIM -------------------------------------------------------------------------------- /configs/prompts/animation.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: 'pretrained_model/stable-diffusion-v1-5' 2 | pretrained_vae_path: 'pretrained_model/sd-vae-ft-mse' 3 | image_encoder_path: 'pretrained_model/image_encoder' 4 | 5 | denoising_unet_path: "pretrained_model/denoising_unet.pth" 6 | reference_unet_path: "pretrained_model/reference_unet.pth" 7 | pose_guider_path: "pretrained_model/pose_guider.pth" 8 | motion_module_path: "pretrained_model/motion_module.pth" 9 | 10 | inference_config: "configs/inference/inference_v2.yaml" 11 | -------------------------------------------------------------------------------- /configs/prompts/animation_audio.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: 'pretrained_model/stable-diffusion-v1-5' 2 | pretrained_vae_path: 'pretrained_model/sd-vae-ft-mse' 3 | image_encoder_path: 'pretrained_model/image_encoder' 4 | 5 | denoising_unet_path: "pretrained_model/denoising_unet.pth" 6 | reference_unet_path: "pretrained_model/reference_unet.pth" 7 | pose_guider_path: "pretrained_model/pose_guider.pth" 8 | motion_module_path: "pretrained_model/motion_module.pth" 9 | 10 | audio_inference_config: "configs/inference/inference_audio.yaml" 11 | inference_config: "configs/inference/inference_v2.yaml" 12 | weight_dtype: 'fp16' 13 | 14 | #pose_temp: "./configs/inference/head_pose_temp/pose_temp.npy" 15 | 16 | test_cases: 17 | # "./configs/inference/ref_images/lyl.png": 18 | # - "./configs/inference/audio/lyl.wav" 19 | -------------------------------------------------------------------------------- /configs/prompts/animation_facereenac.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5' 2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse' 3 | image_encoder_path: './pretrained_model/image_encoder' 4 | 5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth" 6 | reference_unet_path: "./pretrained_model/reference_unet.pth" 7 | pose_guider_path: "./pretrained_model/pose_guider.pth" 8 | motion_module_path: "./pretrained_model/motion_module.pth" 9 | 10 | inference_config: "./configs/inference/inference_v2.yaml" 11 | weight_dtype: 'fp16' 12 | 13 | test_cases: 14 | # "./configs/inference/ref_images/Aragaki.png": 15 | # - "./configs/inference/video/Aragaki_song.mp4" 16 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_aniportrait" 3 | description = "implementation of [a/AniPortrait](https://github.com/Zejun-Yang/AniPortrait) generating of videos, includes self driven, face reenacment and audio driven with a reference image" 4 | version = "1.0.0" 5 | license = "LICENSE" 6 | dependencies = ["mediapipe==0.10.11", "ffmpeg-python==0.2.0", "av==11.0.0", "librosa==0.9.2", "diffusers==0.26.2"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/frankchieng/ComfyUI_Aniportrait" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "frankchieng" 14 | DisplayName = "ComfyUI_Aniportrait" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mediapipe==0.10.11 2 | ffmpeg-python==0.2.0 3 | av==11.0.0 4 | librosa==0.9.2 5 | diffusers==0.26.2 6 | omegaconf==2.3.0 7 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | #Dummy file ensuring this package will be recognized 2 | -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/audio_models/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/audio_models/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /src/audio_models/__pycache__/torch_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/audio_models/__pycache__/torch_utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/audio_models/__pycache__/wav2vec2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/audio_models/__pycache__/wav2vec2.cpython-310.pyc -------------------------------------------------------------------------------- /src/audio_models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import Wav2Vec2Config 6 | 7 | from .torch_utils import get_mask_from_lengths 8 | from .wav2vec2 import Wav2Vec2Model 9 | 10 | 11 | class Audio2MeshModel(nn.Module): 12 | def __init__( 13 | self, 14 | config 15 | ): 16 | super().__init__() 17 | out_dim = config['out_dim'] 18 | latent_dim = config['latent_dim'] 19 | model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), config['model_path']) 20 | #model_path = config['model_path'] 21 | only_last_fetures = config['only_last_fetures'] 22 | from_pretrained = config['from_pretrained'] 23 | 24 | self._only_last_features = only_last_fetures 25 | 26 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) 27 | if from_pretrained: 28 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) 29 | else: 30 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) 31 | self.audio_encoder.feature_extractor._freeze_parameters() 32 | 33 | hidden_size = self.audio_encoder_config.hidden_size 34 | 35 | self.in_fn = nn.Linear(hidden_size, latent_dim) 36 | 37 | self.out_fn = nn.Linear(latent_dim, out_dim) 38 | nn.init.constant_(self.out_fn.weight, 0) 39 | nn.init.constant_(self.out_fn.bias, 0) 40 | 41 | def forward(self, audio, label, audio_len=None): 42 | attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None 43 | 44 | seq_len = label.shape[1] 45 | 46 | embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True, 47 | attention_mask=attention_mask) 48 | 49 | if self._only_last_features: 50 | hidden_states = embeddings.last_hidden_state 51 | else: 52 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) 53 | 54 | layer_in = self.in_fn(hidden_states) 55 | out = self.out_fn(layer_in) 56 | 57 | return out, None 58 | 59 | def infer(self, input_value, seq_len): 60 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) 61 | 62 | if self._only_last_features: 63 | hidden_states = embeddings.last_hidden_state 64 | else: 65 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) 66 | 67 | layer_in = self.in_fn(hidden_states) 68 | out = self.out_fn(layer_in) 69 | 70 | return out 71 | 72 | 73 | -------------------------------------------------------------------------------- /src/audio_models/pose_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from transformers import Wav2Vec2Config 6 | 7 | from .torch_utils import get_mask_from_lengths 8 | from .wav2vec2 import Wav2Vec2Model 9 | 10 | 11 | def init_biased_mask(n_head, max_seq_len, period): 12 | def get_slopes(n): 13 | def get_slopes_power_of_2(n): 14 | start = (2**(-2**-(math.log2(n)-3))) 15 | ratio = start 16 | return [start*ratio**i for i in range(n)] 17 | if math.log2(n).is_integer(): 18 | return get_slopes_power_of_2(n) 19 | else: 20 | closest_power_of_2 = 2**math.floor(math.log2(n)) 21 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2] 22 | slopes = torch.Tensor(get_slopes(n_head)) 23 | bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period) 24 | bias = - torch.flip(bias,dims=[0]) 25 | alibi = torch.zeros(max_seq_len, max_seq_len) 26 | for i in range(max_seq_len): 27 | alibi[i, :i+1] = bias[-(i+1):] 28 | alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) 29 | mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1) 30 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 31 | mask = mask.unsqueeze(0) + alibi 32 | return mask 33 | 34 | 35 | def enc_dec_mask(device, T, S): 36 | mask = torch.ones(T, S) 37 | for i in range(T): 38 | mask[i, i] = 0 39 | return (mask==1).to(device=device) 40 | 41 | 42 | class PositionalEncoding(nn.Module): 43 | def __init__(self, d_model, max_len=600): 44 | super(PositionalEncoding, self).__init__() 45 | pe = torch.zeros(max_len, d_model) 46 | position = torch.arange(0, max_len).unsqueeze(1).float() 47 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 48 | pe[:, 0::2] = torch.sin(position * div_term) 49 | pe[:, 1::2] = torch.cos(position * div_term) 50 | pe = pe.unsqueeze(0) 51 | self.register_buffer('pe', pe) 52 | 53 | def forward(self, x): 54 | x = x + self.pe[:, :x.size(1)] 55 | return x 56 | 57 | 58 | class Audio2PoseModel(nn.Module): 59 | def __init__( 60 | self, 61 | config 62 | ): 63 | 64 | super().__init__() 65 | 66 | latent_dim = config['latent_dim'] 67 | model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), config['model_path']) 68 | #model_path = config['model_path'] 69 | only_last_fetures = config['only_last_fetures'] 70 | from_pretrained = config['from_pretrained'] 71 | out_dim = config['out_dim'] 72 | 73 | self.out_dim = out_dim 74 | 75 | self._only_last_features = only_last_fetures 76 | 77 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) 78 | if from_pretrained: 79 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) 80 | else: 81 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) 82 | self.audio_encoder.feature_extractor._freeze_parameters() 83 | 84 | hidden_size = self.audio_encoder_config.hidden_size 85 | 86 | self.pose_map = nn.Linear(out_dim, latent_dim) 87 | self.in_fn = nn.Linear(hidden_size, latent_dim) 88 | 89 | self.PPE = PositionalEncoding(latent_dim) 90 | self.biased_mask = init_biased_mask(n_head = 8, max_seq_len = 600, period=1) 91 | decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8, dim_feedforward=2*latent_dim, batch_first=True) 92 | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=8) 93 | self.pose_map_r = nn.Linear(latent_dim, out_dim) 94 | 95 | self.id_embed = nn.Embedding(100, latent_dim) # 100 ids 96 | 97 | 98 | def infer(self, input_value, seq_len, id_seed=None): 99 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) 100 | 101 | if self._only_last_features: 102 | hidden_states = embeddings.last_hidden_state 103 | else: 104 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) 105 | 106 | hidden_states = self.in_fn(hidden_states) 107 | 108 | id_embedding = self.id_embed(id_seed).unsqueeze(1) 109 | 110 | init_pose = torch.zeros([hidden_states.shape[0], 1, self.out_dim]).to(hidden_states.device) 111 | for i in range(seq_len): 112 | if i==0: 113 | pose_emb = self.pose_map(init_pose) 114 | pose_input = self.PPE(pose_emb) 115 | else: 116 | pose_input = self.PPE(pose_emb) 117 | 118 | pose_input = pose_input + id_embedding 119 | tgt_mask = self.biased_mask[:, :pose_input.shape[1], :pose_input.shape[1]].clone().detach().to(hidden_states.device) 120 | memory_mask = enc_dec_mask(hidden_states.device, pose_input.shape[1], hidden_states.shape[1]) 121 | pose_out = self.transformer_decoder(pose_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask) 122 | pose_out = self.pose_map_r(pose_out) 123 | new_output = self.pose_map(pose_out[:,-1,:]).unsqueeze(1) 124 | pose_emb = torch.cat((pose_emb, new_output), 1) 125 | return pose_out 126 | 127 | -------------------------------------------------------------------------------- /src/audio_models/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_mask_from_lengths(lengths, max_len=None): 6 | lengths = lengths.to(torch.long) 7 | if max_len is None: 8 | max_len = torch.max(lengths).item() 9 | 10 | ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device) 11 | mask = ids < lengths.unsqueeze(1).expand(-1, max_len) 12 | 13 | return mask 14 | 15 | 16 | def linear_interpolation(features, seq_len): 17 | features = features.transpose(1, 2) 18 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') 19 | return output_features.transpose(1, 2) 20 | 21 | 22 | if __name__ == "__main__": 23 | import numpy as np 24 | mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6]))) 25 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /src/audio_models/wav2vec2.py: -------------------------------------------------------------------------------- 1 | from transformers import Wav2Vec2Config, Wav2Vec2Model 2 | from transformers.modeling_outputs import BaseModelOutput 3 | 4 | from .torch_utils import linear_interpolation 5 | 6 | # the implementation of Wav2Vec2Model is borrowed from 7 | # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py 8 | # initialize our encoder with the pre-trained wav2vec 2.0 weights. 9 | class Wav2Vec2Model(Wav2Vec2Model): 10 | def __init__(self, config: Wav2Vec2Config): 11 | super().__init__(config) 12 | 13 | def forward( 14 | self, 15 | input_values, 16 | seq_len, 17 | attention_mask=None, 18 | mask_time_indices=None, 19 | output_attentions=None, 20 | output_hidden_states=None, 21 | return_dict=None, 22 | ): 23 | self.config.output_attentions = True 24 | 25 | output_hidden_states = ( 26 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 27 | ) 28 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 29 | 30 | extract_features = self.feature_extractor(input_values) 31 | extract_features = extract_features.transpose(1, 2) 32 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 33 | 34 | if attention_mask is not None: 35 | # compute reduced attention_mask corresponding to feature vectors 36 | attention_mask = self._get_feature_vector_attention_mask( 37 | extract_features.shape[1], attention_mask, add_adapter=False 38 | ) 39 | 40 | hidden_states, extract_features = self.feature_projection(extract_features) 41 | hidden_states = self._mask_hidden_states( 42 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 43 | ) 44 | 45 | encoder_outputs = self.encoder( 46 | hidden_states, 47 | attention_mask=attention_mask, 48 | output_attentions=output_attentions, 49 | output_hidden_states=output_hidden_states, 50 | return_dict=return_dict, 51 | ) 52 | 53 | hidden_states = encoder_outputs[0] 54 | 55 | if self.adapter is not None: 56 | hidden_states = self.adapter(hidden_states) 57 | 58 | if not return_dict: 59 | return (hidden_states, ) + encoder_outputs[1:] 60 | return BaseModelOutput( 61 | last_hidden_state=hidden_states, 62 | hidden_states=encoder_outputs.hidden_states, 63 | attentions=encoder_outputs.attentions, 64 | ) 65 | 66 | 67 | def feature_extract( 68 | self, 69 | input_values, 70 | seq_len, 71 | ): 72 | extract_features = self.feature_extractor(input_values) 73 | extract_features = extract_features.transpose(1, 2) 74 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 75 | 76 | return extract_features 77 | 78 | def encode( 79 | self, 80 | extract_features, 81 | attention_mask=None, 82 | mask_time_indices=None, 83 | output_attentions=None, 84 | output_hidden_states=None, 85 | return_dict=None, 86 | ): 87 | self.config.output_attentions = True 88 | 89 | output_hidden_states = ( 90 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 91 | ) 92 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 93 | 94 | if attention_mask is not None: 95 | # compute reduced attention_mask corresponding to feature vectors 96 | attention_mask = self._get_feature_vector_attention_mask( 97 | extract_features.shape[1], attention_mask, add_adapter=False 98 | ) 99 | 100 | 101 | hidden_states, extract_features = self.feature_projection(extract_features) 102 | hidden_states = self._mask_hidden_states( 103 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 104 | ) 105 | 106 | encoder_outputs = self.encoder( 107 | hidden_states, 108 | attention_mask=attention_mask, 109 | output_attentions=output_attentions, 110 | output_hidden_states=output_hidden_states, 111 | return_dict=return_dict, 112 | ) 113 | 114 | hidden_states = encoder_outputs[0] 115 | 116 | if self.adapter is not None: 117 | hidden_states = self.adapter(hidden_states) 118 | 119 | if not return_dict: 120 | return (hidden_states, ) + encoder_outputs[1:] 121 | return BaseModelOutput( 122 | last_hidden_state=hidden_states, 123 | hidden_states=encoder_outputs.hidden_states, 124 | attentions=encoder_outputs.attentions, 125 | ) 126 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/motion_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/motion_module.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/mutual_self_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/mutual_self_attention.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/pose_guider.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/pose_guider.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/transformer_2d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/transformer_2d.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/transformer_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/transformer_3d.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_2d_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_2d_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_2d_condition.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_2d_condition.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_3d.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/unet_3d_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_3d_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/motion_module.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Callable, Optional 5 | 6 | import torch 7 | from diffusers.models.attention import FeedForward 8 | from diffusers.models.attention_processor import Attention, AttnProcessor 9 | from diffusers.utils import BaseOutput 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from einops import rearrange, repeat 12 | from torch import nn 13 | 14 | 15 | def zero_module(module): 16 | # Zero out the parameters of a module and return it. 17 | for p in module.parameters(): 18 | p.detach().zero_() 19 | return module 20 | 21 | 22 | @dataclass 23 | class TemporalTransformer3DModelOutput(BaseOutput): 24 | sample: torch.FloatTensor 25 | 26 | 27 | if is_xformers_available(): 28 | import xformers 29 | import xformers.ops 30 | else: 31 | xformers = None 32 | 33 | 34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): 35 | if motion_module_type == "Vanilla": 36 | return VanillaTemporalModule( 37 | in_channels=in_channels, 38 | **motion_module_kwargs, 39 | ) 40 | else: 41 | raise ValueError 42 | 43 | 44 | class VanillaTemporalModule(nn.Module): 45 | def __init__( 46 | self, 47 | in_channels, 48 | num_attention_heads=8, 49 | num_transformer_block=2, 50 | attention_block_types=("Temporal_Self", "Temporal_Self"), 51 | cross_frame_attention_mode=None, 52 | temporal_position_encoding=False, 53 | temporal_position_encoding_max_len=24, 54 | temporal_attention_dim_div=1, 55 | zero_initialize=True, 56 | ): 57 | super().__init__() 58 | 59 | self.temporal_transformer = TemporalTransformer3DModel( 60 | in_channels=in_channels, 61 | num_attention_heads=num_attention_heads, 62 | attention_head_dim=in_channels 63 | // num_attention_heads 64 | // temporal_attention_dim_div, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_frame_attention_mode=cross_frame_attention_mode, 68 | temporal_position_encoding=temporal_position_encoding, 69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 70 | ) 71 | 72 | if zero_initialize: 73 | self.temporal_transformer.proj_out = zero_module( 74 | self.temporal_transformer.proj_out 75 | ) 76 | 77 | def forward( 78 | self, 79 | input_tensor, 80 | temb, 81 | encoder_hidden_states, 82 | attention_mask=None, 83 | anchor_frame_idx=None, 84 | ): 85 | hidden_states = input_tensor 86 | hidden_states = self.temporal_transformer( 87 | hidden_states, encoder_hidden_states, attention_mask 88 | ) 89 | 90 | output = hidden_states 91 | return output 92 | 93 | 94 | class TemporalTransformer3DModel(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels, 98 | num_attention_heads, 99 | attention_head_dim, 100 | num_layers, 101 | attention_block_types=( 102 | "Temporal_Self", 103 | "Temporal_Self", 104 | ), 105 | dropout=0.0, 106 | norm_num_groups=32, 107 | cross_attention_dim=768, 108 | activation_fn="geglu", 109 | attention_bias=False, 110 | upcast_attention=False, 111 | cross_frame_attention_mode=None, 112 | temporal_position_encoding=False, 113 | temporal_position_encoding_max_len=24, 114 | ): 115 | super().__init__() 116 | 117 | inner_dim = num_attention_heads * attention_head_dim 118 | 119 | self.norm = torch.nn.GroupNorm( 120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 121 | ) 122 | self.proj_in = nn.Linear(in_channels, inner_dim) 123 | 124 | self.transformer_blocks = nn.ModuleList( 125 | [ 126 | TemporalTransformerBlock( 127 | dim=inner_dim, 128 | num_attention_heads=num_attention_heads, 129 | attention_head_dim=attention_head_dim, 130 | attention_block_types=attention_block_types, 131 | dropout=dropout, 132 | norm_num_groups=norm_num_groups, 133 | cross_attention_dim=cross_attention_dim, 134 | activation_fn=activation_fn, 135 | attention_bias=attention_bias, 136 | upcast_attention=upcast_attention, 137 | cross_frame_attention_mode=cross_frame_attention_mode, 138 | temporal_position_encoding=temporal_position_encoding, 139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 140 | ) 141 | for d in range(num_layers) 142 | ] 143 | ) 144 | self.proj_out = nn.Linear(inner_dim, in_channels) 145 | 146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 147 | assert ( 148 | hidden_states.dim() == 5 149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 150 | video_length = hidden_states.shape[2] 151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 152 | 153 | batch, channel, height, weight = hidden_states.shape 154 | residual = hidden_states 155 | 156 | hidden_states = self.norm(hidden_states) 157 | inner_dim = hidden_states.shape[1] 158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 159 | batch, height * weight, inner_dim 160 | ) 161 | hidden_states = self.proj_in(hidden_states) 162 | 163 | # Transformer Blocks 164 | for block in self.transformer_blocks: 165 | hidden_states = block( 166 | hidden_states, 167 | encoder_hidden_states=encoder_hidden_states, 168 | video_length=video_length, 169 | ) 170 | 171 | # output 172 | hidden_states = self.proj_out(hidden_states) 173 | hidden_states = ( 174 | hidden_states.reshape(batch, height, weight, inner_dim) 175 | .permute(0, 3, 1, 2) 176 | .contiguous() 177 | ) 178 | 179 | output = hidden_states + residual 180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 181 | 182 | return output 183 | 184 | 185 | class TemporalTransformerBlock(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | num_attention_heads, 190 | attention_head_dim, 191 | attention_block_types=( 192 | "Temporal_Self", 193 | "Temporal_Self", 194 | ), 195 | dropout=0.0, 196 | norm_num_groups=32, 197 | cross_attention_dim=768, 198 | activation_fn="geglu", 199 | attention_bias=False, 200 | upcast_attention=False, 201 | cross_frame_attention_mode=None, 202 | temporal_position_encoding=False, 203 | temporal_position_encoding_max_len=24, 204 | ): 205 | super().__init__() 206 | 207 | attention_blocks = [] 208 | norms = [] 209 | 210 | for block_name in attention_block_types: 211 | attention_blocks.append( 212 | VersatileAttention( 213 | attention_mode=block_name.split("_")[0], 214 | cross_attention_dim=cross_attention_dim 215 | if block_name.endswith("_Cross") 216 | else None, 217 | query_dim=dim, 218 | heads=num_attention_heads, 219 | dim_head=attention_head_dim, 220 | dropout=dropout, 221 | bias=attention_bias, 222 | upcast_attention=upcast_attention, 223 | cross_frame_attention_mode=cross_frame_attention_mode, 224 | temporal_position_encoding=temporal_position_encoding, 225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 226 | ) 227 | ) 228 | norms.append(nn.LayerNorm(dim)) 229 | 230 | self.attention_blocks = nn.ModuleList(attention_blocks) 231 | self.norms = nn.ModuleList(norms) 232 | 233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 234 | self.ff_norm = nn.LayerNorm(dim) 235 | 236 | def forward( 237 | self, 238 | hidden_states, 239 | encoder_hidden_states=None, 240 | attention_mask=None, 241 | video_length=None, 242 | ): 243 | for attention_block, norm in zip(self.attention_blocks, self.norms): 244 | norm_hidden_states = norm(hidden_states) 245 | hidden_states = ( 246 | attention_block( 247 | norm_hidden_states, 248 | encoder_hidden_states=encoder_hidden_states 249 | if attention_block.is_cross_attention 250 | else None, 251 | video_length=video_length, 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 257 | 258 | output = hidden_states 259 | return output 260 | 261 | 262 | class PositionalEncoding(nn.Module): 263 | def __init__(self, d_model, dropout=0.0, max_len=24): 264 | super().__init__() 265 | self.dropout = nn.Dropout(p=dropout) 266 | position = torch.arange(max_len).unsqueeze(1) 267 | div_term = torch.exp( 268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 269 | ) 270 | pe = torch.zeros(1, max_len, d_model) 271 | pe[0, :, 0::2] = torch.sin(position * div_term) 272 | pe[0, :, 1::2] = torch.cos(position * div_term) 273 | self.register_buffer("pe", pe) 274 | 275 | def forward(self, x): 276 | x = x + self.pe[:, : x.size(1)] 277 | return self.dropout(x) 278 | 279 | 280 | class VersatileAttention(Attention): 281 | def __init__( 282 | self, 283 | attention_mode=None, 284 | cross_frame_attention_mode=None, 285 | temporal_position_encoding=False, 286 | temporal_position_encoding_max_len=24, 287 | *args, 288 | **kwargs, 289 | ): 290 | super().__init__(*args, **kwargs) 291 | assert attention_mode == "Temporal" 292 | 293 | self.attention_mode = attention_mode 294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 295 | 296 | self.pos_encoder = ( 297 | PositionalEncoding( 298 | kwargs["query_dim"], 299 | dropout=0.0, 300 | max_len=temporal_position_encoding_max_len, 301 | ) 302 | if (temporal_position_encoding and attention_mode == "Temporal") 303 | else None 304 | ) 305 | 306 | def extra_repr(self): 307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 308 | 309 | def set_use_memory_efficient_attention_xformers( 310 | self, 311 | use_memory_efficient_attention_xformers: bool, 312 | attention_op: Optional[Callable] = None, 313 | ): 314 | if use_memory_efficient_attention_xformers: 315 | if not is_xformers_available(): 316 | raise ModuleNotFoundError( 317 | ( 318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 319 | " xformers" 320 | ), 321 | name="xformers", 322 | ) 323 | elif not torch.cuda.is_available(): 324 | raise ValueError( 325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 326 | " only available for GPU " 327 | ) 328 | else: 329 | try: 330 | # Make sure we can run the memory efficient attention 331 | _ = xformers.ops.memory_efficient_attention( 332 | torch.randn((1, 2, 40), device="cuda"), 333 | torch.randn((1, 2, 40), device="cuda"), 334 | torch.randn((1, 2, 40), device="cuda"), 335 | ) 336 | except Exception as e: 337 | raise e 338 | 339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13. 340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13. 341 | # You don't need XFormersAttnProcessor here. 342 | # processor = XFormersAttnProcessor( 343 | # attention_op=attention_op, 344 | # ) 345 | processor = AttnProcessor() 346 | else: 347 | processor = AttnProcessor() 348 | 349 | self.set_processor(processor) 350 | 351 | def forward( 352 | self, 353 | hidden_states, 354 | encoder_hidden_states=None, 355 | attention_mask=None, 356 | video_length=None, 357 | **cross_attention_kwargs, 358 | ): 359 | if self.attention_mode == "Temporal": 360 | d = hidden_states.shape[1] # d means HxW 361 | hidden_states = rearrange( 362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 363 | ) 364 | 365 | if self.pos_encoder is not None: 366 | hidden_states = self.pos_encoder(hidden_states) 367 | 368 | encoder_hidden_states = ( 369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) 370 | if encoder_hidden_states is not None 371 | else encoder_hidden_states 372 | ) 373 | 374 | else: 375 | raise NotImplementedError 376 | 377 | hidden_states = self.processor( 378 | self, 379 | hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | 385 | if self.attention_mode == "Temporal": 386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 387 | 388 | return hidden_states 389 | -------------------------------------------------------------------------------- /src/models/mutual_self_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from .attention import TemporalBasicTransformerBlock 8 | 9 | from .attention import BasicTransformerBlock 10 | 11 | 12 | def torch_dfs(model: torch.nn.Module): 13 | result = [model] 14 | for child in model.children(): 15 | result += torch_dfs(child) 16 | return result 17 | 18 | 19 | class ReferenceAttentionControl: 20 | def __init__( 21 | self, 22 | unet, 23 | mode="write", 24 | do_classifier_free_guidance=False, 25 | attention_auto_machine_weight=float("inf"), 26 | gn_auto_machine_weight=1.0, 27 | style_fidelity=1.0, 28 | reference_attn=True, 29 | reference_adain=False, 30 | fusion_blocks="midup", 31 | batch_size=1, 32 | ) -> None: 33 | # 10. Modify self attention and group norm 34 | self.unet = unet 35 | assert mode in ["read", "write"] 36 | assert fusion_blocks in ["midup", "full"] 37 | self.reference_attn = reference_attn 38 | self.reference_adain = reference_adain 39 | self.fusion_blocks = fusion_blocks 40 | self.register_reference_hooks( 41 | mode, 42 | do_classifier_free_guidance, 43 | attention_auto_machine_weight, 44 | gn_auto_machine_weight, 45 | style_fidelity, 46 | reference_attn, 47 | reference_adain, 48 | fusion_blocks, 49 | batch_size=batch_size, 50 | ) 51 | 52 | def register_reference_hooks( 53 | self, 54 | mode, 55 | do_classifier_free_guidance, 56 | attention_auto_machine_weight, 57 | gn_auto_machine_weight, 58 | style_fidelity, 59 | reference_attn, 60 | reference_adain, 61 | dtype=torch.float16, 62 | batch_size=1, 63 | num_images_per_prompt=1, 64 | device=torch.device("cpu"), 65 | fusion_blocks="midup", 66 | ): 67 | MODE = mode 68 | do_classifier_free_guidance = do_classifier_free_guidance 69 | attention_auto_machine_weight = attention_auto_machine_weight 70 | gn_auto_machine_weight = gn_auto_machine_weight 71 | style_fidelity = style_fidelity 72 | reference_attn = reference_attn 73 | reference_adain = reference_adain 74 | fusion_blocks = fusion_blocks 75 | num_images_per_prompt = num_images_per_prompt 76 | dtype = dtype 77 | if do_classifier_free_guidance: 78 | uc_mask = ( 79 | torch.Tensor( 80 | [1] * batch_size * num_images_per_prompt * 16 81 | + [0] * batch_size * num_images_per_prompt * 16 82 | ) 83 | .to(device) 84 | .bool() 85 | ) 86 | else: 87 | uc_mask = ( 88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 89 | .to(device) 90 | .bool() 91 | ) 92 | 93 | def hacked_basic_transformer_inner_forward( 94 | self, 95 | hidden_states: torch.FloatTensor, 96 | attention_mask: Optional[torch.FloatTensor] = None, 97 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 98 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 99 | timestep: Optional[torch.LongTensor] = None, 100 | cross_attention_kwargs: Dict[str, Any] = None, 101 | class_labels: Optional[torch.LongTensor] = None, 102 | video_length=None, 103 | ): 104 | if self.use_ada_layer_norm: # False 105 | norm_hidden_states = self.norm1(hidden_states, timestep) 106 | elif self.use_ada_layer_norm_zero: 107 | ( 108 | norm_hidden_states, 109 | gate_msa, 110 | shift_mlp, 111 | scale_mlp, 112 | gate_mlp, 113 | ) = self.norm1( 114 | hidden_states, 115 | timestep, 116 | class_labels, 117 | hidden_dtype=hidden_states.dtype, 118 | ) 119 | else: 120 | norm_hidden_states = self.norm1(hidden_states) 121 | 122 | # 1. Self-Attention 123 | # self.only_cross_attention = False 124 | cross_attention_kwargs = ( 125 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 126 | ) 127 | if self.only_cross_attention: 128 | attn_output = self.attn1( 129 | norm_hidden_states, 130 | encoder_hidden_states=encoder_hidden_states 131 | if self.only_cross_attention 132 | else None, 133 | attention_mask=attention_mask, 134 | **cross_attention_kwargs, 135 | ) 136 | else: 137 | if MODE == "write": 138 | self.bank.append(norm_hidden_states.clone()) 139 | attn_output = self.attn1( 140 | norm_hidden_states, 141 | encoder_hidden_states=encoder_hidden_states 142 | if self.only_cross_attention 143 | else None, 144 | attention_mask=attention_mask, 145 | **cross_attention_kwargs, 146 | ) 147 | if MODE == "read": 148 | bank_fea = [ 149 | rearrange( 150 | d.unsqueeze(1).repeat(1, video_length, 1, 1), 151 | "b t l c -> (b t) l c", 152 | ) 153 | for d in self.bank 154 | ] 155 | modify_norm_hidden_states = torch.cat( 156 | [norm_hidden_states] + bank_fea, dim=1 157 | ) 158 | hidden_states_uc = ( 159 | self.attn1( 160 | norm_hidden_states, 161 | encoder_hidden_states=modify_norm_hidden_states, 162 | attention_mask=attention_mask, 163 | ) 164 | + hidden_states 165 | ) 166 | if do_classifier_free_guidance: 167 | hidden_states_c = hidden_states_uc.clone() 168 | _uc_mask = uc_mask.clone() 169 | if hidden_states.shape[0] != _uc_mask.shape[0]: 170 | _uc_mask = ( 171 | torch.Tensor( 172 | [1] * (hidden_states.shape[0] // 2) 173 | + [0] * (hidden_states.shape[0] // 2) 174 | ) 175 | .to(device) 176 | .bool() 177 | ) 178 | hidden_states_c[_uc_mask] = ( 179 | self.attn1( 180 | norm_hidden_states[_uc_mask], 181 | encoder_hidden_states=norm_hidden_states[_uc_mask], 182 | attention_mask=attention_mask, 183 | ) 184 | + hidden_states[_uc_mask] 185 | ) 186 | hidden_states = hidden_states_c.clone() 187 | else: 188 | hidden_states = hidden_states_uc 189 | 190 | # self.bank.clear() 191 | if self.attn2 is not None: 192 | # Cross-Attention 193 | norm_hidden_states = ( 194 | self.norm2(hidden_states, timestep) 195 | if self.use_ada_layer_norm 196 | else self.norm2(hidden_states) 197 | ) 198 | hidden_states = ( 199 | self.attn2( 200 | norm_hidden_states, 201 | encoder_hidden_states=encoder_hidden_states, 202 | attention_mask=attention_mask, 203 | ) 204 | + hidden_states 205 | ) 206 | 207 | # Feed-forward 208 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 209 | 210 | # Temporal-Attention 211 | if self.unet_use_temporal_attention: 212 | d = hidden_states.shape[1] 213 | hidden_states = rearrange( 214 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 215 | ) 216 | norm_hidden_states = ( 217 | self.norm_temp(hidden_states, timestep) 218 | if self.use_ada_layer_norm 219 | else self.norm_temp(hidden_states) 220 | ) 221 | hidden_states = ( 222 | self.attn_temp(norm_hidden_states) + hidden_states 223 | ) 224 | hidden_states = rearrange( 225 | hidden_states, "(b d) f c -> (b f) d c", d=d 226 | ) 227 | 228 | return hidden_states 229 | 230 | if self.use_ada_layer_norm_zero: 231 | attn_output = gate_msa.unsqueeze(1) * attn_output 232 | hidden_states = attn_output + hidden_states 233 | 234 | if self.attn2 is not None: 235 | norm_hidden_states = ( 236 | self.norm2(hidden_states, timestep) 237 | if self.use_ada_layer_norm 238 | else self.norm2(hidden_states) 239 | ) 240 | 241 | # 2. Cross-Attention 242 | attn_output = self.attn2( 243 | norm_hidden_states, 244 | encoder_hidden_states=encoder_hidden_states, 245 | attention_mask=encoder_attention_mask, 246 | **cross_attention_kwargs, 247 | ) 248 | hidden_states = attn_output + hidden_states 249 | 250 | # 3. Feed-forward 251 | norm_hidden_states = self.norm3(hidden_states) 252 | 253 | if self.use_ada_layer_norm_zero: 254 | norm_hidden_states = ( 255 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 256 | ) 257 | 258 | ff_output = self.ff(norm_hidden_states) 259 | 260 | if self.use_ada_layer_norm_zero: 261 | ff_output = gate_mlp.unsqueeze(1) * ff_output 262 | 263 | hidden_states = ff_output + hidden_states 264 | 265 | return hidden_states 266 | 267 | if self.reference_attn: 268 | if self.fusion_blocks == "midup": 269 | attn_modules = [ 270 | module 271 | for module in ( 272 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 273 | ) 274 | if isinstance(module, BasicTransformerBlock) 275 | or isinstance(module, TemporalBasicTransformerBlock) 276 | ] 277 | elif self.fusion_blocks == "full": 278 | attn_modules = [ 279 | module 280 | for module in torch_dfs(self.unet) 281 | if isinstance(module, BasicTransformerBlock) 282 | or isinstance(module, TemporalBasicTransformerBlock) 283 | ] 284 | attn_modules = sorted( 285 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 286 | ) 287 | 288 | for i, module in enumerate(attn_modules): 289 | module._original_inner_forward = module.forward 290 | if isinstance(module, BasicTransformerBlock): 291 | module.forward = hacked_basic_transformer_inner_forward.__get__( 292 | module, BasicTransformerBlock 293 | ) 294 | if isinstance(module, TemporalBasicTransformerBlock): 295 | module.forward = hacked_basic_transformer_inner_forward.__get__( 296 | module, TemporalBasicTransformerBlock 297 | ) 298 | 299 | module.bank = [] 300 | module.attn_weight = float(i) / float(len(attn_modules)) 301 | 302 | def update(self, writer, dtype=torch.float16): 303 | if self.reference_attn: 304 | if self.fusion_blocks == "midup": 305 | reader_attn_modules = [ 306 | module 307 | for module in ( 308 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 309 | ) 310 | if isinstance(module, TemporalBasicTransformerBlock) 311 | ] 312 | writer_attn_modules = [ 313 | module 314 | for module in ( 315 | torch_dfs(writer.unet.mid_block) 316 | + torch_dfs(writer.unet.up_blocks) 317 | ) 318 | if isinstance(module, BasicTransformerBlock) 319 | ] 320 | elif self.fusion_blocks == "full": 321 | reader_attn_modules = [ 322 | module 323 | for module in torch_dfs(self.unet) 324 | if isinstance(module, TemporalBasicTransformerBlock) 325 | ] 326 | writer_attn_modules = [ 327 | module 328 | for module in torch_dfs(writer.unet) 329 | if isinstance(module, BasicTransformerBlock) 330 | ] 331 | reader_attn_modules = sorted( 332 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 333 | ) 334 | writer_attn_modules = sorted( 335 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 336 | ) 337 | for r, w in zip(reader_attn_modules, writer_attn_modules): 338 | r.bank = [v.clone().to(dtype) for v in w.bank] 339 | # w.bank.clear() 340 | 341 | def clear(self): 342 | if self.reference_attn: 343 | if self.fusion_blocks == "midup": 344 | reader_attn_modules = [ 345 | module 346 | for module in ( 347 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 348 | ) 349 | if isinstance(module, BasicTransformerBlock) 350 | or isinstance(module, TemporalBasicTransformerBlock) 351 | ] 352 | elif self.fusion_blocks == "full": 353 | reader_attn_modules = [ 354 | module 355 | for module in torch_dfs(self.unet) 356 | if isinstance(module, BasicTransformerBlock) 357 | or isinstance(module, TemporalBasicTransformerBlock) 358 | ] 359 | reader_attn_modules = sorted( 360 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 361 | ) 362 | for r in reader_attn_modules: 363 | r.bank.clear() 364 | -------------------------------------------------------------------------------- /src/models/pose_guider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | from einops import rearrange 6 | import numpy as np 7 | from diffusers.models.modeling_utils import ModelMixin 8 | 9 | from typing import Any, Dict, Optional 10 | from .attention import BasicTransformerBlock 11 | 12 | 13 | class PoseGuider(ModelMixin): 14 | def __init__(self, noise_latent_channels=320, use_ca=True): 15 | super(PoseGuider, self).__init__() 16 | 17 | self.use_ca = use_ca 18 | 19 | self.conv_layers = nn.Sequential( 20 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(3), 22 | nn.ReLU(), 23 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), 24 | nn.BatchNorm2d(16), 25 | nn.ReLU(), 26 | 27 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1), 28 | nn.BatchNorm2d(16), 29 | nn.ReLU(), 30 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), 31 | nn.BatchNorm2d(32), 32 | nn.ReLU(), 33 | 34 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(32), 36 | nn.ReLU(), 37 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), 38 | nn.BatchNorm2d(64), 39 | nn.ReLU(), 40 | 41 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(64), 43 | nn.ReLU(), 44 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 45 | nn.BatchNorm2d(128), 46 | nn.ReLU() 47 | ) 48 | 49 | # Final projection layer 50 | self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1) 51 | 52 | self.conv_layers_1 = nn.Sequential( 53 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1), 54 | nn.BatchNorm2d(noise_latent_channels), 55 | nn.ReLU(), 56 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, stride=2, padding=1), 57 | nn.BatchNorm2d(noise_latent_channels), 58 | nn.ReLU(), 59 | ) 60 | 61 | self.conv_layers_2 = nn.Sequential( 62 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(noise_latent_channels), 64 | nn.ReLU(), 65 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels*2, kernel_size=3, stride=2, padding=1), 66 | nn.BatchNorm2d(noise_latent_channels*2), 67 | nn.ReLU(), 68 | ) 69 | 70 | self.conv_layers_3 = nn.Sequential( 71 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*2, kernel_size=3, padding=1), 72 | nn.BatchNorm2d(noise_latent_channels*2), 73 | nn.ReLU(), 74 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*4, kernel_size=3, stride=2, padding=1), 75 | nn.BatchNorm2d(noise_latent_channels*4), 76 | nn.ReLU(), 77 | ) 78 | 79 | self.conv_layers_4 = nn.Sequential( 80 | nn.Conv2d(in_channels=noise_latent_channels*4, out_channels=noise_latent_channels*4, kernel_size=3, padding=1), 81 | nn.BatchNorm2d(noise_latent_channels*4), 82 | nn.ReLU(), 83 | ) 84 | 85 | if self.use_ca: 86 | self.cross_attn1 = Transformer2DModel(in_channels=noise_latent_channels) 87 | self.cross_attn2 = Transformer2DModel(in_channels=noise_latent_channels*2) 88 | self.cross_attn3 = Transformer2DModel(in_channels=noise_latent_channels*4) 89 | self.cross_attn4 = Transformer2DModel(in_channels=noise_latent_channels*4) 90 | 91 | # Initialize layers 92 | self._initialize_weights() 93 | 94 | self.scale = nn.Parameter(torch.ones(1) * 2) 95 | 96 | # def _initialize_weights(self): 97 | # # Initialize weights with Gaussian distribution and zero out the final layer 98 | # for m in self.conv_layers: 99 | # if isinstance(m, nn.Conv2d): 100 | # init.normal_(m.weight, mean=0.0, std=0.02) 101 | # if m.bias is not None: 102 | # init.zeros_(m.bias) 103 | 104 | # init.zeros_(self.final_proj.weight) 105 | # if self.final_proj.bias is not None: 106 | # init.zeros_(self.final_proj.bias) 107 | 108 | def _initialize_weights(self): 109 | # Initialize weights with He initialization and zero out the biases 110 | conv_blocks = [self.conv_layers, self.conv_layers_1, self.conv_layers_2, self.conv_layers_3, self.conv_layers_4] 111 | for block_item in conv_blocks: 112 | for m in block_item: 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 115 | init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n)) 116 | if m.bias is not None: 117 | init.zeros_(m.bias) 118 | 119 | # For the final projection layer, initialize weights to zero (or you may choose to use He initialization here as well) 120 | init.zeros_(self.final_proj.weight) 121 | if self.final_proj.bias is not None: 122 | init.zeros_(self.final_proj.bias) 123 | 124 | def forward(self, x, ref_x): 125 | fea = [] 126 | b = x.shape[0] 127 | 128 | x = rearrange(x, "b c f h w -> (b f) c h w") 129 | x = self.conv_layers(x) 130 | x = self.final_proj(x) 131 | x = x * self.scale 132 | # x = rearrange(x, "(b f) c h w -> b c f h w", b=b) 133 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 134 | 135 | x = self.conv_layers_1(x) 136 | if self.use_ca: 137 | ref_x = self.conv_layers(ref_x) 138 | ref_x = self.final_proj(ref_x) 139 | ref_x = ref_x * self.scale 140 | ref_x = self.conv_layers_1(ref_x) 141 | x = self.cross_attn1(x, ref_x) 142 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 143 | 144 | x = self.conv_layers_2(x) 145 | if self.use_ca: 146 | ref_x = self.conv_layers_2(ref_x) 147 | x = self.cross_attn2(x, ref_x) 148 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 149 | 150 | x = self.conv_layers_3(x) 151 | if self.use_ca: 152 | ref_x = self.conv_layers_3(ref_x) 153 | x = self.cross_attn3(x, ref_x) 154 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 155 | 156 | x = self.conv_layers_4(x) 157 | if self.use_ca: 158 | ref_x = self.conv_layers_4(ref_x) 159 | x = self.cross_attn4(x, ref_x) 160 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 161 | 162 | return fea 163 | 164 | @classmethod 165 | def from_pretrained(cls,pretrained_model_path): 166 | if not os.path.exists(pretrained_model_path): 167 | print(f"There is no model file in {pretrained_model_path}") 168 | print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...") 169 | 170 | state_dict = torch.load(pretrained_model_path, map_location="cpu") 171 | model = Hack_PoseGuider(noise_latent_channels=320) 172 | 173 | m, u = model.load_state_dict(state_dict, strict=True) 174 | # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 175 | params = [p.numel() for n, p in model.named_parameters()] 176 | print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M") 177 | 178 | return model 179 | 180 | 181 | class Transformer2DModel(ModelMixin): 182 | _supports_gradient_checkpointing = True 183 | def __init__( 184 | self, 185 | num_attention_heads: int = 16, 186 | attention_head_dim: int = 88, 187 | in_channels: Optional[int] = None, 188 | num_layers: int = 1, 189 | dropout: float = 0.0, 190 | norm_num_groups: int = 32, 191 | cross_attention_dim: Optional[int] = None, 192 | attention_bias: bool = False, 193 | activation_fn: str = "geglu", 194 | num_embeds_ada_norm: Optional[int] = None, 195 | use_linear_projection: bool = False, 196 | only_cross_attention: bool = False, 197 | double_self_attention: bool = False, 198 | upcast_attention: bool = False, 199 | norm_type: str = "layer_norm", 200 | norm_elementwise_affine: bool = True, 201 | norm_eps: float = 1e-5, 202 | attention_type: str = "default", 203 | ): 204 | super().__init__() 205 | self.use_linear_projection = use_linear_projection 206 | self.num_attention_heads = num_attention_heads 207 | self.attention_head_dim = attention_head_dim 208 | inner_dim = num_attention_heads * attention_head_dim 209 | 210 | self.in_channels = in_channels 211 | 212 | self.norm = torch.nn.GroupNorm( 213 | num_groups=norm_num_groups, 214 | num_channels=in_channels, 215 | eps=1e-6, 216 | affine=True, 217 | ) 218 | if use_linear_projection: 219 | self.proj_in = nn.Linear(in_channels, inner_dim) 220 | else: 221 | self.proj_in = nn.Conv2d( 222 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 223 | ) 224 | 225 | # 3. Define transformers blocks 226 | self.transformer_blocks = nn.ModuleList( 227 | [ 228 | BasicTransformerBlock( 229 | inner_dim, 230 | num_attention_heads, 231 | attention_head_dim, 232 | dropout=dropout, 233 | cross_attention_dim=cross_attention_dim, 234 | activation_fn=activation_fn, 235 | num_embeds_ada_norm=num_embeds_ada_norm, 236 | attention_bias=attention_bias, 237 | only_cross_attention=only_cross_attention, 238 | double_self_attention=double_self_attention, 239 | upcast_attention=upcast_attention, 240 | norm_type=norm_type, 241 | norm_elementwise_affine=norm_elementwise_affine, 242 | norm_eps=norm_eps, 243 | attention_type=attention_type, 244 | ) 245 | for d in range(num_layers) 246 | ] 247 | ) 248 | 249 | if use_linear_projection: 250 | self.proj_out = nn.Linear(inner_dim, in_channels) 251 | else: 252 | self.proj_out = nn.Conv2d( 253 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 254 | ) 255 | 256 | self.gradient_checkpointing = False 257 | 258 | def _set_gradient_checkpointing(self, module, value=False): 259 | if hasattr(module, "gradient_checkpointing"): 260 | module.gradient_checkpointing = value 261 | 262 | def forward( 263 | self, 264 | hidden_states: torch.Tensor, 265 | encoder_hidden_states: Optional[torch.Tensor] = None, 266 | timestep: Optional[torch.LongTensor] = None, 267 | ): 268 | batch, _, height, width = hidden_states.shape 269 | residual = hidden_states 270 | 271 | hidden_states = self.norm(hidden_states) 272 | if not self.use_linear_projection: 273 | hidden_states = self.proj_in(hidden_states) 274 | inner_dim = hidden_states.shape[1] 275 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 276 | batch, height * width, inner_dim 277 | ) 278 | else: 279 | inner_dim = hidden_states.shape[1] 280 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 281 | batch, height * width, inner_dim 282 | ) 283 | hidden_states = self.proj_in(hidden_states) 284 | 285 | for block in self.transformer_blocks: 286 | hidden_states = block( 287 | hidden_states, 288 | encoder_hidden_states=encoder_hidden_states, 289 | timestep=timestep, 290 | ) 291 | 292 | if not self.use_linear_projection: 293 | hidden_states = ( 294 | hidden_states.reshape(batch, height, width, inner_dim) 295 | .permute(0, 3, 1, 2) 296 | .contiguous() 297 | ) 298 | hidden_states = self.proj_out(hidden_states) 299 | else: 300 | hidden_states = self.proj_out(hidden_states) 301 | hidden_states = ( 302 | hidden_states.reshape(batch, height, width, inner_dim) 303 | .permute(0, 3, 1, 2) 304 | .contiguous() 305 | ) 306 | 307 | output = hidden_states + residual 308 | return output 309 | 310 | 311 | if __name__ == '__main__': 312 | model = PoseGuider(noise_latent_channels=320).to(device="cuda") 313 | 314 | input_data = torch.randn(1,3,1,512,512).to(device="cuda") 315 | input_data1 = torch.randn(1,3,512,512).to(device="cuda") 316 | 317 | output = model(input_data, input_data1) 318 | for item in output: 319 | print(item.shape) 320 | 321 | # tf_model = Transformer2DModel( 322 | # in_channels=320 323 | # ).to('cuda') 324 | 325 | # input_data = torch.randn(4,320,32,32).to(device="cuda") 326 | # # input_emb = torch.randn(4,1,768).to(device="cuda") 327 | # input_emb = torch.randn(4,320,32,32).to(device="cuda") 328 | # o1 = tf_model(input_data, input_emb) 329 | # print(o1.shape) 330 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from typing import Dict, Optional 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class InflatedGroupNorm(nn.GroupNorm): 22 | def forward(self, x): 23 | video_length = x.shape[2] 24 | 25 | x = rearrange(x, "b c f h w -> (b f) c h w") 26 | x = super().forward(x) 27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 28 | 29 | return x 30 | 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__( 34 | self, 35 | channels, 36 | use_conv=False, 37 | use_conv_transpose=False, 38 | out_channels=None, 39 | name="conv", 40 | ): 41 | super().__init__() 42 | self.channels = channels 43 | self.out_channels = out_channels or channels 44 | self.use_conv = use_conv 45 | self.use_conv_transpose = use_conv_transpose 46 | self.name = name 47 | 48 | conv = None 49 | if use_conv_transpose: 50 | raise NotImplementedError 51 | elif use_conv: 52 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 53 | 54 | def forward(self, hidden_states, output_size=None): 55 | assert hidden_states.shape[1] == self.channels 56 | 57 | if self.use_conv_transpose: 58 | raise NotImplementedError 59 | 60 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 61 | dtype = hidden_states.dtype 62 | if dtype == torch.bfloat16: 63 | hidden_states = hidden_states.to(torch.float32) 64 | 65 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 66 | if hidden_states.shape[0] >= 64: 67 | hidden_states = hidden_states.contiguous() 68 | 69 | # if `output_size` is passed we force the interpolation output 70 | # size and do not make use of `scale_factor=2` 71 | if output_size is None: 72 | hidden_states = F.interpolate( 73 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 74 | ) 75 | else: 76 | hidden_states = F.interpolate( 77 | hidden_states, size=output_size, mode="nearest" 78 | ) 79 | 80 | # If the input is bfloat16, we cast back to bfloat16 81 | if dtype == torch.bfloat16: 82 | hidden_states = hidden_states.to(dtype) 83 | 84 | # if self.use_conv: 85 | # if self.name == "conv": 86 | # hidden_states = self.conv(hidden_states) 87 | # else: 88 | # hidden_states = self.Conv2d_0(hidden_states) 89 | hidden_states = self.conv(hidden_states) 90 | 91 | return hidden_states 92 | 93 | 94 | class Downsample3D(nn.Module): 95 | def __init__( 96 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 97 | ): 98 | super().__init__() 99 | self.channels = channels 100 | self.out_channels = out_channels or channels 101 | self.use_conv = use_conv 102 | self.padding = padding 103 | stride = 2 104 | self.name = name 105 | 106 | if use_conv: 107 | self.conv = InflatedConv3d( 108 | self.channels, self.out_channels, 3, stride=stride, padding=padding 109 | ) 110 | else: 111 | raise NotImplementedError 112 | 113 | def forward(self, hidden_states): 114 | assert hidden_states.shape[1] == self.channels 115 | if self.use_conv and self.padding == 0: 116 | raise NotImplementedError 117 | 118 | assert hidden_states.shape[1] == self.channels 119 | hidden_states = self.conv(hidden_states) 120 | 121 | return hidden_states 122 | 123 | 124 | class ResnetBlock3D(nn.Module): 125 | def __init__( 126 | self, 127 | *, 128 | in_channels, 129 | out_channels=None, 130 | conv_shortcut=False, 131 | dropout=0.0, 132 | temb_channels=512, 133 | groups=32, 134 | groups_out=None, 135 | pre_norm=True, 136 | eps=1e-6, 137 | non_linearity="swish", 138 | time_embedding_norm="default", 139 | output_scale_factor=1.0, 140 | use_in_shortcut=None, 141 | use_inflated_groupnorm=None, 142 | ): 143 | super().__init__() 144 | self.pre_norm = pre_norm 145 | self.pre_norm = True 146 | self.in_channels = in_channels 147 | out_channels = in_channels if out_channels is None else out_channels 148 | self.out_channels = out_channels 149 | self.use_conv_shortcut = conv_shortcut 150 | self.time_embedding_norm = time_embedding_norm 151 | self.output_scale_factor = output_scale_factor 152 | 153 | if groups_out is None: 154 | groups_out = groups 155 | 156 | assert use_inflated_groupnorm != None 157 | if use_inflated_groupnorm: 158 | self.norm1 = InflatedGroupNorm( 159 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 160 | ) 161 | else: 162 | self.norm1 = torch.nn.GroupNorm( 163 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 164 | ) 165 | 166 | self.conv1 = InflatedConv3d( 167 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 168 | ) 169 | 170 | if temb_channels is not None: 171 | if self.time_embedding_norm == "default": 172 | time_emb_proj_out_channels = out_channels 173 | elif self.time_embedding_norm == "scale_shift": 174 | time_emb_proj_out_channels = out_channels * 2 175 | else: 176 | raise ValueError( 177 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 178 | ) 179 | 180 | self.time_emb_proj = torch.nn.Linear( 181 | temb_channels, time_emb_proj_out_channels 182 | ) 183 | else: 184 | self.time_emb_proj = None 185 | 186 | if use_inflated_groupnorm: 187 | self.norm2 = InflatedGroupNorm( 188 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 189 | ) 190 | else: 191 | self.norm2 = torch.nn.GroupNorm( 192 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 193 | ) 194 | self.dropout = torch.nn.Dropout(dropout) 195 | self.conv2 = InflatedConv3d( 196 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 197 | ) 198 | 199 | if non_linearity == "swish": 200 | self.nonlinearity = lambda x: F.silu(x) 201 | elif non_linearity == "mish": 202 | self.nonlinearity = Mish() 203 | elif non_linearity == "silu": 204 | self.nonlinearity = nn.SiLU() 205 | 206 | self.use_in_shortcut = ( 207 | self.in_channels != self.out_channels 208 | if use_in_shortcut is None 209 | else use_in_shortcut 210 | ) 211 | 212 | self.conv_shortcut = None 213 | if self.use_in_shortcut: 214 | self.conv_shortcut = InflatedConv3d( 215 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 216 | ) 217 | 218 | def forward(self, input_tensor, temb): 219 | hidden_states = input_tensor 220 | 221 | hidden_states = self.norm1(hidden_states) 222 | hidden_states = self.nonlinearity(hidden_states) 223 | 224 | hidden_states = self.conv1(hidden_states) 225 | 226 | if temb is not None: 227 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 228 | 229 | if temb is not None and self.time_embedding_norm == "default": 230 | hidden_states = hidden_states + temb 231 | 232 | hidden_states = self.norm2(hidden_states) 233 | 234 | if temb is not None and self.time_embedding_norm == "scale_shift": 235 | scale, shift = torch.chunk(temb, 2, dim=1) 236 | hidden_states = hidden_states * (1 + scale) + shift 237 | 238 | hidden_states = self.nonlinearity(hidden_states) 239 | 240 | hidden_states = self.dropout(hidden_states) 241 | hidden_states = self.conv2(hidden_states) 242 | 243 | if self.conv_shortcut is not None: 244 | input_tensor = self.conv_shortcut(input_tensor) 245 | 246 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 247 | 248 | return output_tensor 249 | 250 | class Mish(torch.nn.Module): 251 | def forward(self, hidden_states): 252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 253 | -------------------------------------------------------------------------------- /src/models/transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Optional 4 | 5 | import torch 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.embeddings import PixArtAlphaTextProjection 8 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 9 | from diffusers.models.modeling_utils import ModelMixin 10 | from diffusers.models.normalization import AdaLayerNormSingle 11 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 12 | from torch import nn 13 | 14 | from .attention import BasicTransformerBlock 15 | 16 | 17 | @dataclass 18 | class Transformer2DModelOutput(BaseOutput): 19 | """ 20 | The output of [`Transformer2DModel`]. 21 | 22 | Args: 23 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 24 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 25 | distributions for the unnoised latent pixels. 26 | """ 27 | 28 | sample: torch.FloatTensor 29 | ref_feature: torch.FloatTensor 30 | 31 | 32 | class Transformer2DModel(ModelMixin, ConfigMixin): 33 | """ 34 | A 2D Transformer model for image-like data. 35 | 36 | Parameters: 37 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 38 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 39 | in_channels (`int`, *optional*): 40 | The number of channels in the input and output (specify if the input is **continuous**). 41 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 42 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 43 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 44 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 45 | This is fixed during training since it is used to learn a number of position embeddings. 46 | num_vector_embeds (`int`, *optional*): 47 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 48 | Includes the class for the masked latent pixel. 49 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 50 | num_embeds_ada_norm ( `int`, *optional*): 51 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 52 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 53 | added to the hidden states. 54 | 55 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 56 | attention_bias (`bool`, *optional*): 57 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 58 | """ 59 | 60 | _supports_gradient_checkpointing = True 61 | 62 | @register_to_config 63 | def __init__( 64 | self, 65 | num_attention_heads: int = 16, 66 | attention_head_dim: int = 88, 67 | in_channels: Optional[int] = None, 68 | out_channels: Optional[int] = None, 69 | num_layers: int = 1, 70 | dropout: float = 0.0, 71 | norm_num_groups: int = 32, 72 | cross_attention_dim: Optional[int] = None, 73 | attention_bias: bool = False, 74 | sample_size: Optional[int] = None, 75 | num_vector_embeds: Optional[int] = None, 76 | patch_size: Optional[int] = None, 77 | activation_fn: str = "geglu", 78 | num_embeds_ada_norm: Optional[int] = None, 79 | use_linear_projection: bool = False, 80 | only_cross_attention: bool = False, 81 | double_self_attention: bool = False, 82 | upcast_attention: bool = False, 83 | norm_type: str = "layer_norm", 84 | norm_elementwise_affine: bool = True, 85 | norm_eps: float = 1e-5, 86 | attention_type: str = "default", 87 | caption_channels: int = None, 88 | ): 89 | super().__init__() 90 | self.use_linear_projection = use_linear_projection 91 | self.num_attention_heads = num_attention_heads 92 | self.attention_head_dim = attention_head_dim 93 | inner_dim = num_attention_heads * attention_head_dim 94 | 95 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 96 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 97 | 98 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 99 | # Define whether input is continuous or discrete depending on configuration 100 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 101 | self.is_input_vectorized = num_vector_embeds is not None 102 | self.is_input_patches = in_channels is not None and patch_size is not None 103 | 104 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 105 | deprecation_message = ( 106 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 107 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 108 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 109 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 110 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 111 | ) 112 | deprecate( 113 | "norm_type!=num_embeds_ada_norm", 114 | "1.0.0", 115 | deprecation_message, 116 | standard_warn=False, 117 | ) 118 | norm_type = "ada_norm" 119 | 120 | if self.is_input_continuous and self.is_input_vectorized: 121 | raise ValueError( 122 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 123 | " sure that either `in_channels` or `num_vector_embeds` is None." 124 | ) 125 | elif self.is_input_vectorized and self.is_input_patches: 126 | raise ValueError( 127 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 128 | " sure that either `num_vector_embeds` or `num_patches` is None." 129 | ) 130 | elif ( 131 | not self.is_input_continuous 132 | and not self.is_input_vectorized 133 | and not self.is_input_patches 134 | ): 135 | raise ValueError( 136 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 137 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 138 | ) 139 | 140 | # 2. Define input layers 141 | self.in_channels = in_channels 142 | 143 | self.norm = torch.nn.GroupNorm( 144 | num_groups=norm_num_groups, 145 | num_channels=in_channels, 146 | eps=1e-6, 147 | affine=True, 148 | ) 149 | if use_linear_projection: 150 | self.proj_in = linear_cls(in_channels, inner_dim) 151 | else: 152 | self.proj_in = conv_cls( 153 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 154 | ) 155 | 156 | # 3. Define transformers blocks 157 | self.transformer_blocks = nn.ModuleList( 158 | [ 159 | BasicTransformerBlock( 160 | inner_dim, 161 | num_attention_heads, 162 | attention_head_dim, 163 | dropout=dropout, 164 | cross_attention_dim=cross_attention_dim, 165 | activation_fn=activation_fn, 166 | num_embeds_ada_norm=num_embeds_ada_norm, 167 | attention_bias=attention_bias, 168 | only_cross_attention=only_cross_attention, 169 | double_self_attention=double_self_attention, 170 | upcast_attention=upcast_attention, 171 | norm_type=norm_type, 172 | norm_elementwise_affine=norm_elementwise_affine, 173 | norm_eps=norm_eps, 174 | attention_type=attention_type, 175 | ) 176 | for d in range(num_layers) 177 | ] 178 | ) 179 | 180 | # 4. Define output layers 181 | self.out_channels = in_channels if out_channels is None else out_channels 182 | # TODO: should use out_channels for continuous projections 183 | if use_linear_projection: 184 | self.proj_out = linear_cls(inner_dim, in_channels) 185 | else: 186 | self.proj_out = conv_cls( 187 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 188 | ) 189 | 190 | # 5. PixArt-Alpha blocks. 191 | self.adaln_single = None 192 | self.use_additional_conditions = False 193 | if norm_type == "ada_norm_single": 194 | self.use_additional_conditions = self.config.sample_size == 128 195 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 196 | # additional conditions until we find better name 197 | self.adaln_single = AdaLayerNormSingle( 198 | inner_dim, use_additional_conditions=self.use_additional_conditions 199 | ) 200 | 201 | self.caption_projection = None 202 | if caption_channels is not None: 203 | self.caption_projection = PixArtAlphaTextProjection( 204 | in_features=caption_channels, hidden_size=inner_dim 205 | ) 206 | 207 | self.gradient_checkpointing = False 208 | 209 | def _set_gradient_checkpointing(self, module, value=False): 210 | if hasattr(module, "gradient_checkpointing"): 211 | module.gradient_checkpointing = value 212 | 213 | def forward( 214 | self, 215 | hidden_states: torch.Tensor, 216 | encoder_hidden_states: Optional[torch.Tensor] = None, 217 | timestep: Optional[torch.LongTensor] = None, 218 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 219 | class_labels: Optional[torch.LongTensor] = None, 220 | cross_attention_kwargs: Dict[str, Any] = None, 221 | attention_mask: Optional[torch.Tensor] = None, 222 | encoder_attention_mask: Optional[torch.Tensor] = None, 223 | return_dict: bool = True, 224 | ): 225 | """ 226 | The [`Transformer2DModel`] forward method. 227 | 228 | Args: 229 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 230 | Input `hidden_states`. 231 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 232 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 233 | self-attention. 234 | timestep ( `torch.LongTensor`, *optional*): 235 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 236 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 237 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 238 | `AdaLayerZeroNorm`. 239 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 240 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 241 | `self.processor` in 242 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 243 | attention_mask ( `torch.Tensor`, *optional*): 244 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 245 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 246 | negative values to the attention scores corresponding to "discard" tokens. 247 | encoder_attention_mask ( `torch.Tensor`, *optional*): 248 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 249 | 250 | * Mask `(batch, sequence_length)` True = keep, False = discard. 251 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 252 | 253 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 254 | above. This bias will be added to the cross-attention scores. 255 | return_dict (`bool`, *optional*, defaults to `True`): 256 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 257 | tuple. 258 | 259 | Returns: 260 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 261 | `tuple` where the first element is the sample tensor. 262 | """ 263 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 264 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 265 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 266 | # expects mask of shape: 267 | # [batch, key_tokens] 268 | # adds singleton query_tokens dimension: 269 | # [batch, 1, key_tokens] 270 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 271 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 272 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 273 | if attention_mask is not None and attention_mask.ndim == 2: 274 | # assume that mask is expressed as: 275 | # (1 = keep, 0 = discard) 276 | # convert mask into a bias that can be added to attention scores: 277 | # (keep = +0, discard = -10000.0) 278 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 279 | attention_mask = attention_mask.unsqueeze(1) 280 | 281 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 282 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 283 | encoder_attention_mask = ( 284 | 1 - encoder_attention_mask.to(hidden_states.dtype) 285 | ) * -10000.0 286 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 287 | 288 | # Retrieve lora scale. 289 | lora_scale = ( 290 | cross_attention_kwargs.get("scale", 1.0) 291 | if cross_attention_kwargs is not None 292 | else 1.0 293 | ) 294 | 295 | # 1. Input 296 | batch, _, height, width = hidden_states.shape 297 | residual = hidden_states 298 | 299 | hidden_states = self.norm(hidden_states) 300 | if not self.use_linear_projection: 301 | hidden_states = ( 302 | self.proj_in(hidden_states, scale=lora_scale) 303 | if not USE_PEFT_BACKEND 304 | else self.proj_in(hidden_states) 305 | ) 306 | inner_dim = hidden_states.shape[1] 307 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 308 | batch, height * width, inner_dim 309 | ) 310 | else: 311 | inner_dim = hidden_states.shape[1] 312 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 313 | batch, height * width, inner_dim 314 | ) 315 | hidden_states = ( 316 | self.proj_in(hidden_states, scale=lora_scale) 317 | if not USE_PEFT_BACKEND 318 | else self.proj_in(hidden_states) 319 | ) 320 | 321 | # 2. Blocks 322 | if self.caption_projection is not None: 323 | batch_size = hidden_states.shape[0] 324 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 325 | encoder_hidden_states = encoder_hidden_states.view( 326 | batch_size, -1, hidden_states.shape[-1] 327 | ) 328 | 329 | ref_feature = hidden_states.reshape(batch, height, width, inner_dim) 330 | for block in self.transformer_blocks: 331 | if self.training and self.gradient_checkpointing: 332 | 333 | def create_custom_forward(module, return_dict=None): 334 | def custom_forward(*inputs): 335 | if return_dict is not None: 336 | return module(*inputs, return_dict=return_dict) 337 | else: 338 | return module(*inputs) 339 | 340 | return custom_forward 341 | 342 | ckpt_kwargs: Dict[str, Any] = ( 343 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 344 | ) 345 | hidden_states = torch.utils.checkpoint.checkpoint( 346 | create_custom_forward(block), 347 | hidden_states, 348 | attention_mask, 349 | encoder_hidden_states, 350 | encoder_attention_mask, 351 | timestep, 352 | cross_attention_kwargs, 353 | class_labels, 354 | **ckpt_kwargs, 355 | ) 356 | else: 357 | hidden_states = block( 358 | hidden_states, 359 | attention_mask=attention_mask, 360 | encoder_hidden_states=encoder_hidden_states, 361 | encoder_attention_mask=encoder_attention_mask, 362 | timestep=timestep, 363 | cross_attention_kwargs=cross_attention_kwargs, 364 | class_labels=class_labels, 365 | ) 366 | 367 | # 3. Output 368 | if self.is_input_continuous: 369 | if not self.use_linear_projection: 370 | hidden_states = ( 371 | hidden_states.reshape(batch, height, width, inner_dim) 372 | .permute(0, 3, 1, 2) 373 | .contiguous() 374 | ) 375 | hidden_states = ( 376 | self.proj_out(hidden_states, scale=lora_scale) 377 | if not USE_PEFT_BACKEND 378 | else self.proj_out(hidden_states) 379 | ) 380 | else: 381 | hidden_states = ( 382 | self.proj_out(hidden_states, scale=lora_scale) 383 | if not USE_PEFT_BACKEND 384 | else self.proj_out(hidden_states) 385 | ) 386 | hidden_states = ( 387 | hidden_states.reshape(batch, height, width, inner_dim) 388 | .permute(0, 3, 1, 2) 389 | .contiguous() 390 | ) 391 | 392 | output = hidden_states + residual 393 | if not return_dict: 394 | return (output, ref_feature) 395 | 396 | return Transformer2DModelOutput(sample=output, ref_feature=ref_feature) 397 | -------------------------------------------------------------------------------- /src/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Dict 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from .attention import TemporalBasicTransformerBlock, ResidualTemporalBasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer3DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | 19 | 20 | if is_xformers_available(): 21 | import xformers 22 | import xformers.ops 23 | else: 24 | xformers = None 25 | 26 | 27 | class Transformer3DModel(ModelMixin, ConfigMixin): 28 | _supports_gradient_checkpointing = True 29 | 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm( 59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 60 | ) 61 | if use_linear_projection: 62 | self.proj_in = nn.Linear(in_channels, inner_dim) 63 | else: 64 | self.proj_in = nn.Conv2d( 65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 66 | ) 67 | 68 | # Define transformers blocks 69 | self.transformer_blocks = nn.ModuleList( 70 | [ 71 | TemporalBasicTransformerBlock( 72 | inner_dim, 73 | num_attention_heads, 74 | attention_head_dim, 75 | dropout=dropout, 76 | cross_attention_dim=cross_attention_dim, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | attention_bias=attention_bias, 80 | only_cross_attention=only_cross_attention, 81 | upcast_attention=upcast_attention, 82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 83 | unet_use_temporal_attention=unet_use_temporal_attention, 84 | ) 85 | for d in range(num_layers) 86 | ] 87 | ) 88 | 89 | # 4. Define output layers 90 | if use_linear_projection: 91 | self.proj_out = nn.Linear(in_channels, inner_dim) 92 | else: 93 | self.proj_out = nn.Conv2d( 94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 95 | ) 96 | 97 | self.gradient_checkpointing = False 98 | 99 | def _set_gradient_checkpointing(self, module, value=False): 100 | if hasattr(module, "gradient_checkpointing"): 101 | module.gradient_checkpointing = value 102 | 103 | def forward( 104 | self, 105 | hidden_states, 106 | encoder_hidden_states=None, 107 | timestep=None, 108 | return_dict: bool = True, 109 | ): 110 | # Input 111 | assert ( 112 | hidden_states.dim() == 5 113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 114 | video_length = hidden_states.shape[2] 115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 117 | encoder_hidden_states = repeat( 118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 119 | ) 120 | 121 | batch, channel, height, weight = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | if not self.use_linear_projection: 126 | hidden_states = self.proj_in(hidden_states) 127 | inner_dim = hidden_states.shape[1] 128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 129 | batch, height * weight, inner_dim 130 | ) 131 | else: 132 | inner_dim = hidden_states.shape[1] 133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 134 | batch, height * weight, inner_dim 135 | ) 136 | hidden_states = self.proj_in(hidden_states) 137 | 138 | # Blocks 139 | for i, block in enumerate(self.transformer_blocks): 140 | hidden_states = block( 141 | hidden_states, 142 | encoder_hidden_states=encoder_hidden_states, 143 | timestep=timestep, 144 | video_length=video_length, 145 | ) 146 | 147 | # Output 148 | if not self.use_linear_projection: 149 | hidden_states = ( 150 | hidden_states.reshape(batch, height, weight, inner_dim) 151 | .permute(0, 3, 1, 2) 152 | .contiguous() 153 | ) 154 | hidden_states = self.proj_out(hidden_states) 155 | else: 156 | hidden_states = self.proj_out(hidden_states) 157 | hidden_states = ( 158 | hidden_states.reshape(batch, height, weight, inner_dim) 159 | .permute(0, 3, 1, 2) 160 | .contiguous() 161 | ) 162 | 163 | output = hidden_states + residual 164 | 165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 166 | if not return_dict: 167 | return (output,) 168 | 169 | return Transformer3DModelOutput(sample=output) 170 | -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__init__.py -------------------------------------------------------------------------------- /src/pipelines/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/__pycache__/context.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/context.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/__pycache__/pipeline_pose2vid_long.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/pipeline_pose2vid_long.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # TODO: Adapted from cli 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | def ordered_halving(val): 8 | bin_str = f"{val:064b}" 9 | bin_flip = bin_str[::-1] 10 | as_int = int(bin_flip, 2) 11 | 12 | return as_int / (1 << 64) 13 | 14 | 15 | def uniform( 16 | step: int = ..., 17 | num_steps: Optional[int] = None, 18 | num_frames: int = ..., 19 | context_size: Optional[int] = None, 20 | context_stride: int = 3, 21 | context_overlap: int = 4, 22 | closed_loop: bool = True, 23 | ): 24 | if num_frames <= context_size: 25 | yield list(range(num_frames)) 26 | return 27 | 28 | context_stride = min( 29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 30 | ) 31 | 32 | for context_step in 1 << np.arange(context_stride): 33 | pad = int(round(num_frames * ordered_halving(step))) 34 | for j in range( 35 | int(ordered_halving(step) * context_step) + pad, 36 | num_frames + pad + (0 if closed_loop else -context_overlap), 37 | (context_size * context_step - context_overlap), 38 | ): 39 | yield [ 40 | e % num_frames 41 | for e in range(j, j + context_size * context_step, context_step) 42 | ] 43 | 44 | 45 | def get_context_scheduler(name: str) -> Callable: 46 | if name == "uniform": 47 | return uniform 48 | else: 49 | raise ValueError(f"Unknown context_overlap policy {name}") 50 | 51 | 52 | def get_total_steps( 53 | scheduler, 54 | timesteps: List[int], 55 | num_steps: Optional[int] = None, 56 | num_frames: int = ..., 57 | context_size: Optional[int] = None, 58 | context_stride: int = 3, 59 | context_overlap: int = 4, 60 | closed_loop: bool = True, 61 | ): 62 | return sum( 63 | len( 64 | list( 65 | scheduler( 66 | i, 67 | num_steps, 68 | num_frames, 69 | context_size, 70 | context_stride, 71 | context_overlap, 72 | ) 73 | ) 74 | ) 75 | for i in range(len(timesteps)) 76 | ) 77 | -------------------------------------------------------------------------------- /src/pipelines/pipeline_pose2vid.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | from diffusers import DiffusionPipeline 10 | from diffusers.image_processor import VaeImageProcessor 11 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, 12 | EulerAncestralDiscreteScheduler, 13 | EulerDiscreteScheduler, LMSDiscreteScheduler, 14 | PNDMScheduler) 15 | from diffusers.utils import BaseOutput, is_accelerate_available 16 | from diffusers.utils.torch_utils import randn_tensor 17 | from einops import rearrange 18 | from tqdm import tqdm 19 | from transformers import CLIPImageProcessor 20 | 21 | from ..models.mutual_self_attention import ReferenceAttentionControl 22 | 23 | 24 | @dataclass 25 | class Pose2VideoPipelineOutput(BaseOutput): 26 | videos: Union[torch.Tensor, np.ndarray] 27 | 28 | 29 | class Pose2VideoPipeline(DiffusionPipeline): 30 | _optional_components = [] 31 | 32 | def __init__( 33 | self, 34 | vae, 35 | image_encoder, 36 | reference_unet, 37 | denoising_unet, 38 | pose_guider, 39 | scheduler: Union[ 40 | DDIMScheduler, 41 | PNDMScheduler, 42 | LMSDiscreteScheduler, 43 | EulerDiscreteScheduler, 44 | EulerAncestralDiscreteScheduler, 45 | DPMSolverMultistepScheduler, 46 | ], 47 | image_proj_model=None, 48 | tokenizer=None, 49 | text_encoder=None, 50 | ): 51 | super().__init__() 52 | 53 | self.register_modules( 54 | vae=vae, 55 | image_encoder=image_encoder, 56 | reference_unet=reference_unet, 57 | denoising_unet=denoising_unet, 58 | pose_guider=pose_guider, 59 | scheduler=scheduler, 60 | image_proj_model=image_proj_model, 61 | tokenizer=tokenizer, 62 | text_encoder=text_encoder, 63 | ) 64 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 65 | self.clip_image_processor = CLIPImageProcessor() 66 | self.ref_image_processor = VaeImageProcessor( 67 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 68 | ) 69 | self.cond_image_processor = VaeImageProcessor( 70 | vae_scale_factor=self.vae_scale_factor, 71 | do_convert_rgb=True, 72 | do_normalize=True, 73 | ) 74 | 75 | def enable_vae_slicing(self): 76 | self.vae.enable_slicing() 77 | 78 | def disable_vae_slicing(self): 79 | self.vae.disable_slicing() 80 | 81 | def enable_sequential_cpu_offload(self, gpu_id=0): 82 | if is_accelerate_available(): 83 | from accelerate import cpu_offload 84 | else: 85 | raise ImportError("Please install accelerate via `pip install accelerate`") 86 | 87 | device = torch.device(f"cuda:{gpu_id}") 88 | 89 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 90 | if cpu_offloaded_model is not None: 91 | cpu_offload(cpu_offloaded_model, device) 92 | 93 | @property 94 | def _execution_device(self): 95 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 96 | return self.device 97 | for module in self.unet.modules(): 98 | if ( 99 | hasattr(module, "_hf_hook") 100 | and hasattr(module._hf_hook, "execution_device") 101 | and module._hf_hook.execution_device is not None 102 | ): 103 | return torch.device(module._hf_hook.execution_device) 104 | return self.device 105 | 106 | def decode_latents(self, latents): 107 | video_length = latents.shape[2] 108 | latents = 1 / 0.18215 * latents 109 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 110 | # video = self.vae.decode(latents).sample 111 | video = [] 112 | for frame_idx in tqdm(range(latents.shape[0])): 113 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 114 | video = torch.cat(video) 115 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 116 | video = (video / 2 + 0.5).clamp(0, 1) 117 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 118 | video = video.cpu().float().numpy() 119 | return video 120 | 121 | def prepare_extra_step_kwargs(self, generator, eta): 122 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 123 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 124 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 125 | # and should be between [0, 1] 126 | 127 | accepts_eta = "eta" in set( 128 | inspect.signature(self.scheduler.step).parameters.keys() 129 | ) 130 | extra_step_kwargs = {} 131 | if accepts_eta: 132 | extra_step_kwargs["eta"] = eta 133 | 134 | # check if the scheduler accepts generator 135 | accepts_generator = "generator" in set( 136 | inspect.signature(self.scheduler.step).parameters.keys() 137 | ) 138 | if accepts_generator: 139 | extra_step_kwargs["generator"] = generator 140 | return extra_step_kwargs 141 | 142 | def prepare_latents( 143 | self, 144 | batch_size, 145 | num_channels_latents, 146 | width, 147 | height, 148 | video_length, 149 | dtype, 150 | device, 151 | generator, 152 | latents=None, 153 | ): 154 | shape = ( 155 | batch_size, 156 | num_channels_latents, 157 | video_length, 158 | height // self.vae_scale_factor, 159 | width // self.vae_scale_factor, 160 | ) 161 | if isinstance(generator, list) and len(generator) != batch_size: 162 | raise ValueError( 163 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 164 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 165 | ) 166 | 167 | if latents is None: 168 | latents = randn_tensor( 169 | shape, generator=generator, device=device, dtype=dtype 170 | ) 171 | else: 172 | latents = latents.to(device) 173 | 174 | # scale the initial noise by the standard deviation required by the scheduler 175 | latents = latents * self.scheduler.init_noise_sigma 176 | return latents 177 | 178 | def _encode_prompt( 179 | self, 180 | prompt, 181 | device, 182 | num_videos_per_prompt, 183 | do_classifier_free_guidance, 184 | negative_prompt, 185 | ): 186 | batch_size = len(prompt) if isinstance(prompt, list) else 1 187 | 188 | text_inputs = self.tokenizer( 189 | prompt, 190 | padding="max_length", 191 | max_length=self.tokenizer.model_max_length, 192 | truncation=True, 193 | return_tensors="pt", 194 | ) 195 | text_input_ids = text_inputs.input_ids 196 | untruncated_ids = self.tokenizer( 197 | prompt, padding="longest", return_tensors="pt" 198 | ).input_ids 199 | 200 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 201 | text_input_ids, untruncated_ids 202 | ): 203 | removed_text = self.tokenizer.batch_decode( 204 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 205 | ) 206 | 207 | if ( 208 | hasattr(self.text_encoder.config, "use_attention_mask") 209 | and self.text_encoder.config.use_attention_mask 210 | ): 211 | attention_mask = text_inputs.attention_mask.to(device) 212 | else: 213 | attention_mask = None 214 | 215 | text_embeddings = self.text_encoder( 216 | text_input_ids.to(device), 217 | attention_mask=attention_mask, 218 | ) 219 | text_embeddings = text_embeddings[0] 220 | 221 | # duplicate text embeddings for each generation per prompt, using mps friendly method 222 | bs_embed, seq_len, _ = text_embeddings.shape 223 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 224 | text_embeddings = text_embeddings.view( 225 | bs_embed * num_videos_per_prompt, seq_len, -1 226 | ) 227 | 228 | # get unconditional embeddings for classifier free guidance 229 | if do_classifier_free_guidance: 230 | uncond_tokens: List[str] 231 | if negative_prompt is None: 232 | uncond_tokens = [""] * batch_size 233 | elif type(prompt) is not type(negative_prompt): 234 | raise TypeError( 235 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 236 | f" {type(prompt)}." 237 | ) 238 | elif isinstance(negative_prompt, str): 239 | uncond_tokens = [negative_prompt] 240 | elif batch_size != len(negative_prompt): 241 | raise ValueError( 242 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 243 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 244 | " the batch size of `prompt`." 245 | ) 246 | else: 247 | uncond_tokens = negative_prompt 248 | 249 | max_length = text_input_ids.shape[-1] 250 | uncond_input = self.tokenizer( 251 | uncond_tokens, 252 | padding="max_length", 253 | max_length=max_length, 254 | truncation=True, 255 | return_tensors="pt", 256 | ) 257 | 258 | if ( 259 | hasattr(self.text_encoder.config, "use_attention_mask") 260 | and self.text_encoder.config.use_attention_mask 261 | ): 262 | attention_mask = uncond_input.attention_mask.to(device) 263 | else: 264 | attention_mask = None 265 | 266 | uncond_embeddings = self.text_encoder( 267 | uncond_input.input_ids.to(device), 268 | attention_mask=attention_mask, 269 | ) 270 | uncond_embeddings = uncond_embeddings[0] 271 | 272 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 273 | seq_len = uncond_embeddings.shape[1] 274 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 275 | uncond_embeddings = uncond_embeddings.view( 276 | batch_size * num_videos_per_prompt, seq_len, -1 277 | ) 278 | 279 | # For classifier free guidance, we need to do two forward passes. 280 | # Here we concatenate the unconditional and text embeddings into a single batch 281 | # to avoid doing two forward passes 282 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 283 | 284 | return text_embeddings 285 | 286 | @torch.no_grad() 287 | def __call__( 288 | self, 289 | ref_image, 290 | pose_images, 291 | ref_pose_image, 292 | width, 293 | height, 294 | video_length, 295 | num_inference_steps, 296 | guidance_scale, 297 | num_images_per_prompt=1, 298 | eta: float = 0.0, 299 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 300 | output_type: Optional[str] = "tensor", 301 | return_dict: bool = True, 302 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 303 | callback_steps: Optional[int] = 1, 304 | **kwargs, 305 | ): 306 | # Default height and width to unet 307 | height = height or self.unet.config.sample_size * self.vae_scale_factor 308 | width = width or self.unet.config.sample_size * self.vae_scale_factor 309 | 310 | device = self._execution_device 311 | 312 | do_classifier_free_guidance = guidance_scale > 1.0 313 | 314 | # Prepare timesteps 315 | self.scheduler.set_timesteps(num_inference_steps, device=device) 316 | timesteps = self.scheduler.timesteps 317 | 318 | batch_size = 1 319 | 320 | # Prepare clip image embeds 321 | clip_image = self.clip_image_processor.preprocess( 322 | ref_image, return_tensors="pt" 323 | ).pixel_values 324 | clip_image_embeds = self.image_encoder( 325 | clip_image.to(device, dtype=self.image_encoder.dtype) 326 | ).image_embeds 327 | encoder_hidden_states = clip_image_embeds.unsqueeze(1) 328 | 329 | uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) 330 | 331 | if do_classifier_free_guidance: 332 | encoder_hidden_states = torch.cat( 333 | [uncond_encoder_hidden_states, encoder_hidden_states], dim=0 334 | ) 335 | reference_control_writer = ReferenceAttentionControl( 336 | self.reference_unet, 337 | do_classifier_free_guidance=do_classifier_free_guidance, 338 | mode="write", 339 | batch_size=batch_size, 340 | fusion_blocks="full", 341 | ) 342 | reference_control_reader = ReferenceAttentionControl( 343 | self.denoising_unet, 344 | do_classifier_free_guidance=do_classifier_free_guidance, 345 | mode="read", 346 | batch_size=batch_size, 347 | fusion_blocks="full", 348 | ) 349 | 350 | num_channels_latents = self.denoising_unet.in_channels 351 | latents = self.prepare_latents( 352 | batch_size * num_images_per_prompt, 353 | num_channels_latents, 354 | width, 355 | height, 356 | video_length, 357 | clip_image_embeds.dtype, 358 | device, 359 | generator, 360 | ) 361 | 362 | # Prepare extra step kwargs. 363 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 364 | 365 | # Prepare ref image latents 366 | ref_image_tensor = self.ref_image_processor.preprocess( 367 | ref_image, height=height, width=width 368 | ) # (bs, c, width, height) 369 | ref_image_tensor = ref_image_tensor.to( 370 | dtype=self.vae.dtype, device=self.vae.device 371 | ) 372 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 373 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 374 | 375 | # Prepare a list of pose condition images 376 | pose_cond_tensor_list = [] 377 | for pose_image in pose_images: 378 | pose_cond_tensor = self.cond_image_processor.preprocess( 379 | pose_image, height=height, width=width 380 | ).transpose(0, 1) # (c, 1, h, w) 381 | pose_cond_tensor_list.append(pose_cond_tensor) 382 | pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w) 383 | 384 | pose_cond_tensor = pose_cond_tensor.unsqueeze(0) # (1, c, t, h, w) 385 | pose_cond_tensor = pose_cond_tensor.to( 386 | device=device, dtype=self.pose_guider.dtype 387 | ) 388 | 389 | ref_pose_tensor = self.cond_image_processor.preprocess( 390 | ref_pose_image, height=height, width=width 391 | ) 392 | ref_pose_tensor = ref_pose_tensor.to( 393 | device=device, dtype=self.pose_guider.dtype 394 | ) 395 | 396 | pose_fea = self.pose_guider(pose_cond_tensor, ref_pose_tensor) 397 | if do_classifier_free_guidance: 398 | for idxx in range(len(pose_fea)): 399 | pose_fea[idxx] = torch.cat([pose_fea[idxx]] * 2) 400 | 401 | # denoising loop 402 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 403 | with self.progress_bar(total=num_inference_steps) as progress_bar: 404 | for i, t in enumerate(timesteps): 405 | # 1. Forward reference image 406 | if i == 0: 407 | self.reference_unet( 408 | ref_image_latents.repeat( 409 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 410 | ), 411 | torch.zeros_like(t), 412 | # t, 413 | encoder_hidden_states=encoder_hidden_states, 414 | return_dict=False, 415 | ) 416 | reference_control_reader.update(reference_control_writer) 417 | 418 | # 3.1 expand the latents if we are doing classifier free guidance 419 | latent_model_input = ( 420 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 421 | ) 422 | latent_model_input = self.scheduler.scale_model_input( 423 | latent_model_input, t 424 | ) 425 | 426 | noise_pred = self.denoising_unet( 427 | latent_model_input, 428 | t, 429 | encoder_hidden_states=encoder_hidden_states, 430 | pose_cond_fea=pose_fea, 431 | return_dict=False, 432 | )[0] 433 | 434 | # perform guidance 435 | if do_classifier_free_guidance: 436 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 437 | noise_pred = noise_pred_uncond + guidance_scale * ( 438 | noise_pred_text - noise_pred_uncond 439 | ) 440 | 441 | # compute the previous noisy sample x_t -> x_t-1 442 | latents = self.scheduler.step( 443 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 444 | )[0] 445 | 446 | # call the callback, if provided 447 | if i == len(timesteps) - 1 or ( 448 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 449 | ): 450 | progress_bar.update() 451 | if callback is not None and i % callback_steps == 0: 452 | step_idx = i // getattr(self.scheduler, "order", 1) 453 | callback(step_idx, t, latents) 454 | 455 | reference_control_reader.clear() 456 | reference_control_writer.clear() 457 | 458 | # Post-processing 459 | images = self.decode_latents(latents) # (b, c, f, h, w) 460 | 461 | # Convert to tensor 462 | if output_type == "tensor": 463 | images = torch.from_numpy(images) 464 | 465 | if not return_dict: 466 | return images 467 | 468 | return Pose2VideoPipelineOutput(videos=images) 469 | -------------------------------------------------------------------------------- /src/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | tensor_interpolation = None 4 | 5 | 6 | def get_tensor_interpolation_method(): 7 | return tensor_interpolation 8 | 9 | 10 | def set_tensor_interpolation_method(is_slerp): 11 | global tensor_interpolation 12 | tensor_interpolation = slerp if is_slerp else linear 13 | 14 | 15 | def linear(v1, v2, t): 16 | return (1.0 - t) * v1 + t * v2 17 | 18 | 19 | def slerp( 20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 21 | ) -> torch.Tensor: 22 | u0 = v0 / v0.norm() 23 | u1 = v1 / v1.norm() 24 | dot = (u0 * u1).sum() 25 | if dot.abs() > DOT_THRESHOLD: 26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 27 | return (1.0 - t) * v0 + t * v1 28 | omega = dot.acos() 29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 30 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/audio_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/audio_util.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/draw_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/draw_util.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/face_landmark.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/face_landmark.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/mp_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/mp_utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/pose_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/pose_util.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/audio_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import librosa 5 | import numpy as np 6 | from transformers import Wav2Vec2FeatureExtractor 7 | 8 | 9 | class DataProcessor: 10 | def __init__(self, sampling_rate, wav2vec_model_path): 11 | self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) 12 | self._sampling_rate = sampling_rate 13 | 14 | def extract_feature(self, audio_path): 15 | speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate) 16 | input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values) 17 | return input_value 18 | 19 | 20 | def prepare_audio_feature(wav_file, fps=30, sampling_rate=16000, wav2vec_model_path=None): 21 | data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path) 22 | 23 | input_value = data_preprocessor.extract_feature(wav_file) 24 | seq_len = math.ceil(len(input_value)/sampling_rate*fps) 25 | return { 26 | "audio_feature": input_value, 27 | "seq_len": seq_len 28 | } 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/utils/draw_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mediapipe as mp 3 | import numpy as np 4 | from mediapipe.framework.formats import landmark_pb2 5 | 6 | class FaceMeshVisualizer: 7 | def __init__(self, forehead_edge=False): 8 | self.mp_drawing = mp.solutions.drawing_utils 9 | mp_face_mesh = mp.solutions.face_mesh 10 | self.mp_face_mesh = mp_face_mesh 11 | self.forehead_edge = forehead_edge 12 | 13 | DrawingSpec = mp.solutions.drawing_styles.DrawingSpec 14 | f_thick = 2 15 | f_rad = 1 16 | right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) 17 | right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) 18 | right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) 19 | left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) 20 | left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) 21 | left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) 22 | head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) 23 | 24 | mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad) 25 | mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad) 26 | 27 | mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad) 28 | mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad) 29 | 30 | mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad) 31 | mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad) 32 | 33 | mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad) 34 | mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad) 35 | 36 | FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)] 37 | FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)] 38 | 39 | FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)] 40 | FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)] 41 | 42 | FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)] 43 | FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)] 44 | 45 | FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)] 46 | FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)] 47 | 48 | FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)] 49 | 50 | # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. 51 | face_connection_spec = {} 52 | if self.forehead_edge: 53 | for edge in mp_face_mesh.FACEMESH_FACE_OVAL: 54 | face_connection_spec[edge] = head_draw 55 | else: 56 | for edge in FACEMESH_CUSTOM_FACE_OVAL: 57 | face_connection_spec[edge] = head_draw 58 | for edge in mp_face_mesh.FACEMESH_LEFT_EYE: 59 | face_connection_spec[edge] = left_eye_draw 60 | for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: 61 | face_connection_spec[edge] = left_eyebrow_draw 62 | # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: 63 | # face_connection_spec[edge] = left_iris_draw 64 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: 65 | face_connection_spec[edge] = right_eye_draw 66 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: 67 | face_connection_spec[edge] = right_eyebrow_draw 68 | # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: 69 | # face_connection_spec[edge] = right_iris_draw 70 | # for edge in mp_face_mesh.FACEMESH_LIPS: 71 | # face_connection_spec[edge] = mouth_draw 72 | 73 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT: 74 | face_connection_spec[edge] = mouth_draw_obl 75 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT: 76 | face_connection_spec[edge] = mouth_draw_obr 77 | for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT: 78 | face_connection_spec[edge] = mouth_draw_ibl 79 | for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT: 80 | face_connection_spec[edge] = mouth_draw_ibr 81 | for edge in FACEMESH_LIPS_OUTER_TOP_LEFT: 82 | face_connection_spec[edge] = mouth_draw_otl 83 | for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT: 84 | face_connection_spec[edge] = mouth_draw_otr 85 | for edge in FACEMESH_LIPS_INNER_TOP_LEFT: 86 | face_connection_spec[edge] = mouth_draw_itl 87 | for edge in FACEMESH_LIPS_INNER_TOP_RIGHT: 88 | face_connection_spec[edge] = mouth_draw_itr 89 | 90 | 91 | iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} 92 | 93 | self.face_connection_spec = face_connection_spec 94 | def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2): 95 | """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all 96 | landmarks. Until our PR is merged into mediapipe, we need this separate method.""" 97 | if len(image.shape) != 3: 98 | raise ValueError("Input image must be H,W,C.") 99 | image_rows, image_cols, image_channels = image.shape 100 | if image_channels != 3: # BGR channels 101 | raise ValueError('Input image must contain three channel bgr data.') 102 | for idx, landmark in enumerate(landmark_list.landmark): 103 | if ( 104 | (landmark.HasField('visibility') and landmark.visibility < 0.9) or 105 | (landmark.HasField('presence') and landmark.presence < 0.5) 106 | ): 107 | continue 108 | if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: 109 | continue 110 | image_x = int(image_cols*landmark.x) 111 | image_y = int(image_rows*landmark.y) 112 | draw_color = None 113 | if isinstance(drawing_spec, Mapping): 114 | if drawing_spec.get(idx) is None: 115 | continue 116 | else: 117 | draw_color = drawing_spec[idx].color 118 | elif isinstance(drawing_spec, DrawingSpec): 119 | draw_color = drawing_spec.color 120 | image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color 121 | 122 | 123 | 124 | def draw_landmarks(self, image_size, keypoints, normed=False): 125 | ini_size = [512, 512] 126 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) 127 | new_landmarks = landmark_pb2.NormalizedLandmarkList() 128 | for i in range(keypoints.shape[0]): 129 | landmark = new_landmarks.landmark.add() 130 | if normed: 131 | landmark.x = keypoints[i, 0] 132 | landmark.y = keypoints[i, 1] 133 | else: 134 | landmark.x = keypoints[i, 0] / image_size[0] 135 | landmark.y = keypoints[i, 1] / image_size[1] 136 | landmark.z = 1.0 137 | 138 | self.mp_drawing.draw_landmarks( 139 | image=image, 140 | landmark_list=new_landmarks, 141 | connections=self.face_connection_spec.keys(), 142 | landmark_drawing_spec=None, 143 | connection_drawing_spec=self.face_connection_spec 144 | ) 145 | # draw_pupils(image, face_landmarks, iris_landmark_spec, 2) 146 | image = cv2.resize(image, (image_size[0], image_size[1])) 147 | 148 | return image 149 | 150 | -------------------------------------------------------------------------------- /src/utils/frame_interpolation.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/dajes/frame-interpolation-pytorch 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import bisect 7 | import shutil 8 | import pdb 9 | from tqdm import tqdm 10 | 11 | def init_frame_interpolation_model(): 12 | print("Initializing frame interpolation model") 13 | checkpoint_name = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),"pretrained_model/film_net_fp16.pt") 14 | #checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt") 15 | 16 | model = torch.jit.load(checkpoint_name, map_location='cpu') 17 | model.eval() 18 | model = model.half() 19 | model = model.to(device="cuda") 20 | return model 21 | 22 | 23 | def batch_images_interpolation_tool(input_tensor, model, inter_frames=1): 24 | 25 | video_tensor = [] 26 | frame_num = input_tensor.shape[2] # bs, channel, frame, height, width 27 | 28 | for idx in tqdm(range(frame_num-1)): 29 | image1 = input_tensor[:,:,idx] 30 | image2 = input_tensor[:,:,idx+1] 31 | 32 | results = [image1, image2] 33 | 34 | inter_frames = int(inter_frames) 35 | idxes = [0, inter_frames + 1] 36 | remains = list(range(1, inter_frames + 1)) 37 | 38 | splits = torch.linspace(0, 1, inter_frames + 2) 39 | 40 | for _ in range(len(remains)): 41 | starts = splits[idxes[:-1]] 42 | ends = splits[idxes[1:]] 43 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() 44 | matrix = torch.argmin(distances).item() 45 | start_i, step = np.unravel_index(matrix, distances.shape) 46 | end_i = start_i + 1 47 | 48 | x0 = results[start_i] 49 | x1 = results[end_i] 50 | 51 | x0 = x0.half() 52 | x1 = x1.half() 53 | x0 = x0.cuda() 54 | x1 = x1.cuda() 55 | 56 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) 57 | 58 | with torch.no_grad(): 59 | prediction = model(x0, x1, dt) 60 | insert_position = bisect.bisect_left(idxes, remains[step]) 61 | idxes.insert(insert_position, remains[step]) 62 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) 63 | del remains[step] 64 | 65 | for sub_idx in range(len(results)-1): 66 | video_tensor.append(results[sub_idx].unsqueeze(2)) 67 | 68 | video_tensor.append(input_tensor[:,:,-1].unsqueeze(2)) 69 | video_tensor = torch.cat(video_tensor, dim=2) 70 | return video_tensor 71 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import logging 4 | 5 | class ColoredFormatter(logging.Formatter): 6 | COLORS = { 7 | "DEBUG": "\033[0;36m", # CYAN 8 | "INFO": "\033[0;32m", # GREEN 9 | "WARNING": "\033[0;33m", # YELLOW 10 | "ERROR": "\033[0;31m", # RED 11 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED 12 | "RESET": "\033[0m", # RESET COLOR 13 | } 14 | 15 | def format(self, record): 16 | colored_record = copy.copy(record) 17 | levelname = colored_record.levelname 18 | seq = self.COLORS.get(levelname, self.COLORS["RESET"]) 19 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" 20 | return super().format(colored_record) 21 | 22 | 23 | # Create a new logger 24 | logger = logging.getLogger("AniPortrait") 25 | logger.propagate = False 26 | 27 | # Add handler if we don't have one. 28 | if not logger.handlers: 29 | handler = logging.StreamHandler(sys.stdout) 30 | handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s")) 31 | logger.addHandler(handler) 32 | 33 | # Configure logger 34 | loglevel = logging.INFO 35 | logger.setLevel(loglevel) 36 | -------------------------------------------------------------------------------- /src/utils/mp_models/blaze_face_short_range.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/mp_models/blaze_face_short_range.tflite -------------------------------------------------------------------------------- /src/utils/mp_models/face_landmarker_v2_with_blendshapes.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task -------------------------------------------------------------------------------- /src/utils/mp_models/pose_landmarker_heavy.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/mp_models/pose_landmarker_heavy.task -------------------------------------------------------------------------------- /src/utils/mp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import time 5 | from tqdm import tqdm 6 | import multiprocessing 7 | import glob 8 | 9 | import mediapipe as mp 10 | from mediapipe import solutions 11 | from mediapipe.framework.formats import landmark_pb2 12 | from mediapipe.tasks import python 13 | from mediapipe.tasks.python import vision 14 | from . import face_landmark 15 | 16 | CUR_DIR = os.path.dirname(__file__) 17 | 18 | 19 | class LMKExtractor(): 20 | def __init__(self, FPS=25): 21 | # Create an FaceLandmarker object. 22 | self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE 23 | base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task')) 24 | base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU 25 | options = vision.FaceLandmarkerOptions(base_options=base_options, 26 | running_mode=self.mode, 27 | output_face_blendshapes=True, 28 | output_facial_transformation_matrixes=True, 29 | num_faces=1) 30 | self.detector = face_landmark.FaceLandmarker.create_from_options(options) 31 | self.last_ts = 0 32 | self.frame_ms = int(1000 / FPS) 33 | 34 | det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite')) 35 | det_options = vision.FaceDetectorOptions(base_options=det_base_options) 36 | self.det_detector = vision.FaceDetector.create_from_options(det_options) 37 | 38 | 39 | def __call__(self, img): 40 | frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 41 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) 42 | t0 = time.time() 43 | if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO: 44 | det_result = self.det_detector.detect(image) 45 | if len(det_result.detections) != 1: 46 | return None 47 | self.last_ts += self.frame_ms 48 | try: 49 | detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts) 50 | except: 51 | return None 52 | elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE: 53 | # det_result = self.det_detector.detect(image) 54 | 55 | # if len(det_result.detections) != 1: 56 | # return None 57 | try: 58 | detection_result, mesh3d = self.detector.detect(image) 59 | except: 60 | return None 61 | 62 | 63 | bs_list = detection_result.face_blendshapes 64 | if len(bs_list) == 1: 65 | bs = bs_list[0] 66 | bs_values = [] 67 | for index in range(len(bs)): 68 | bs_values.append(bs[index].score) 69 | bs_values = bs_values[1:] # remove neutral 70 | trans_mat = detection_result.facial_transformation_matrixes[0] 71 | face_landmarks_list = detection_result.face_landmarks 72 | face_landmarks = face_landmarks_list[0] 73 | lmks = [] 74 | for index in range(len(face_landmarks)): 75 | x = face_landmarks[index].x 76 | y = face_landmarks[index].y 77 | z = face_landmarks[index].z 78 | lmks.append([x, y, z]) 79 | lmks = np.array(lmks) 80 | 81 | lmks3d = np.array(mesh3d.vertex_buffer) 82 | lmks3d = lmks3d.reshape(-1, 5)[:, :3] 83 | mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1 84 | 85 | return { 86 | "lmks": lmks, 87 | 'lmks3d': lmks3d, 88 | "trans_mat": trans_mat, 89 | 'faces': mp_tris, 90 | "bs": bs_values 91 | } 92 | else: 93 | # print('multiple faces in the image: {}'.format(img_path)) 94 | return None 95 | 96 | -------------------------------------------------------------------------------- /src/utils/pose_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | def create_perspective_matrix(aspect_ratio): 8 | kDegreesToRadians = np.pi / 180. 9 | near = 1 10 | far = 10000 11 | perspective_matrix = np.zeros(16, dtype=np.float32) 12 | 13 | # Standard perspective projection matrix calculations. 14 | f = 1.0 / np.tan(kDegreesToRadians * 63 / 2.) 15 | 16 | denom = 1.0 / (near - far) 17 | perspective_matrix[0] = f / aspect_ratio 18 | perspective_matrix[5] = f 19 | perspective_matrix[10] = (near + far) * denom 20 | perspective_matrix[11] = -1. 21 | perspective_matrix[14] = 1. * far * near * denom 22 | 23 | # If the environment's origin point location is in the top left corner, 24 | # then skip additional flip along Y-axis is required to render correctly. 25 | 26 | perspective_matrix[5] *= -1. 27 | return perspective_matrix 28 | 29 | 30 | def project_points(points_3d, transformation_matrix, pose_vectors, image_shape): 31 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T 32 | L, N, _ = points_3d.shape 33 | projected_points = np.zeros((L, N, 2)) 34 | for i in range(L): 35 | points_3d_frame = points_3d[i] 36 | ones = np.ones((points_3d_frame.shape[0], 1)) 37 | points_3d_homogeneous = np.hstack([points_3d_frame, ones]) 38 | transformed_points = points_3d_homogeneous @ (transformation_matrix @ euler_and_translation_to_matrix(pose_vectors[i][:3], pose_vectors[i][3:])).T @ P 39 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1 40 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1] 41 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0] 42 | projected_points[i] = projected_points_frame 43 | return projected_points 44 | 45 | 46 | def project_points_with_trans(points_3d, transformation_matrix, image_shape): 47 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T 48 | L, N, _ = points_3d.shape 49 | projected_points = np.zeros((L, N, 2)) 50 | for i in range(L): 51 | points_3d_frame = points_3d[i] 52 | ones = np.ones((points_3d_frame.shape[0], 1)) 53 | points_3d_homogeneous = np.hstack([points_3d_frame, ones]) 54 | transformed_points = points_3d_homogeneous @ transformation_matrix[i].T @ P 55 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1 56 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1] 57 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0] 58 | projected_points[i] = projected_points_frame 59 | return projected_points 60 | 61 | 62 | def euler_and_translation_to_matrix(euler_angles, translation_vector): 63 | rotation = R.from_euler('xyz', euler_angles, degrees=True) 64 | rotation_matrix = rotation.as_matrix() 65 | 66 | matrix = np.eye(4) 67 | matrix[:3, :3] = rotation_matrix 68 | matrix[:3, 3] = translation_vector 69 | 70 | return matrix 71 | 72 | 73 | def matrix_to_euler_and_translation(matrix): 74 | rotation_matrix = matrix[:3, :3] 75 | translation_vector = matrix[:3, 3] 76 | rotation = R.from_matrix(rotation_matrix) 77 | euler_angles = rotation.as_euler('xyz', degrees=True) 78 | return euler_angles, translation_vector 79 | 80 | 81 | def smooth_pose_seq(pose_seq, window_size=5): 82 | smoothed_pose_seq = np.zeros_like(pose_seq) 83 | 84 | for i in range(len(pose_seq)): 85 | start = max(0, i - window_size // 2) 86 | end = min(len(pose_seq), i + window_size // 2 + 1) 87 | smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) 88 | 89 | return smoothed_pose_seq 90 | -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import sys 6 | from pathlib import Path 7 | import hashlib 8 | 9 | from typing import Iterable 10 | import subprocess 11 | import re 12 | 13 | from .logger import logger 14 | 15 | import av 16 | import numpy as np 17 | import torch 18 | import torchvision 19 | from einops import rearrange 20 | from PIL import Image 21 | 22 | def seed_everything(seed): 23 | import random 24 | 25 | import numpy as np 26 | 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | np.random.seed(seed % (2**32)) 30 | random.seed(seed) 31 | 32 | 33 | def import_filename(filename): 34 | spec = importlib.util.spec_from_file_location("mymodule", filename) 35 | module = importlib.util.module_from_spec(spec) 36 | sys.modules[spec.name] = module 37 | spec.loader.exec_module(module) 38 | return module 39 | 40 | 41 | def delete_additional_ckpt(base_path, num_keep): 42 | dirs = [] 43 | for d in os.listdir(base_path): 44 | if d.startswith("checkpoint-"): 45 | dirs.append(d) 46 | num_tot = len(dirs) 47 | if num_tot <= num_keep: 48 | return 49 | # ensure ckpt is sorted and delete the ealier! 50 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] 51 | for d in del_dirs: 52 | path_to_dir = osp.join(base_path, d) 53 | if osp.exists(path_to_dir): 54 | shutil.rmtree(path_to_dir) 55 | 56 | 57 | def save_videos_from_pil(pil_images, path, fps=8): 58 | import av 59 | 60 | save_fmt = Path(path).suffix 61 | os.makedirs(os.path.dirname(path), exist_ok=True) 62 | width, height = pil_images[0].size 63 | 64 | if save_fmt == ".mp4": 65 | codec = "libx264" 66 | container = av.open(path, "w") 67 | stream = container.add_stream(codec, rate=fps) 68 | 69 | stream.width = width 70 | stream.height = height 71 | 72 | for pil_image in pil_images: 73 | # pil_image = Image.fromarray(image_arr).convert("RGB") 74 | av_frame = av.VideoFrame.from_image(pil_image) 75 | container.mux(stream.encode(av_frame)) 76 | container.mux(stream.encode()) 77 | container.close() 78 | 79 | elif save_fmt == ".gif": 80 | pil_images[0].save( 81 | fp=path, 82 | format="GIF", 83 | append_images=pil_images[1:], 84 | save_all=True, 85 | duration=(1 / fps * 1000), 86 | loop=0, 87 | ) 88 | else: 89 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 90 | 91 | 92 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 93 | videos = rearrange(videos, "b c t h w -> t b c h w") 94 | height, width = videos.shape[-2:] 95 | outputs = [] 96 | 97 | for x in videos: 98 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 99 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 100 | if rescale: 101 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 102 | x = (x * 255).numpy().astype(np.uint8) 103 | x = Image.fromarray(x) 104 | 105 | outputs.append(x) 106 | 107 | os.makedirs(os.path.dirname(path), exist_ok=True) 108 | 109 | save_videos_from_pil(outputs, path, fps) 110 | 111 | 112 | def read_frames(video_path): 113 | container = av.open(video_path) 114 | 115 | video_stream = next(s for s in container.streams if s.type == "video") 116 | frames = [] 117 | for packet in container.demux(video_stream): 118 | for frame in packet.decode(): 119 | image = Image.frombytes( 120 | "RGB", 121 | (frame.width, frame.height), 122 | frame.to_rgb().to_ndarray(), 123 | ) 124 | frames.append(image) 125 | 126 | return frames 127 | 128 | 129 | def get_fps(video_path): 130 | container = av.open(video_path) 131 | video_stream = next(s for s in container.streams if s.type == "video") 132 | fps = video_stream.average_rate 133 | container.close() 134 | return fps 135 | 136 | def ffmpeg_suitability(path): 137 | try: 138 | version = subprocess.run([path, "-version"], check=True, 139 | capture_output=True).stdout.decode("utf-8") 140 | except: 141 | return 0 142 | score = 0 143 | #rough layout of the importance of various features 144 | simple_criterion = [("libvpx", 20),("264",10), ("265",3), 145 | ("svtav1",5),("libopus", 1)] 146 | for criterion in simple_criterion: 147 | if version.find(criterion[0]) >= 0: 148 | score += criterion[1] 149 | #obtain rough compile year from copyright information 150 | copyright_index = version.find('2000-2') 151 | if copyright_index >= 0: 152 | copyright_year = version[copyright_index+6:copyright_index+9] 153 | if copyright_year.isnumeric(): 154 | score += int(copyright_year) 155 | return score 156 | 157 | 158 | if "VHS_FORCE_FFMPEG_PATH" in os.environ: 159 | ffmpeg_path = os.env["VHS_FORCE_FFMPEG_PATH"] 160 | else: 161 | ffmpeg_paths = [] 162 | try: 163 | from imageio_ffmpeg import get_ffmpeg_exe 164 | imageio_ffmpeg_path = get_ffmpeg_exe() 165 | ffmpeg_paths.append(imageio_ffmpeg_path) 166 | except: 167 | if "VHS_USE_IMAGEIO_FFMPEG" in os.environ: 168 | raise 169 | logger.warn("Failed to import imageio_ffmpeg") 170 | if "VHS_USE_IMAGEIO_FFMPEG" in os.environ: 171 | ffmpeg_path = imageio_ffmpeg_path 172 | else: 173 | system_ffmpeg = shutil.which("ffmpeg") 174 | if system_ffmpeg is not None: 175 | ffmpeg_paths.append(system_ffmpeg) 176 | if len(ffmpeg_paths) == 0: 177 | logger.error("No valid ffmpeg found.") 178 | ffmpeg_path = None 179 | else: 180 | ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability) 181 | 182 | 183 | def get_sorted_dir_files_from_directory(directory: str, skip_first_images: int=0, select_every_nth: int=1, extensions: Iterable=None): 184 | directory = directory.strip() 185 | dir_files = os.listdir(directory) 186 | dir_files = sorted(dir_files) 187 | dir_files = [os.path.join(directory, x) for x in dir_files] 188 | dir_files = list(filter(lambda filepath: os.path.isfile(filepath), dir_files)) 189 | # filter by extension, if needed 190 | if extensions is not None: 191 | extensions = list(extensions) 192 | new_dir_files = [] 193 | for filepath in dir_files: 194 | ext = "." + filepath.split(".")[-1] 195 | if ext.lower() in extensions: 196 | new_dir_files.append(filepath) 197 | dir_files = new_dir_files 198 | # start at skip_first_images 199 | dir_files = dir_files[skip_first_images:] 200 | dir_files = dir_files[0::select_every_nth] 201 | return dir_files 202 | 203 | 204 | # modified from https://stackoverflow.com/questions/22058048/hashing-a-file-in-python 205 | def calculate_file_hash(filename: str, hash_every_n: int = 1): 206 | h = hashlib.sha256() 207 | b = bytearray(10*1024*1024) # read 10 megabytes at a time 208 | mv = memoryview(b) 209 | with open(filename, 'rb', buffering=0) as f: 210 | i = 0 211 | # don't hash entire file, only portions of it if requested 212 | while n := f.readinto(mv): 213 | if i%hash_every_n == 0: 214 | h.update(mv[:n]) 215 | i += 1 216 | return h.hexdigest() 217 | 218 | 219 | def get_audio(file, start_time=0, duration=0): 220 | args = [ffmpeg_path, "-v", "error", "-i", file] 221 | if start_time > 0: 222 | args += ["-ss", str(start_time)] 223 | if duration > 0: 224 | args += ["-t", str(duration)] 225 | return subprocess.run(args + ["-f", "wav", "-"], 226 | stdout=subprocess.PIPE, check=True).stdout 227 | 228 | 229 | def lazy_eval(func): 230 | class Cache: 231 | def __init__(self, func): 232 | self.res = None 233 | self.func = func 234 | def get(self): 235 | if self.res is None: 236 | self.res = self.func() 237 | return self.res 238 | cache = Cache(func) 239 | return lambda : cache.get() 240 | 241 | 242 | def is_url(url): 243 | return url.split("://")[0] in ["http", "https"] 244 | 245 | def validate_sequence(path): 246 | #Check if path is a valid ffmpeg sequence that points to at least one file 247 | (path, file) = os.path.split(path) 248 | if not os.path.isdir(path): 249 | return False 250 | match = re.search('%0?\d+d', file) 251 | if not match: 252 | return False 253 | seq = match.group() 254 | if seq == '%d': 255 | seq = '\\\\d+' 256 | else: 257 | seq = '\\\\d{%s}' % seq[1:-1] 258 | file_matcher = re.compile(re.sub('%0?\d+d', seq, file)) 259 | for file in os.listdir(path): 260 | if file_matcher.fullmatch(file): 261 | return True 262 | return False 263 | 264 | def hash_path(path): 265 | if path is None: 266 | return "input" 267 | if is_url(path): 268 | return "url" 269 | return calculate_file_hash(path.strip("\"")) 270 | 271 | 272 | def validate_path(path, allow_none=False, allow_url=True): 273 | if path is None: 274 | return allow_none 275 | if is_url(path): 276 | #Probably not feasible to check if url resolves here 277 | return True if allow_url else "URLs are unsupported for this path" 278 | if not os.path.isfile(path.strip("\"")): 279 | return "Invalid file path: {}".format(path) 280 | return True 281 | --------------------------------------------------------------------------------