├── README.md
├── __init__.py
├── assets
├── Aragaki.png
├── audio2video_workflow.json
├── face_reenacment_workflow.json
├── lyl.wav
├── pose2video_workflow.json
├── pose_ref_video.mp4
├── solo.png
└── woman.jpg
├── configs
├── inference
│ ├── inference_audio.yaml
│ └── inference_v2.yaml
└── prompts
│ ├── animation.yaml
│ ├── animation_audio.yaml
│ └── animation_facereenac.yaml
├── nodes.py
├── pyproject.toml
├── requirements.txt
└── src
├── __init__.py
├── __pycache__
└── __init__.cpython-310.pyc
├── audio_models
├── __pycache__
│ ├── model.cpython-310.pyc
│ ├── torch_utils.cpython-310.pyc
│ └── wav2vec2.cpython-310.pyc
├── model.py
├── pose_model.py
├── torch_utils.py
└── wav2vec2.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── attention.cpython-310.pyc
│ ├── motion_module.cpython-310.pyc
│ ├── mutual_self_attention.cpython-310.pyc
│ ├── pose_guider.cpython-310.pyc
│ ├── resnet.cpython-310.pyc
│ ├── transformer_2d.cpython-310.pyc
│ ├── transformer_3d.cpython-310.pyc
│ ├── unet_2d_blocks.cpython-310.pyc
│ ├── unet_2d_condition.cpython-310.pyc
│ ├── unet_3d.cpython-310.pyc
│ └── unet_3d_blocks.cpython-310.pyc
├── attention.py
├── motion_module.py
├── mutual_self_attention.py
├── pose_guider.py
├── resnet.py
├── transformer_2d.py
├── transformer_3d.py
├── unet_2d_blocks.py
├── unet_2d_condition.py
├── unet_3d.py
└── unet_3d_blocks.py
├── pipelines
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── context.cpython-310.pyc
│ ├── pipeline_pose2vid_long.cpython-310.pyc
│ └── utils.cpython-310.pyc
├── context.py
├── pipeline_pose2vid.py
├── pipeline_pose2vid_long.py
└── utils.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-310.pyc
├── audio_util.cpython-310.pyc
├── draw_util.cpython-310.pyc
├── face_landmark.cpython-310.pyc
├── logger.cpython-310.pyc
├── mp_utils.cpython-310.pyc
├── pose_util.cpython-310.pyc
└── util.cpython-310.pyc
├── audio_util.py
├── draw_util.py
├── face_landmark.py
├── frame_interpolation.py
├── logger.py
├── mp_models
├── blaze_face_short_range.tflite
├── face_landmarker_v2_with_blendshapes.task
└── pose_landmarker_heavy.task
├── mp_utils.py
├── pose_util.py
└── util.py
/README.md:
--------------------------------------------------------------------------------
1 | #### Updates:
2 | ① Implement the frame_interpolation to speed up generation
3 |
4 | ② Modify the current code and support chain with the [VHS nodes](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite), i just found that comfyUI IMAGE type requires the torch float32 datatype, and AniPortrait heavily used numpy of image unit8 datatype,so i just changed my mind from my own image/video upload and generation nodes to the prevelance SOTA VHS image/video upload and video combined nodes,it WYSIWYG and inteactive well and instantly render the result
5 | - ✅ [2024/04/09] raw video to pose video with reference image(aka self-driven)
6 | - ✅ [2024/04/09] audio driven
7 | - ✅ [2024/04/09] face reenacment
8 | - ✅ [2024/04/22] implement audio2pose model and [pre-trained weight](https://huggingface.co/ZJYang/AniPortrait/tree/main) for audio2video,the face reenacment and audio2video workflow has been modified, currently inference up to a maximum length of 10 seconds has been supported,you can experiment with the length hyperparameter.
9 |
10 | U can contact me thr [twitter](https://twitter.com/kurtqian)  Weixin:GalaticKing
11 |
12 |
13 | ### audio driven combined with reference image and reference video
14 | 
15 | [audio2video workflow](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/audio2video_workflow.json)
16 |
17 |
18 |
19 |
20 | |
21 |
22 |
23 |
24 | ### raw video to pose video with reference image
25 | 
26 |
27 |
28 |
29 |
30 | |
31 |
32 |
33 |
34 | ### face reenacment
35 | 
36 | [video2video workflow](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/face_reenacment_workflow.json)
37 |
38 |
39 |
40 |
41 | |
42 |
43 |
44 |
45 | This is unofficial implementation of AniPortrait in ComfyUI custom_node,cuz i have routine jobs,so i will update this project when i have time
46 | > [Aniportrait_pose2video.json](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/pose2video_workflow.json)
47 |
48 | > [Audio driven](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/audio2video_workflow.json)
49 |
50 | > [face reenacment](https://github.com/frankchieng/ComfyUI_Aniportrait/blob/main/assets/face_reenacment_workflow.json)
51 |
52 | you should run
53 | ```shell
54 | git clone https://github.com/frankchieng/ComfyUI_Aniportrait.git
55 | ```
56 | then run
57 | ```shell
58 | pip install -r requirements.txt
59 | ```
60 | download the pretrained models
61 | > [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
62 |
63 | > [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
64 |
65 | > [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)
66 |
67 | > [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h)
68 |
69 | download the weights:
70 | > [denoising_unet.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main)
71 | > [reference_unet.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main)
72 | > [pose_guider.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main)
73 | > [motion_module.pth](https://huggingface.co/ZJYang/AniPortrait/tree/main)
74 | > [audio2mesh.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main)
75 | > [audio2pose.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main)
76 | > [film_net_fp16.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main)
77 | ```text
78 | ./pretrained_model/
79 | |-- image_encoder
80 | | |-- config.json
81 | | `-- pytorch_model.bin
82 | |-- sd-vae-ft-mse
83 | | |-- config.json
84 | | |-- diffusion_pytorch_model.bin
85 | | `-- diffusion_pytorch_model.safetensors
86 | |-- stable-diffusion-v1-5
87 | | |-- feature_extractor
88 | | | `-- preprocessor_config.json
89 | | |-- model_index.json
90 | | |-- unet
91 | | | |-- config.json
92 | | | `-- diffusion_pytorch_model.bin
93 | | `-- v1-inference.yaml
94 | |-- wav2vec2-base-960h
95 | | |-- config.json
96 | | |-- feature_extractor_config.json
97 | | |-- preprocessor_config.json
98 | | |-- pytorch_model.bin
99 | | |-- README.md
100 | | |-- special_tokens_map.json
101 | | |-- tokenizer_config.json
102 | | `-- vocab.json
103 | |-- audio2mesh.pt
104 | |-- audio2pose.pt
105 | |-- denoising_unet.pth
106 | |-- motion_module.pth
107 | |-- pose_guider.pth
108 | |-- reference_unet.pth
109 | |-- film_net_fp16.pt
110 | ```
111 |
112 | Tips :
113 | The intermediate audio file will be generated and deleted,the raw video to pose video with audio and pose2video mp4 file will be located in the output directory of ComfyUI
114 | the original uploaded mp4 video requires square size like 512x512, otherwise the result will be weird
115 | #### I've updated diffusers from 0.24.x to 0.26.2,so the diffusers/models/embeddings.py classname of PositionNet changed to GLIGENTextBoundingboxProjection and CaptionProjection changed to PixArtAlphaTextProjection,you should pay attention to it and modify the corresponding python files like src/models/transformer_2d.py if you installed the lower version of diffusers
116 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import folder_paths
2 | import os
3 | import ffmpeg
4 | from PIL import Image
5 | import cv2
6 | from tqdm import tqdm
7 | import re
8 | import torch
9 | from .nodes import PoseGenVideo, RefImagePath, Audio2Video, AudioPath #,GenerateRefPose
10 |
11 | from .src.utils.util import get_fps, read_frames, save_videos_from_pil, calculate_file_hash, get_sorted_dir_files_from_directory, get_audio, lazy_eval, hash_path, validate_path
12 | import numpy as np
13 | from .src.utils.draw_util import FaceMeshVisualizer
14 | from .src.utils.mp_utils import LMKExtractor
15 |
16 | video_extensions = ['webm', 'mp4', 'mkv', 'gif']
17 |
18 | class VideoGenPose:
19 | @classmethod
20 | def INPUT_TYPES(s):
21 | return {
22 | "required": {
23 | "image": ("IMAGE",),
24 | "filename_prefix": ("STRING", {"default": "AniPortrait"}),
25 | "height": ("INT", {"default": 512, "min": 0, "max": 1024, "step": 1}),
26 | "width": ("INT", {"default": 512, "min": 0, "max": 1024, "step": 1}),
27 | },
28 | }
29 |
30 | RETURN_TYPES = ("IMAGE",)
31 | RETURN_NAMES = ("pose_images",)
32 | OUTPUT_NODE = True
33 | CATEGORY = "AniPortrait 🎥Video"
34 | FUNCTION = "generate_pose_video"
35 |
36 | def generate_pose_video(self, image, filename_prefix, height, width):
37 |
38 | frames = (image.numpy() * 255).astype(np.uint8)
39 | lmk_extractor = LMKExtractor()
40 | vis = FaceMeshVisualizer(forehead_edge=False)
41 |
42 | kps_results = []
43 | for i, frame_pil in enumerate(tqdm(frames)):
44 | image_np = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)
45 | image_np = cv2.resize(image_np, (height, width))
46 | face_result = lmk_extractor(image_np)
47 | try:
48 | lmks = face_result['lmks'].astype(np.float32)
49 | pose_img = vis.draw_landmarks((image_np.shape[1], image_np.shape[0]), lmks, normed=True)
50 | pose_img = Image.fromarray(cv2.cvtColor(pose_img, cv2.COLOR_BGR2RGB))
51 | except:
52 | pose_img = kps_results[-1]
53 |
54 | kps_results.append(pose_img)
55 |
56 | iterable = (x for x in kps_results)
57 | images = torch.from_numpy(np.fromiter(iterable, np.dtype((np.float32, (height, width, 3)))))
58 | return (images,)
59 |
60 |
61 | class LoadVideoPath:
62 | @classmethod
63 | def INPUT_TYPES(s):
64 | return {
65 | "required": {
66 | "video": ("STRING", {"default": "X://insert/path/here.mp4", "aniportrait_path_extensions": video_extensions}),
67 | },
68 | }
69 |
70 | CATEGORY = "AniPortrait 🎥Video"
71 |
72 | RETURN_TYPES = ("AniPortrait_Video", "IMAGE", "Frame_per_second", "AniPortrait_Audio", )
73 | RETURN_NAMES = ("video", "frames", "frame_per_second", "audio",)
74 | FUNCTION = "load_video"
75 |
76 | def load_video(self, **kwargs):
77 | if kwargs['video'] is None or validate_path(kwargs['video']) != True:
78 | raise Exception("video is not a valid path: " + kwargs['video'])
79 | return load_video_av(**kwargs)
80 |
81 | @classmethod
82 | def IS_CHANGED(s, video, **kwargs):
83 | return hash_path(video)
84 |
85 | @classmethod
86 | def VALIDATE_INPUTS(s, video, **kwargs):
87 | return validate_path(video, allow_none=True)
88 |
89 |
90 | def load_video_av(video: str):
91 | fps = get_fps(video)
92 | frames = read_frames(video)
93 | input_dir = folder_paths.get_output_directory()
94 | audio_output = os.path.join(input_dir, 'audio_from_video.aac')
95 |
96 | return (video, frames, fps, audio_output)
97 |
98 | NODE_CLASS_MAPPINGS = {
99 | "AniPortrait_Video_Gen_Pose": VideoGenPose,
100 | "AniPortrait_LoadVideoPath": LoadVideoPath,
101 | "AniPortrait_Pose_Gen_Video": PoseGenVideo,
102 | "AniPortrait_Ref_Image_Path": RefImagePath,
103 | # "AniPortrait_Generate_Ref_Pose": GenerateRefPose,
104 | "AniPortrait_Audio2Video": Audio2Video,
105 | "AniPortrait_Audio_Path": AudioPath,
106 | }
107 |
108 | NODE_DISPLAY_NAME_MAPPINGS = {
109 | "AniPortrait_Video_Gen_Pose": "Video MediaPipe Face Detection🎥AniPortrait",
110 | "AniPortrait_LoadVideoPath": "Load Video (Path) 🎥AniPortrait",
111 | "AniPortrait_Pose_Gen_Video": "Pose Generate Video 🎥AniPortrait",
112 | "AniPortrait_Ref_Image_Path": "Ref Image Path 🎥AniPortrait",
113 | # "AniPortrait_Generate_Ref_Pose": "Generate Ref Pose 🎥AniPortrait",
114 | "AniPortrait_Audio2Video": "Audio Gen Video 🎥AniPortrait",
115 | "AniPortrait_Audio_Path": "Audio Path 🎥AniPortrait",
116 | }
117 |
--------------------------------------------------------------------------------
/assets/Aragaki.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/Aragaki.png
--------------------------------------------------------------------------------
/assets/audio2video_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 17,
3 | "last_link_id": 31,
4 | "nodes": [
5 | {
6 | "id": 10,
7 | "type": "AniPortrait_Audio2Video",
8 | "pos": [
9 | -1712,
10 | -63
11 | ],
12 | "size": {
13 | "0": 315,
14 | "1": 506
15 | },
16 | "flags": {},
17 | "order": 3,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "ref_image",
22 | "type": "IMAGE",
23 | "link": 12,
24 | "label": "ref_image"
25 | },
26 | {
27 | "name": "images",
28 | "type": "IMAGE",
29 | "link": 14,
30 | "label": "images"
31 | },
32 | {
33 | "name": "audio_path",
34 | "type": "Audio_Path",
35 | "link": 15,
36 | "label": "audio_path"
37 | },
38 | {
39 | "name": "fps",
40 | "type": "INT",
41 | "link": null,
42 | "widget": {
43 | "name": "fps"
44 | },
45 | "label": "fps"
46 | }
47 | ],
48 | "outputs": [
49 | {
50 | "name": "images",
51 | "type": "IMAGE",
52 | "links": [
53 | 31
54 | ],
55 | "shape": 3,
56 | "label": "images"
57 | }
58 | ],
59 | "properties": {
60 | "Node name for S&R": "AniPortrait_Audio2Video"
61 | },
62 | "widgets_values": [
63 | 512,
64 | 512,
65 | 1748,
66 | "randomize",
67 | 3.5,
68 | 25,
69 | "pretrained_model/sd-vae-ft-mse",
70 | "pretrained_model/stable-diffusion-v1-5",
71 | "fp16",
72 | true,
73 | 0,
74 | 3,
75 | "pretrained_model/motion_module.pth",
76 | "pretrained_model/image_encoder",
77 | "pretrained_model/denoising_unet.pth",
78 | "pretrained_model/reference_unet.pth",
79 | "pretrained_model/pose_guider.pth",
80 | 0
81 | ]
82 | },
83 | {
84 | "id": 1,
85 | "type": "VHS_LoadVideo",
86 | "pos": [
87 | -2366,
88 | -60
89 | ],
90 | "size": [
91 | 251.52520751953125,
92 | 507.52520751953125
93 | ],
94 | "flags": {},
95 | "order": 0,
96 | "mode": 0,
97 | "inputs": [
98 | {
99 | "name": "meta_batch",
100 | "type": "VHS_BatchManager",
101 | "link": null,
102 | "label": "meta_batch"
103 | },
104 | {
105 | "name": "vae",
106 | "type": "VAE",
107 | "link": null,
108 | "label": "vae"
109 | }
110 | ],
111 | "outputs": [
112 | {
113 | "name": "IMAGE",
114 | "type": "IMAGE",
115 | "links": [
116 | 14
117 | ],
118 | "slot_index": 0,
119 | "shape": 3,
120 | "label": "IMAGE"
121 | },
122 | {
123 | "name": "frame_count",
124 | "type": "INT",
125 | "links": null,
126 | "shape": 3,
127 | "label": "frame_count"
128 | },
129 | {
130 | "name": "audio",
131 | "type": "VHS_AUDIO",
132 | "links": null,
133 | "shape": 3,
134 | "label": "audio"
135 | },
136 | {
137 | "name": "video_info",
138 | "type": "VHS_VIDEOINFO",
139 | "links": [],
140 | "slot_index": 3,
141 | "shape": 3,
142 | "label": "video_info"
143 | }
144 | ],
145 | "properties": {
146 | "Node name for S&R": "VHS_LoadVideo"
147 | },
148 | "widgets_values": {
149 | "video": "pose_ref_video.mp4",
150 | "force_rate": 0,
151 | "force_size": "Disabled",
152 | "custom_width": 512,
153 | "custom_height": 512,
154 | "frame_load_cap": 0,
155 | "skip_first_frames": 0,
156 | "select_every_nth": 1,
157 | "choose video to upload": "image",
158 | "videopreview": {
159 | "hidden": false,
160 | "paused": false,
161 | "params": {
162 | "frame_load_cap": 0,
163 | "skip_first_frames": 0,
164 | "force_rate": 0,
165 | "filename": "pose_ref_video.mp4",
166 | "type": "input",
167 | "format": "video/mp4",
168 | "select_every_nth": 1
169 | }
170 | }
171 | }
172 | },
173 | {
174 | "id": 7,
175 | "type": "LoadImage",
176 | "pos": [
177 | -2082,
178 | 122
179 | ],
180 | "size": {
181 | "0": 315,
182 | "1": 314
183 | },
184 | "flags": {},
185 | "order": 1,
186 | "mode": 0,
187 | "outputs": [
188 | {
189 | "name": "IMAGE",
190 | "type": "IMAGE",
191 | "links": [
192 | 12
193 | ],
194 | "slot_index": 0,
195 | "shape": 3,
196 | "label": "IMAGE"
197 | },
198 | {
199 | "name": "MASK",
200 | "type": "MASK",
201 | "links": null,
202 | "shape": 3,
203 | "label": "MASK"
204 | }
205 | ],
206 | "properties": {
207 | "Node name for S&R": "LoadImage"
208 | },
209 | "widgets_values": [
210 | "man.jpg",
211 | "image"
212 | ]
213 | },
214 | {
215 | "id": 8,
216 | "type": "AniPortrait_Audio_Path",
217 | "pos": [
218 | -2114,
219 | -238
220 | ],
221 | "size": {
222 | "0": 315,
223 | "1": 102
224 | },
225 | "flags": {},
226 | "order": 2,
227 | "mode": 0,
228 | "outputs": [
229 | {
230 | "name": "audio_path",
231 | "type": "Audio_Path",
232 | "links": [
233 | 15
234 | ],
235 | "slot_index": 0,
236 | "shape": 3,
237 | "label": "audio_path"
238 | },
239 | {
240 | "name": "audio",
241 | "type": "VHS_AUDIO",
242 | "links": [
243 | 30
244 | ],
245 | "shape": 3,
246 | "label": "audio",
247 | "slot_index": 1
248 | }
249 | ],
250 | "properties": {
251 | "Node name for S&R": "AniPortrait_Audio_Path"
252 | },
253 | "widgets_values": [
254 | "/home/qm/test.wav",
255 | 0
256 | ]
257 | },
258 | {
259 | "id": 17,
260 | "type": "VHS_VideoCombine",
261 | "pos": [
262 | -1297,
263 | -16
264 | ],
265 | "size": [
266 | 218.82891845703125,
267 | 523.3297469669985
268 | ],
269 | "flags": {},
270 | "order": 5,
271 | "mode": 0,
272 | "inputs": [
273 | {
274 | "name": "images",
275 | "type": "IMAGE",
276 | "link": 31,
277 | "label": "images"
278 | },
279 | {
280 | "name": "audio",
281 | "type": "AUDIO",
282 | "link": 29,
283 | "label": "audio"
284 | },
285 | {
286 | "name": "meta_batch",
287 | "type": "VHS_BatchManager",
288 | "link": null,
289 | "label": "meta_batch"
290 | },
291 | {
292 | "name": "vae",
293 | "type": "VAE",
294 | "link": null,
295 | "label": "vae"
296 | }
297 | ],
298 | "outputs": [
299 | {
300 | "name": "Filenames",
301 | "type": "VHS_FILENAMES",
302 | "links": null,
303 | "shape": 3,
304 | "label": "Filenames"
305 | }
306 | ],
307 | "properties": {
308 | "Node name for S&R": "VHS_VideoCombine"
309 | },
310 | "widgets_values": {
311 | "frame_rate": 30,
312 | "loop_count": 0,
313 | "filename_prefix": "Aniportrait",
314 | "format": "video/h264-mp4",
315 | "pix_fmt": "yuv420p",
316 | "crf": 19,
317 | "save_metadata": true,
318 | "pingpong": false,
319 | "save_output": true,
320 | "videopreview": {
321 | "hidden": false,
322 | "paused": false,
323 | "params": {
324 | "filename": "Aniportrait_00002-audio.mp4",
325 | "subfolder": "",
326 | "type": "output",
327 | "format": "video/h264-mp4",
328 | "frame_rate": 30
329 | },
330 | "muted": false
331 | }
332 | }
333 | },
334 | {
335 | "id": 13,
336 | "type": "VHS_VHSAudioToAudio",
337 | "pos": [
338 | -1682,
339 | -200
340 | ],
341 | "size": {
342 | "0": 304.67462158203125,
343 | "1": 26
344 | },
345 | "flags": {},
346 | "order": 4,
347 | "mode": 0,
348 | "inputs": [
349 | {
350 | "name": "vhs_audio",
351 | "type": "VHS_AUDIO",
352 | "link": 30,
353 | "label": "vhs_audio"
354 | }
355 | ],
356 | "outputs": [
357 | {
358 | "name": "audio",
359 | "type": "AUDIO",
360 | "links": [
361 | 29
362 | ],
363 | "shape": 3,
364 | "label": "audio",
365 | "slot_index": 0
366 | }
367 | ],
368 | "properties": {
369 | "Node name for S&R": "VHS_VHSAudioToAudio"
370 | },
371 | "widgets_values": {}
372 | }
373 | ],
374 | "links": [
375 | [
376 | 12,
377 | 7,
378 | 0,
379 | 10,
380 | 0,
381 | "IMAGE"
382 | ],
383 | [
384 | 14,
385 | 1,
386 | 0,
387 | 10,
388 | 1,
389 | "IMAGE"
390 | ],
391 | [
392 | 15,
393 | 8,
394 | 0,
395 | 10,
396 | 2,
397 | "Audio_Path"
398 | ],
399 | [
400 | 29,
401 | 13,
402 | 0,
403 | 17,
404 | 1,
405 | "AUDIO"
406 | ],
407 | [
408 | 30,
409 | 8,
410 | 1,
411 | 13,
412 | 0,
413 | "VHS_AUDIO"
414 | ],
415 | [
416 | 31,
417 | 10,
418 | 0,
419 | 17,
420 | 0,
421 | "IMAGE"
422 | ]
423 | ],
424 | "groups": [],
425 | "config": {},
426 | "extra": {
427 | "ds": {
428 | "scale": 1.061076460950001,
429 | "offset": [
430 | 2608.218390205725,
431 | 532.0839393915426
432 | ]
433 | }
434 | },
435 | "version": 0.4
436 | }
437 |
--------------------------------------------------------------------------------
/assets/face_reenacment_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 12,
3 | "last_link_id": 16,
4 | "nodes": [
5 | {
6 | "id": 7,
7 | "type": "VHS_LoadVideo",
8 | "pos": [
9 | -2279,
10 | 347
11 | ],
12 | "size": [
13 | 235.1999969482422,
14 | 491.1999969482422
15 | ],
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "meta_batch",
22 | "type": "VHS_BatchManager",
23 | "link": null,
24 | "label": "meta_batch"
25 | }
26 | ],
27 | "outputs": [
28 | {
29 | "name": "IMAGE",
30 | "type": "IMAGE",
31 | "links": [
32 | 15
33 | ],
34 | "shape": 3,
35 | "label": "IMAGE",
36 | "slot_index": 0
37 | },
38 | {
39 | "name": "frame_count",
40 | "type": "INT",
41 | "links": null,
42 | "shape": 3,
43 | "label": "frame_count"
44 | },
45 | {
46 | "name": "audio",
47 | "type": "VHS_AUDIO",
48 | "links": [
49 | 8
50 | ],
51 | "shape": 3,
52 | "label": "audio",
53 | "slot_index": 2
54 | },
55 | {
56 | "name": "video_info",
57 | "type": "VHS_VIDEOINFO",
58 | "links": [
59 | 9
60 | ],
61 | "shape": 3,
62 | "label": "video_info",
63 | "slot_index": 3
64 | }
65 | ],
66 | "properties": {
67 | "Node name for S&R": "VHS_LoadVideo"
68 | },
69 | "widgets_values": {
70 | "video": "pose_ref_video.mp4",
71 | "force_rate": 0,
72 | "force_size": "Disabled",
73 | "custom_width": 512,
74 | "custom_height": 512,
75 | "frame_load_cap": 0,
76 | "skip_first_frames": 0,
77 | "select_every_nth": 1,
78 | "choose video to upload": "image",
79 | "videopreview": {
80 | "hidden": false,
81 | "paused": false,
82 | "params": {
83 | "frame_load_cap": 0,
84 | "skip_first_frames": 0,
85 | "force_rate": 0,
86 | "filename": "pose_ref_video.mp4",
87 | "type": "input",
88 | "format": "video/mp4",
89 | "select_every_nth": 1
90 | }
91 | }
92 | }
93 | },
94 | {
95 | "id": 10,
96 | "type": "VHS_VideoInfo",
97 | "pos": [
98 | -1987,
99 | 728
100 | ],
101 | "size": {
102 | "0": 393,
103 | "1": 206
104 | },
105 | "flags": {},
106 | "order": 2,
107 | "mode": 0,
108 | "inputs": [
109 | {
110 | "name": "video_info",
111 | "type": "VHS_VIDEOINFO",
112 | "link": 9,
113 | "label": "video_info"
114 | }
115 | ],
116 | "outputs": [
117 | {
118 | "name": "source_fps🟨",
119 | "type": "FLOAT",
120 | "links": [
121 | 10
122 | ],
123 | "shape": 3,
124 | "label": "source_fps🟨",
125 | "slot_index": 0
126 | },
127 | {
128 | "name": "source_frame_count🟨",
129 | "type": "INT",
130 | "links": null,
131 | "shape": 3,
132 | "label": "source_frame_count🟨"
133 | },
134 | {
135 | "name": "source_duration🟨",
136 | "type": "FLOAT",
137 | "links": null,
138 | "shape": 3,
139 | "label": "source_duration🟨"
140 | },
141 | {
142 | "name": "source_width🟨",
143 | "type": "INT",
144 | "links": null,
145 | "shape": 3,
146 | "label": "source_width🟨"
147 | },
148 | {
149 | "name": "source_height🟨",
150 | "type": "INT",
151 | "links": null,
152 | "shape": 3,
153 | "label": "source_height🟨"
154 | },
155 | {
156 | "name": "loaded_fps🟦",
157 | "type": "FLOAT",
158 | "links": null,
159 | "shape": 3,
160 | "label": "loaded_fps🟦"
161 | },
162 | {
163 | "name": "loaded_frame_count🟦",
164 | "type": "INT",
165 | "links": null,
166 | "shape": 3,
167 | "label": "loaded_frame_count🟦"
168 | },
169 | {
170 | "name": "loaded_duration🟦",
171 | "type": "FLOAT",
172 | "links": null,
173 | "shape": 3,
174 | "label": "loaded_duration🟦"
175 | },
176 | {
177 | "name": "loaded_width🟦",
178 | "type": "INT",
179 | "links": null,
180 | "shape": 3,
181 | "label": "loaded_width🟦"
182 | },
183 | {
184 | "name": "loaded_height🟦",
185 | "type": "INT",
186 | "links": null,
187 | "shape": 3,
188 | "label": "loaded_height🟦"
189 | }
190 | ],
191 | "properties": {
192 | "Node name for S&R": "VHS_VideoInfo"
193 | },
194 | "widgets_values": {}
195 | },
196 | {
197 | "id": 11,
198 | "type": "CR Float To Integer",
199 | "pos": [
200 | -1545,
201 | 854
202 | ],
203 | "size": {
204 | "0": 315,
205 | "1": 78
206 | },
207 | "flags": {},
208 | "order": 3,
209 | "mode": 0,
210 | "inputs": [
211 | {
212 | "name": "_float",
213 | "type": "FLOAT",
214 | "link": 10,
215 | "widget": {
216 | "name": "_float"
217 | },
218 | "label": "_float"
219 | }
220 | ],
221 | "outputs": [
222 | {
223 | "name": "INT",
224 | "type": "INT",
225 | "links": [
226 | 13
227 | ],
228 | "shape": 3,
229 | "label": "INT",
230 | "slot_index": 0
231 | },
232 | {
233 | "name": "show_help",
234 | "type": "STRING",
235 | "links": null,
236 | "shape": 3,
237 | "label": "show_help"
238 | }
239 | ],
240 | "properties": {
241 | "Node name for S&R": "CR Float To Integer"
242 | },
243 | "widgets_values": [
244 | 0
245 | ]
246 | },
247 | {
248 | "id": 8,
249 | "type": "LoadImage",
250 | "pos": [
251 | -1964,
252 | 326
253 | ],
254 | "size": {
255 | "0": 315,
256 | "1": 314
257 | },
258 | "flags": {},
259 | "order": 1,
260 | "mode": 0,
261 | "outputs": [
262 | {
263 | "name": "IMAGE",
264 | "type": "IMAGE",
265 | "links": [
266 | 14
267 | ],
268 | "shape": 3,
269 | "label": "IMAGE",
270 | "slot_index": 0
271 | },
272 | {
273 | "name": "MASK",
274 | "type": "MASK",
275 | "links": null,
276 | "shape": 3,
277 | "label": "MASK"
278 | }
279 | ],
280 | "properties": {
281 | "Node name for S&R": "LoadImage"
282 | },
283 | "widgets_values": [
284 | "solo (2).png",
285 | "image"
286 | ]
287 | },
288 | {
289 | "id": 9,
290 | "type": "VHS_VideoCombine",
291 | "pos": [
292 | -1166,
293 | 282
294 | ],
295 | "size": [
296 | 315,
297 | 599
298 | ],
299 | "flags": {},
300 | "order": 5,
301 | "mode": 0,
302 | "inputs": [
303 | {
304 | "name": "images",
305 | "type": "IMAGE",
306 | "link": 16,
307 | "label": "images"
308 | },
309 | {
310 | "name": "audio",
311 | "type": "VHS_AUDIO",
312 | "link": 8,
313 | "label": "audio"
314 | },
315 | {
316 | "name": "meta_batch",
317 | "type": "VHS_BatchManager",
318 | "link": null,
319 | "label": "meta_batch"
320 | }
321 | ],
322 | "outputs": [
323 | {
324 | "name": "Filenames",
325 | "type": "VHS_FILENAMES",
326 | "links": null,
327 | "shape": 3,
328 | "label": "Filenames"
329 | }
330 | ],
331 | "properties": {
332 | "Node name for S&R": "VHS_VideoCombine"
333 | },
334 | "widgets_values": {
335 | "frame_rate": 30,
336 | "loop_count": 0,
337 | "filename_prefix": "Aniportrait",
338 | "format": "video/h264-mp4",
339 | "pix_fmt": "yuv420p",
340 | "crf": 19,
341 | "save_metadata": true,
342 | "pingpong": false,
343 | "save_output": true,
344 | "videopreview": {
345 | "hidden": false,
346 | "paused": false,
347 | "params": {
348 | "filename": "Aniportrait_00003-audio.mp4",
349 | "subfolder": "",
350 | "type": "output",
351 | "format": "video/h264-mp4"
352 | }
353 | }
354 | }
355 | },
356 | {
357 | "id": 12,
358 | "type": "AniPortrait_Audio2Video",
359 | "pos": [
360 | -1545,
361 | 280
362 | ],
363 | "size": {
364 | "0": 315,
365 | "1": 506
366 | },
367 | "flags": {},
368 | "order": 4,
369 | "mode": 0,
370 | "inputs": [
371 | {
372 | "name": "ref_image",
373 | "type": "IMAGE",
374 | "link": 14,
375 | "label": "ref_image"
376 | },
377 | {
378 | "name": "images",
379 | "type": "IMAGE",
380 | "link": 15,
381 | "label": "images",
382 | "slot_index": 1
383 | },
384 | {
385 | "name": "audio_path",
386 | "type": "Audio_Path",
387 | "link": null,
388 | "label": "audio_path"
389 | },
390 | {
391 | "name": "fps",
392 | "type": "INT",
393 | "link": 13,
394 | "widget": {
395 | "name": "fps"
396 | },
397 | "label": "fps"
398 | }
399 | ],
400 | "outputs": [
401 | {
402 | "name": "images",
403 | "type": "IMAGE",
404 | "links": [
405 | 16
406 | ],
407 | "shape": 3,
408 | "label": "images",
409 | "slot_index": 0
410 | }
411 | ],
412 | "properties": {
413 | "Node name for S&R": "AniPortrait_Audio2Video"
414 | },
415 | "widgets_values": [
416 | 512,
417 | 512,
418 | 713,
419 | "randomize",
420 | 3.5,
421 | 25,
422 | "pretrained_model/sd-vae-ft-mse",
423 | "pretrained_model/stable-diffusion-v1-5",
424 | "fp16",
425 | true,
426 | 60,
427 | 3,
428 | "pretrained_model/motion_module.pth",
429 | "pretrained_model/image_encoder",
430 | "pretrained_model/denoising_unet.pth",
431 | "pretrained_model/reference_unet.pth",
432 | "pretrained_model/pose_guider.pth",
433 | 0
434 | ]
435 | }
436 | ],
437 | "links": [
438 | [
439 | 8,
440 | 7,
441 | 2,
442 | 9,
443 | 1,
444 | "VHS_AUDIO"
445 | ],
446 | [
447 | 9,
448 | 7,
449 | 3,
450 | 10,
451 | 0,
452 | "VHS_VIDEOINFO"
453 | ],
454 | [
455 | 10,
456 | 10,
457 | 0,
458 | 11,
459 | 0,
460 | "FLOAT"
461 | ],
462 | [
463 | 13,
464 | 11,
465 | 0,
466 | 12,
467 | 3,
468 | "INT"
469 | ],
470 | [
471 | 14,
472 | 8,
473 | 0,
474 | 12,
475 | 0,
476 | "IMAGE"
477 | ],
478 | [
479 | 15,
480 | 7,
481 | 0,
482 | 12,
483 | 1,
484 | "IMAGE"
485 | ],
486 | [
487 | 16,
488 | 12,
489 | 0,
490 | 9,
491 | 0,
492 | "IMAGE"
493 | ]
494 | ],
495 | "groups": [],
496 | "config": {},
497 | "extra": {},
498 | "version": 0.4
499 | }
--------------------------------------------------------------------------------
/assets/lyl.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/lyl.wav
--------------------------------------------------------------------------------
/assets/pose2video_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 59,
3 | "last_link_id": 47,
4 | "nodes": [
5 | {
6 | "id": 42,
7 | "type": "VHS_LoadVideo",
8 | "pos": [
9 | -1964,
10 | 875
11 | ],
12 | "size": [
13 | 235.1999969482422,
14 | 491.1999969482422
15 | ],
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "batch_manager",
22 | "type": "VHS_BatchManager",
23 | "link": null,
24 | "label": "batch_manager"
25 | }
26 | ],
27 | "outputs": [
28 | {
29 | "name": "IMAGE",
30 | "type": "IMAGE",
31 | "links": [
32 | 38
33 | ],
34 | "shape": 3,
35 | "label": "IMAGE",
36 | "slot_index": 0
37 | },
38 | {
39 | "name": "frame_count",
40 | "type": "INT",
41 | "links": null,
42 | "shape": 3,
43 | "label": "frame_count"
44 | },
45 | {
46 | "name": "audio",
47 | "type": "VHS_AUDIO",
48 | "links": [
49 | 34
50 | ],
51 | "shape": 3,
52 | "label": "audio"
53 | },
54 | {
55 | "name": "video_info",
56 | "type": "VHS_VIDEOINFO",
57 | "links": [
58 | 31
59 | ],
60 | "shape": 3,
61 | "label": "video_info",
62 | "slot_index": 3
63 | }
64 | ],
65 | "properties": {
66 | "Node name for S&R": "VHS_LoadVideo"
67 | },
68 | "widgets_values": {
69 | "video": "pose_ref_video.mp4",
70 | "force_rate": 0,
71 | "force_size": "Disabled",
72 | "custom_width": 512,
73 | "custom_height": 512,
74 | "frame_load_cap": 0,
75 | "skip_first_frames": 0,
76 | "select_every_nth": 1,
77 | "choose video to upload": "image",
78 | "videopreview": {
79 | "hidden": false,
80 | "paused": false,
81 | "params": {
82 | "frame_load_cap": 0,
83 | "skip_first_frames": 0,
84 | "force_rate": 0,
85 | "filename": "pose_ref_video.mp4",
86 | "type": "input",
87 | "format": "video/mp4",
88 | "select_every_nth": 1
89 | }
90 | }
91 | }
92 | },
93 | {
94 | "id": 29,
95 | "type": "VHS_VideoInfo",
96 | "pos": [
97 | -1642,
98 | 1154
99 | ],
100 | "size": {
101 | "0": 393,
102 | "1": 206
103 | },
104 | "flags": {},
105 | "order": 3,
106 | "mode": 0,
107 | "inputs": [
108 | {
109 | "name": "video_info",
110 | "type": "VHS_VIDEOINFO",
111 | "link": 31,
112 | "label": "video_info"
113 | }
114 | ],
115 | "outputs": [
116 | {
117 | "name": "source_fps🟨",
118 | "type": "FLOAT",
119 | "links": [],
120 | "shape": 3,
121 | "label": "source_fps🟨",
122 | "slot_index": 0
123 | },
124 | {
125 | "name": "source_frame_count🟨",
126 | "type": "INT",
127 | "links": [
128 | 42
129 | ],
130 | "shape": 3,
131 | "label": "source_frame_count🟨",
132 | "slot_index": 1
133 | },
134 | {
135 | "name": "source_duration🟨",
136 | "type": "FLOAT",
137 | "links": null,
138 | "shape": 3,
139 | "label": "source_duration🟨"
140 | },
141 | {
142 | "name": "source_width🟨",
143 | "type": "INT",
144 | "links": null,
145 | "shape": 3,
146 | "label": "source_width🟨",
147 | "slot_index": 3
148 | },
149 | {
150 | "name": "source_height🟨",
151 | "type": "INT",
152 | "links": null,
153 | "shape": 3,
154 | "label": "source_height🟨"
155 | },
156 | {
157 | "name": "loaded_fps🟦",
158 | "type": "FLOAT",
159 | "links": null,
160 | "shape": 3,
161 | "label": "loaded_fps🟦"
162 | },
163 | {
164 | "name": "loaded_frame_count🟦",
165 | "type": "INT",
166 | "links": null,
167 | "shape": 3,
168 | "label": "loaded_frame_count🟦"
169 | },
170 | {
171 | "name": "loaded_duration🟦",
172 | "type": "FLOAT",
173 | "links": null,
174 | "shape": 3,
175 | "label": "loaded_duration🟦"
176 | },
177 | {
178 | "name": "loaded_width🟦",
179 | "type": "INT",
180 | "links": null,
181 | "shape": 3,
182 | "label": "loaded_width🟦"
183 | },
184 | {
185 | "name": "loaded_height🟦",
186 | "type": "INT",
187 | "links": null,
188 | "shape": 3,
189 | "label": "loaded_height🟦"
190 | }
191 | ],
192 | "properties": {
193 | "Node name for S&R": "VHS_VideoInfo"
194 | },
195 | "widgets_values": {}
196 | },
197 | {
198 | "id": 52,
199 | "type": "AniPortrait_Video_Gen_Pose",
200 | "pos": [
201 | -1642,
202 | 1005
203 | ],
204 | "size": {
205 | "0": 361.20001220703125,
206 | "1": 106
207 | },
208 | "flags": {},
209 | "order": 2,
210 | "mode": 0,
211 | "inputs": [
212 | {
213 | "name": "image",
214 | "type": "IMAGE",
215 | "link": 38,
216 | "label": "image"
217 | }
218 | ],
219 | "outputs": [
220 | {
221 | "name": "pose_images",
222 | "type": "IMAGE",
223 | "links": [
224 | 41,
225 | 47
226 | ],
227 | "shape": 3,
228 | "label": "pose_images",
229 | "slot_index": 0
230 | }
231 | ],
232 | "properties": {
233 | "Node name for S&R": "AniPortrait_Video_Gen_Pose"
234 | },
235 | "widgets_values": [
236 | "AniPortrait",
237 | 512,
238 | 512
239 | ]
240 | },
241 | {
242 | "id": 56,
243 | "type": "LoadImage",
244 | "pos": [
245 | -1649,
246 | 633
247 | ],
248 | "size": {
249 | "0": 315,
250 | "1": 314
251 | },
252 | "flags": {},
253 | "order": 1,
254 | "mode": 0,
255 | "outputs": [
256 | {
257 | "name": "IMAGE",
258 | "type": "IMAGE",
259 | "links": [
260 | 45
261 | ],
262 | "shape": 3,
263 | "label": "IMAGE",
264 | "slot_index": 0
265 | },
266 | {
267 | "name": "MASK",
268 | "type": "MASK",
269 | "links": null,
270 | "shape": 3,
271 | "label": "MASK"
272 | }
273 | ],
274 | "properties": {
275 | "Node name for S&R": "LoadImage"
276 | },
277 | "widgets_values": [
278 | "solo.png",
279 | "image"
280 | ]
281 | },
282 | {
283 | "id": 53,
284 | "type": "AniPortrait_Pose_Gen_Video",
285 | "pos": [
286 | -1226,
287 | 797
288 | ],
289 | "size": {
290 | "0": 315,
291 | "1": 462
292 | },
293 | "flags": {},
294 | "order": 5,
295 | "mode": 0,
296 | "inputs": [
297 | {
298 | "name": "ref_image",
299 | "type": "IMAGE",
300 | "link": 45,
301 | "label": "ref_image",
302 | "slot_index": 0
303 | },
304 | {
305 | "name": "pose_images",
306 | "type": "IMAGE",
307 | "link": 41,
308 | "label": "pose_images"
309 | },
310 | {
311 | "name": "frame_count",
312 | "type": "INT",
313 | "link": 42,
314 | "widget": {
315 | "name": "frame_count"
316 | },
317 | "slot_index": 2,
318 | "label": "frame_count"
319 | }
320 | ],
321 | "outputs": [
322 | {
323 | "name": "images",
324 | "type": "IMAGE",
325 | "links": [
326 | 43,
327 | 46
328 | ],
329 | "shape": 3,
330 | "label": "images",
331 | "slot_index": 0
332 | }
333 | ],
334 | "properties": {
335 | "Node name for S&R": "AniPortrait_Pose_Gen_Video"
336 | },
337 | "widgets_values": [
338 | 0,
339 | 512,
340 | 512,
341 | 688,
342 | "randomize",
343 | 3.5,
344 | 25,
345 | "pretrained_model/sd-vae-ft-mse",
346 | "pretrained_model/stable-diffusion-v1-5",
347 | "fp16",
348 | true,
349 | 3,
350 | "pretrained_model/motion_module.pth",
351 | "pretrained_model/image_encoder",
352 | "pretrained_model/denoising_unet.pth",
353 | "pretrained_model/reference_unet.pth",
354 | "pretrained_model/pose_guider.pth"
355 | ]
356 | },
357 | {
358 | "id": 44,
359 | "type": "VHS_VideoCombine",
360 | "pos": [
361 | -877,
362 | 794
363 | ],
364 | "size": [
365 | 315,
366 | 599
367 | ],
368 | "flags": {},
369 | "order": 6,
370 | "mode": 0,
371 | "inputs": [
372 | {
373 | "name": "images",
374 | "type": "IMAGE",
375 | "link": 43,
376 | "label": "images",
377 | "slot_index": 0
378 | },
379 | {
380 | "name": "audio",
381 | "type": "VHS_AUDIO",
382 | "link": 34,
383 | "label": "audio",
384 | "slot_index": 1
385 | },
386 | {
387 | "name": "batch_manager",
388 | "type": "VHS_BatchManager",
389 | "link": null,
390 | "label": "batch_manager"
391 | }
392 | ],
393 | "outputs": [
394 | {
395 | "name": "Filenames",
396 | "type": "VHS_FILENAMES",
397 | "links": null,
398 | "shape": 3,
399 | "label": "Filenames"
400 | }
401 | ],
402 | "properties": {
403 | "Node name for S&R": "VHS_VideoCombine"
404 | },
405 | "widgets_values": {
406 | "frame_rate": 25,
407 | "loop_count": 0,
408 | "filename_prefix": "Aniportrait",
409 | "format": "video/h264-mp4",
410 | "pix_fmt": "yuv420p",
411 | "crf": 19,
412 | "save_metadata": true,
413 | "pingpong": false,
414 | "save_output": true,
415 | "videopreview": {
416 | "hidden": false,
417 | "paused": false,
418 | "params": {
419 | "filename": "Aniportrait_00004-audio.mp4",
420 | "subfolder": "",
421 | "type": "output",
422 | "format": "video/h264-mp4"
423 | }
424 | }
425 | }
426 | },
427 | {
428 | "id": 59,
429 | "type": "PreviewImage",
430 | "pos": [
431 | -1245,
432 | 504
433 | ],
434 | "size": {
435 | "0": 210,
436 | "1": 246
437 | },
438 | "flags": {},
439 | "order": 4,
440 | "mode": 0,
441 | "inputs": [
442 | {
443 | "name": "images",
444 | "type": "IMAGE",
445 | "link": 47,
446 | "label": "images"
447 | }
448 | ],
449 | "properties": {
450 | "Node name for S&R": "PreviewImage"
451 | }
452 | },
453 | {
454 | "id": 58,
455 | "type": "PreviewImage",
456 | "pos": [
457 | -890,
458 | 503
459 | ],
460 | "size": {
461 | "0": 210,
462 | "1": 246
463 | },
464 | "flags": {},
465 | "order": 7,
466 | "mode": 0,
467 | "inputs": [
468 | {
469 | "name": "images",
470 | "type": "IMAGE",
471 | "link": 46,
472 | "label": "images"
473 | }
474 | ],
475 | "properties": {
476 | "Node name for S&R": "PreviewImage"
477 | }
478 | }
479 | ],
480 | "links": [
481 | [
482 | 31,
483 | 42,
484 | 3,
485 | 29,
486 | 0,
487 | "VHS_VIDEOINFO"
488 | ],
489 | [
490 | 34,
491 | 42,
492 | 2,
493 | 44,
494 | 1,
495 | "VHS_AUDIO"
496 | ],
497 | [
498 | 38,
499 | 42,
500 | 0,
501 | 52,
502 | 0,
503 | "IMAGE"
504 | ],
505 | [
506 | 41,
507 | 52,
508 | 0,
509 | 53,
510 | 1,
511 | "IMAGE"
512 | ],
513 | [
514 | 42,
515 | 29,
516 | 1,
517 | 53,
518 | 2,
519 | "INT"
520 | ],
521 | [
522 | 43,
523 | 53,
524 | 0,
525 | 44,
526 | 0,
527 | "IMAGE"
528 | ],
529 | [
530 | 45,
531 | 56,
532 | 0,
533 | 53,
534 | 0,
535 | "IMAGE"
536 | ],
537 | [
538 | 46,
539 | 53,
540 | 0,
541 | 58,
542 | 0,
543 | "IMAGE"
544 | ],
545 | [
546 | 47,
547 | 52,
548 | 0,
549 | 59,
550 | 0,
551 | "IMAGE"
552 | ]
553 | ],
554 | "groups": [],
555 | "config": {},
556 | "extra": {},
557 | "version": 0.4
558 | }
--------------------------------------------------------------------------------
/assets/pose_ref_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/pose_ref_video.mp4
--------------------------------------------------------------------------------
/assets/solo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/solo.png
--------------------------------------------------------------------------------
/assets/woman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/assets/woman.jpg
--------------------------------------------------------------------------------
/configs/inference/inference_audio.yaml:
--------------------------------------------------------------------------------
1 | a2m_model:
2 | out_dim: 1404
3 | latent_dim: 512
4 | model_path: pretrained_model/wav2vec2-base-960h
5 | only_last_fetures: True
6 | from_pretrained: True
7 |
8 | a2p_model:
9 | out_dim: 6
10 | latent_dim: 512
11 | model_path: pretrained_model/wav2vec2-base-960h
12 | only_last_fetures: True
13 | from_pretrained: True
14 |
15 | pretrained_model:
16 | a2m_ckpt: pretrained_model/audio2mesh.pt
17 | a2p_ckpt: pretrained_model/audio2pose.pt
18 |
--------------------------------------------------------------------------------
/configs/inference/inference_v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 32
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "linear"
28 | clip_sample: false
29 | steps_offset: 1
30 | ### Zero-SNR params
31 | prediction_type: "v_prediction"
32 | rescale_betas_zero_snr: True
33 | timestep_spacing: "trailing"
34 |
35 | sampler: DDIM
--------------------------------------------------------------------------------
/configs/prompts/animation.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: 'pretrained_model/stable-diffusion-v1-5'
2 | pretrained_vae_path: 'pretrained_model/sd-vae-ft-mse'
3 | image_encoder_path: 'pretrained_model/image_encoder'
4 |
5 | denoising_unet_path: "pretrained_model/denoising_unet.pth"
6 | reference_unet_path: "pretrained_model/reference_unet.pth"
7 | pose_guider_path: "pretrained_model/pose_guider.pth"
8 | motion_module_path: "pretrained_model/motion_module.pth"
9 |
10 | inference_config: "configs/inference/inference_v2.yaml"
11 |
--------------------------------------------------------------------------------
/configs/prompts/animation_audio.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: 'pretrained_model/stable-diffusion-v1-5'
2 | pretrained_vae_path: 'pretrained_model/sd-vae-ft-mse'
3 | image_encoder_path: 'pretrained_model/image_encoder'
4 |
5 | denoising_unet_path: "pretrained_model/denoising_unet.pth"
6 | reference_unet_path: "pretrained_model/reference_unet.pth"
7 | pose_guider_path: "pretrained_model/pose_guider.pth"
8 | motion_module_path: "pretrained_model/motion_module.pth"
9 |
10 | audio_inference_config: "configs/inference/inference_audio.yaml"
11 | inference_config: "configs/inference/inference_v2.yaml"
12 | weight_dtype: 'fp16'
13 |
14 | #pose_temp: "./configs/inference/head_pose_temp/pose_temp.npy"
15 |
16 | test_cases:
17 | # "./configs/inference/ref_images/lyl.png":
18 | # - "./configs/inference/audio/lyl.wav"
19 |
--------------------------------------------------------------------------------
/configs/prompts/animation_facereenac.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5'
2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse'
3 | image_encoder_path: './pretrained_model/image_encoder'
4 |
5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth"
6 | reference_unet_path: "./pretrained_model/reference_unet.pth"
7 | pose_guider_path: "./pretrained_model/pose_guider.pth"
8 | motion_module_path: "./pretrained_model/motion_module.pth"
9 |
10 | inference_config: "./configs/inference/inference_v2.yaml"
11 | weight_dtype: 'fp16'
12 |
13 | test_cases:
14 | # "./configs/inference/ref_images/Aragaki.png":
15 | # - "./configs/inference/video/Aragaki_song.mp4"
16 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui_aniportrait"
3 | description = "implementation of [a/AniPortrait](https://github.com/Zejun-Yang/AniPortrait) generating of videos, includes self driven, face reenacment and audio driven with a reference image"
4 | version = "1.0.0"
5 | license = "LICENSE"
6 | dependencies = ["mediapipe==0.10.11", "ffmpeg-python==0.2.0", "av==11.0.0", "librosa==0.9.2", "diffusers==0.26.2"]
7 |
8 | [project.urls]
9 | Repository = "https://github.com/frankchieng/ComfyUI_Aniportrait"
10 | # Used by Comfy Registry https://comfyregistry.org
11 |
12 | [tool.comfy]
13 | PublisherId = "frankchieng"
14 | DisplayName = "ComfyUI_Aniportrait"
15 | Icon = ""
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | mediapipe==0.10.11
2 | ffmpeg-python==0.2.0
3 | av==11.0.0
4 | librosa==0.9.2
5 | diffusers==0.26.2
6 | omegaconf==2.3.0
7 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | #Dummy file ensuring this package will be recognized
2 |
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/audio_models/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/audio_models/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/src/audio_models/__pycache__/torch_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/audio_models/__pycache__/torch_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/audio_models/__pycache__/wav2vec2.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/audio_models/__pycache__/wav2vec2.cpython-310.pyc
--------------------------------------------------------------------------------
/src/audio_models/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from transformers import Wav2Vec2Config
6 |
7 | from .torch_utils import get_mask_from_lengths
8 | from .wav2vec2 import Wav2Vec2Model
9 |
10 |
11 | class Audio2MeshModel(nn.Module):
12 | def __init__(
13 | self,
14 | config
15 | ):
16 | super().__init__()
17 | out_dim = config['out_dim']
18 | latent_dim = config['latent_dim']
19 | model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), config['model_path'])
20 | #model_path = config['model_path']
21 | only_last_fetures = config['only_last_fetures']
22 | from_pretrained = config['from_pretrained']
23 |
24 | self._only_last_features = only_last_fetures
25 |
26 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True)
27 | if from_pretrained:
28 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True)
29 | else:
30 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config)
31 | self.audio_encoder.feature_extractor._freeze_parameters()
32 |
33 | hidden_size = self.audio_encoder_config.hidden_size
34 |
35 | self.in_fn = nn.Linear(hidden_size, latent_dim)
36 |
37 | self.out_fn = nn.Linear(latent_dim, out_dim)
38 | nn.init.constant_(self.out_fn.weight, 0)
39 | nn.init.constant_(self.out_fn.bias, 0)
40 |
41 | def forward(self, audio, label, audio_len=None):
42 | attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None
43 |
44 | seq_len = label.shape[1]
45 |
46 | embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True,
47 | attention_mask=attention_mask)
48 |
49 | if self._only_last_features:
50 | hidden_states = embeddings.last_hidden_state
51 | else:
52 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
53 |
54 | layer_in = self.in_fn(hidden_states)
55 | out = self.out_fn(layer_in)
56 |
57 | return out, None
58 |
59 | def infer(self, input_value, seq_len):
60 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True)
61 |
62 | if self._only_last_features:
63 | hidden_states = embeddings.last_hidden_state
64 | else:
65 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
66 |
67 | layer_in = self.in_fn(hidden_states)
68 | out = self.out_fn(layer_in)
69 |
70 | return out
71 |
72 |
73 |
--------------------------------------------------------------------------------
/src/audio_models/pose_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | from transformers import Wav2Vec2Config
6 |
7 | from .torch_utils import get_mask_from_lengths
8 | from .wav2vec2 import Wav2Vec2Model
9 |
10 |
11 | def init_biased_mask(n_head, max_seq_len, period):
12 | def get_slopes(n):
13 | def get_slopes_power_of_2(n):
14 | start = (2**(-2**-(math.log2(n)-3)))
15 | ratio = start
16 | return [start*ratio**i for i in range(n)]
17 | if math.log2(n).is_integer():
18 | return get_slopes_power_of_2(n)
19 | else:
20 | closest_power_of_2 = 2**math.floor(math.log2(n))
21 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
22 | slopes = torch.Tensor(get_slopes(n_head))
23 | bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period)
24 | bias = - torch.flip(bias,dims=[0])
25 | alibi = torch.zeros(max_seq_len, max_seq_len)
26 | for i in range(max_seq_len):
27 | alibi[i, :i+1] = bias[-(i+1):]
28 | alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
29 | mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
30 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
31 | mask = mask.unsqueeze(0) + alibi
32 | return mask
33 |
34 |
35 | def enc_dec_mask(device, T, S):
36 | mask = torch.ones(T, S)
37 | for i in range(T):
38 | mask[i, i] = 0
39 | return (mask==1).to(device=device)
40 |
41 |
42 | class PositionalEncoding(nn.Module):
43 | def __init__(self, d_model, max_len=600):
44 | super(PositionalEncoding, self).__init__()
45 | pe = torch.zeros(max_len, d_model)
46 | position = torch.arange(0, max_len).unsqueeze(1).float()
47 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
48 | pe[:, 0::2] = torch.sin(position * div_term)
49 | pe[:, 1::2] = torch.cos(position * div_term)
50 | pe = pe.unsqueeze(0)
51 | self.register_buffer('pe', pe)
52 |
53 | def forward(self, x):
54 | x = x + self.pe[:, :x.size(1)]
55 | return x
56 |
57 |
58 | class Audio2PoseModel(nn.Module):
59 | def __init__(
60 | self,
61 | config
62 | ):
63 |
64 | super().__init__()
65 |
66 | latent_dim = config['latent_dim']
67 | model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), config['model_path'])
68 | #model_path = config['model_path']
69 | only_last_fetures = config['only_last_fetures']
70 | from_pretrained = config['from_pretrained']
71 | out_dim = config['out_dim']
72 |
73 | self.out_dim = out_dim
74 |
75 | self._only_last_features = only_last_fetures
76 |
77 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True)
78 | if from_pretrained:
79 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True)
80 | else:
81 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config)
82 | self.audio_encoder.feature_extractor._freeze_parameters()
83 |
84 | hidden_size = self.audio_encoder_config.hidden_size
85 |
86 | self.pose_map = nn.Linear(out_dim, latent_dim)
87 | self.in_fn = nn.Linear(hidden_size, latent_dim)
88 |
89 | self.PPE = PositionalEncoding(latent_dim)
90 | self.biased_mask = init_biased_mask(n_head = 8, max_seq_len = 600, period=1)
91 | decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8, dim_feedforward=2*latent_dim, batch_first=True)
92 | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=8)
93 | self.pose_map_r = nn.Linear(latent_dim, out_dim)
94 |
95 | self.id_embed = nn.Embedding(100, latent_dim) # 100 ids
96 |
97 |
98 | def infer(self, input_value, seq_len, id_seed=None):
99 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True)
100 |
101 | if self._only_last_features:
102 | hidden_states = embeddings.last_hidden_state
103 | else:
104 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
105 |
106 | hidden_states = self.in_fn(hidden_states)
107 |
108 | id_embedding = self.id_embed(id_seed).unsqueeze(1)
109 |
110 | init_pose = torch.zeros([hidden_states.shape[0], 1, self.out_dim]).to(hidden_states.device)
111 | for i in range(seq_len):
112 | if i==0:
113 | pose_emb = self.pose_map(init_pose)
114 | pose_input = self.PPE(pose_emb)
115 | else:
116 | pose_input = self.PPE(pose_emb)
117 |
118 | pose_input = pose_input + id_embedding
119 | tgt_mask = self.biased_mask[:, :pose_input.shape[1], :pose_input.shape[1]].clone().detach().to(hidden_states.device)
120 | memory_mask = enc_dec_mask(hidden_states.device, pose_input.shape[1], hidden_states.shape[1])
121 | pose_out = self.transformer_decoder(pose_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask)
122 | pose_out = self.pose_map_r(pose_out)
123 | new_output = self.pose_map(pose_out[:,-1,:]).unsqueeze(1)
124 | pose_emb = torch.cat((pose_emb, new_output), 1)
125 | return pose_out
126 |
127 |
--------------------------------------------------------------------------------
/src/audio_models/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def get_mask_from_lengths(lengths, max_len=None):
6 | lengths = lengths.to(torch.long)
7 | if max_len is None:
8 | max_len = torch.max(lengths).item()
9 |
10 | ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
11 | mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
12 |
13 | return mask
14 |
15 |
16 | def linear_interpolation(features, seq_len):
17 | features = features.transpose(1, 2)
18 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
19 | return output_features.transpose(1, 2)
20 |
21 |
22 | if __name__ == "__main__":
23 | import numpy as np
24 | mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6])))
25 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/src/audio_models/wav2vec2.py:
--------------------------------------------------------------------------------
1 | from transformers import Wav2Vec2Config, Wav2Vec2Model
2 | from transformers.modeling_outputs import BaseModelOutput
3 |
4 | from .torch_utils import linear_interpolation
5 |
6 | # the implementation of Wav2Vec2Model is borrowed from
7 | # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
8 | # initialize our encoder with the pre-trained wav2vec 2.0 weights.
9 | class Wav2Vec2Model(Wav2Vec2Model):
10 | def __init__(self, config: Wav2Vec2Config):
11 | super().__init__(config)
12 |
13 | def forward(
14 | self,
15 | input_values,
16 | seq_len,
17 | attention_mask=None,
18 | mask_time_indices=None,
19 | output_attentions=None,
20 | output_hidden_states=None,
21 | return_dict=None,
22 | ):
23 | self.config.output_attentions = True
24 |
25 | output_hidden_states = (
26 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
27 | )
28 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
29 |
30 | extract_features = self.feature_extractor(input_values)
31 | extract_features = extract_features.transpose(1, 2)
32 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
33 |
34 | if attention_mask is not None:
35 | # compute reduced attention_mask corresponding to feature vectors
36 | attention_mask = self._get_feature_vector_attention_mask(
37 | extract_features.shape[1], attention_mask, add_adapter=False
38 | )
39 |
40 | hidden_states, extract_features = self.feature_projection(extract_features)
41 | hidden_states = self._mask_hidden_states(
42 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
43 | )
44 |
45 | encoder_outputs = self.encoder(
46 | hidden_states,
47 | attention_mask=attention_mask,
48 | output_attentions=output_attentions,
49 | output_hidden_states=output_hidden_states,
50 | return_dict=return_dict,
51 | )
52 |
53 | hidden_states = encoder_outputs[0]
54 |
55 | if self.adapter is not None:
56 | hidden_states = self.adapter(hidden_states)
57 |
58 | if not return_dict:
59 | return (hidden_states, ) + encoder_outputs[1:]
60 | return BaseModelOutput(
61 | last_hidden_state=hidden_states,
62 | hidden_states=encoder_outputs.hidden_states,
63 | attentions=encoder_outputs.attentions,
64 | )
65 |
66 |
67 | def feature_extract(
68 | self,
69 | input_values,
70 | seq_len,
71 | ):
72 | extract_features = self.feature_extractor(input_values)
73 | extract_features = extract_features.transpose(1, 2)
74 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
75 |
76 | return extract_features
77 |
78 | def encode(
79 | self,
80 | extract_features,
81 | attention_mask=None,
82 | mask_time_indices=None,
83 | output_attentions=None,
84 | output_hidden_states=None,
85 | return_dict=None,
86 | ):
87 | self.config.output_attentions = True
88 |
89 | output_hidden_states = (
90 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
91 | )
92 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
93 |
94 | if attention_mask is not None:
95 | # compute reduced attention_mask corresponding to feature vectors
96 | attention_mask = self._get_feature_vector_attention_mask(
97 | extract_features.shape[1], attention_mask, add_adapter=False
98 | )
99 |
100 |
101 | hidden_states, extract_features = self.feature_projection(extract_features)
102 | hidden_states = self._mask_hidden_states(
103 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
104 | )
105 |
106 | encoder_outputs = self.encoder(
107 | hidden_states,
108 | attention_mask=attention_mask,
109 | output_attentions=output_attentions,
110 | output_hidden_states=output_hidden_states,
111 | return_dict=return_dict,
112 | )
113 |
114 | hidden_states = encoder_outputs[0]
115 |
116 | if self.adapter is not None:
117 | hidden_states = self.adapter(hidden_states)
118 |
119 | if not return_dict:
120 | return (hidden_states, ) + encoder_outputs[1:]
121 | return BaseModelOutput(
122 | last_hidden_state=hidden_states,
123 | hidden_states=encoder_outputs.hidden_states,
124 | attentions=encoder_outputs.attentions,
125 | )
126 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/motion_module.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/motion_module.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/mutual_self_attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/mutual_self_attention.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/pose_guider.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/pose_guider.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/resnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/resnet.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/transformer_2d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/transformer_2d.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/transformer_3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/transformer_3d.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_2d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_2d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_2d_condition.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_2d_condition.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_3d.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/unet_3d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/models/__pycache__/unet_3d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/src/models/motion_module.py:
--------------------------------------------------------------------------------
1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional
5 |
6 | import torch
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.attention_processor import Attention, AttnProcessor
9 | from diffusers.utils import BaseOutput
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from einops import rearrange, repeat
12 | from torch import nn
13 |
14 |
15 | def zero_module(module):
16 | # Zero out the parameters of a module and return it.
17 | for p in module.parameters():
18 | p.detach().zero_()
19 | return module
20 |
21 |
22 | @dataclass
23 | class TemporalTransformer3DModelOutput(BaseOutput):
24 | sample: torch.FloatTensor
25 |
26 |
27 | if is_xformers_available():
28 | import xformers
29 | import xformers.ops
30 | else:
31 | xformers = None
32 |
33 |
34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35 | if motion_module_type == "Vanilla":
36 | return VanillaTemporalModule(
37 | in_channels=in_channels,
38 | **motion_module_kwargs,
39 | )
40 | else:
41 | raise ValueError
42 |
43 |
44 | class VanillaTemporalModule(nn.Module):
45 | def __init__(
46 | self,
47 | in_channels,
48 | num_attention_heads=8,
49 | num_transformer_block=2,
50 | attention_block_types=("Temporal_Self", "Temporal_Self"),
51 | cross_frame_attention_mode=None,
52 | temporal_position_encoding=False,
53 | temporal_position_encoding_max_len=24,
54 | temporal_attention_dim_div=1,
55 | zero_initialize=True,
56 | ):
57 | super().__init__()
58 |
59 | self.temporal_transformer = TemporalTransformer3DModel(
60 | in_channels=in_channels,
61 | num_attention_heads=num_attention_heads,
62 | attention_head_dim=in_channels
63 | // num_attention_heads
64 | // temporal_attention_dim_div,
65 | num_layers=num_transformer_block,
66 | attention_block_types=attention_block_types,
67 | cross_frame_attention_mode=cross_frame_attention_mode,
68 | temporal_position_encoding=temporal_position_encoding,
69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70 | )
71 |
72 | if zero_initialize:
73 | self.temporal_transformer.proj_out = zero_module(
74 | self.temporal_transformer.proj_out
75 | )
76 |
77 | def forward(
78 | self,
79 | input_tensor,
80 | temb,
81 | encoder_hidden_states,
82 | attention_mask=None,
83 | anchor_frame_idx=None,
84 | ):
85 | hidden_states = input_tensor
86 | hidden_states = self.temporal_transformer(
87 | hidden_states, encoder_hidden_states, attention_mask
88 | )
89 |
90 | output = hidden_states
91 | return output
92 |
93 |
94 | class TemporalTransformer3DModel(nn.Module):
95 | def __init__(
96 | self,
97 | in_channels,
98 | num_attention_heads,
99 | attention_head_dim,
100 | num_layers,
101 | attention_block_types=(
102 | "Temporal_Self",
103 | "Temporal_Self",
104 | ),
105 | dropout=0.0,
106 | norm_num_groups=32,
107 | cross_attention_dim=768,
108 | activation_fn="geglu",
109 | attention_bias=False,
110 | upcast_attention=False,
111 | cross_frame_attention_mode=None,
112 | temporal_position_encoding=False,
113 | temporal_position_encoding_max_len=24,
114 | ):
115 | super().__init__()
116 |
117 | inner_dim = num_attention_heads * attention_head_dim
118 |
119 | self.norm = torch.nn.GroupNorm(
120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121 | )
122 | self.proj_in = nn.Linear(in_channels, inner_dim)
123 |
124 | self.transformer_blocks = nn.ModuleList(
125 | [
126 | TemporalTransformerBlock(
127 | dim=inner_dim,
128 | num_attention_heads=num_attention_heads,
129 | attention_head_dim=attention_head_dim,
130 | attention_block_types=attention_block_types,
131 | dropout=dropout,
132 | norm_num_groups=norm_num_groups,
133 | cross_attention_dim=cross_attention_dim,
134 | activation_fn=activation_fn,
135 | attention_bias=attention_bias,
136 | upcast_attention=upcast_attention,
137 | cross_frame_attention_mode=cross_frame_attention_mode,
138 | temporal_position_encoding=temporal_position_encoding,
139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140 | )
141 | for d in range(num_layers)
142 | ]
143 | )
144 | self.proj_out = nn.Linear(inner_dim, in_channels)
145 |
146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147 | assert (
148 | hidden_states.dim() == 5
149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150 | video_length = hidden_states.shape[2]
151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152 |
153 | batch, channel, height, weight = hidden_states.shape
154 | residual = hidden_states
155 |
156 | hidden_states = self.norm(hidden_states)
157 | inner_dim = hidden_states.shape[1]
158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159 | batch, height * weight, inner_dim
160 | )
161 | hidden_states = self.proj_in(hidden_states)
162 |
163 | # Transformer Blocks
164 | for block in self.transformer_blocks:
165 | hidden_states = block(
166 | hidden_states,
167 | encoder_hidden_states=encoder_hidden_states,
168 | video_length=video_length,
169 | )
170 |
171 | # output
172 | hidden_states = self.proj_out(hidden_states)
173 | hidden_states = (
174 | hidden_states.reshape(batch, height, weight, inner_dim)
175 | .permute(0, 3, 1, 2)
176 | .contiguous()
177 | )
178 |
179 | output = hidden_states + residual
180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181 |
182 | return output
183 |
184 |
185 | class TemporalTransformerBlock(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | num_attention_heads,
190 | attention_head_dim,
191 | attention_block_types=(
192 | "Temporal_Self",
193 | "Temporal_Self",
194 | ),
195 | dropout=0.0,
196 | norm_num_groups=32,
197 | cross_attention_dim=768,
198 | activation_fn="geglu",
199 | attention_bias=False,
200 | upcast_attention=False,
201 | cross_frame_attention_mode=None,
202 | temporal_position_encoding=False,
203 | temporal_position_encoding_max_len=24,
204 | ):
205 | super().__init__()
206 |
207 | attention_blocks = []
208 | norms = []
209 |
210 | for block_name in attention_block_types:
211 | attention_blocks.append(
212 | VersatileAttention(
213 | attention_mode=block_name.split("_")[0],
214 | cross_attention_dim=cross_attention_dim
215 | if block_name.endswith("_Cross")
216 | else None,
217 | query_dim=dim,
218 | heads=num_attention_heads,
219 | dim_head=attention_head_dim,
220 | dropout=dropout,
221 | bias=attention_bias,
222 | upcast_attention=upcast_attention,
223 | cross_frame_attention_mode=cross_frame_attention_mode,
224 | temporal_position_encoding=temporal_position_encoding,
225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226 | )
227 | )
228 | norms.append(nn.LayerNorm(dim))
229 |
230 | self.attention_blocks = nn.ModuleList(attention_blocks)
231 | self.norms = nn.ModuleList(norms)
232 |
233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234 | self.ff_norm = nn.LayerNorm(dim)
235 |
236 | def forward(
237 | self,
238 | hidden_states,
239 | encoder_hidden_states=None,
240 | attention_mask=None,
241 | video_length=None,
242 | ):
243 | for attention_block, norm in zip(self.attention_blocks, self.norms):
244 | norm_hidden_states = norm(hidden_states)
245 | hidden_states = (
246 | attention_block(
247 | norm_hidden_states,
248 | encoder_hidden_states=encoder_hidden_states
249 | if attention_block.is_cross_attention
250 | else None,
251 | video_length=video_length,
252 | )
253 | + hidden_states
254 | )
255 |
256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257 |
258 | output = hidden_states
259 | return output
260 |
261 |
262 | class PositionalEncoding(nn.Module):
263 | def __init__(self, d_model, dropout=0.0, max_len=24):
264 | super().__init__()
265 | self.dropout = nn.Dropout(p=dropout)
266 | position = torch.arange(max_len).unsqueeze(1)
267 | div_term = torch.exp(
268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269 | )
270 | pe = torch.zeros(1, max_len, d_model)
271 | pe[0, :, 0::2] = torch.sin(position * div_term)
272 | pe[0, :, 1::2] = torch.cos(position * div_term)
273 | self.register_buffer("pe", pe)
274 |
275 | def forward(self, x):
276 | x = x + self.pe[:, : x.size(1)]
277 | return self.dropout(x)
278 |
279 |
280 | class VersatileAttention(Attention):
281 | def __init__(
282 | self,
283 | attention_mode=None,
284 | cross_frame_attention_mode=None,
285 | temporal_position_encoding=False,
286 | temporal_position_encoding_max_len=24,
287 | *args,
288 | **kwargs,
289 | ):
290 | super().__init__(*args, **kwargs)
291 | assert attention_mode == "Temporal"
292 |
293 | self.attention_mode = attention_mode
294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295 |
296 | self.pos_encoder = (
297 | PositionalEncoding(
298 | kwargs["query_dim"],
299 | dropout=0.0,
300 | max_len=temporal_position_encoding_max_len,
301 | )
302 | if (temporal_position_encoding and attention_mode == "Temporal")
303 | else None
304 | )
305 |
306 | def extra_repr(self):
307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308 |
309 | def set_use_memory_efficient_attention_xformers(
310 | self,
311 | use_memory_efficient_attention_xformers: bool,
312 | attention_op: Optional[Callable] = None,
313 | ):
314 | if use_memory_efficient_attention_xformers:
315 | if not is_xformers_available():
316 | raise ModuleNotFoundError(
317 | (
318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319 | " xformers"
320 | ),
321 | name="xformers",
322 | )
323 | elif not torch.cuda.is_available():
324 | raise ValueError(
325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326 | " only available for GPU "
327 | )
328 | else:
329 | try:
330 | # Make sure we can run the memory efficient attention
331 | _ = xformers.ops.memory_efficient_attention(
332 | torch.randn((1, 2, 40), device="cuda"),
333 | torch.randn((1, 2, 40), device="cuda"),
334 | torch.randn((1, 2, 40), device="cuda"),
335 | )
336 | except Exception as e:
337 | raise e
338 |
339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341 | # You don't need XFormersAttnProcessor here.
342 | # processor = XFormersAttnProcessor(
343 | # attention_op=attention_op,
344 | # )
345 | processor = AttnProcessor()
346 | else:
347 | processor = AttnProcessor()
348 |
349 | self.set_processor(processor)
350 |
351 | def forward(
352 | self,
353 | hidden_states,
354 | encoder_hidden_states=None,
355 | attention_mask=None,
356 | video_length=None,
357 | **cross_attention_kwargs,
358 | ):
359 | if self.attention_mode == "Temporal":
360 | d = hidden_states.shape[1] # d means HxW
361 | hidden_states = rearrange(
362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
363 | )
364 |
365 | if self.pos_encoder is not None:
366 | hidden_states = self.pos_encoder(hidden_states)
367 |
368 | encoder_hidden_states = (
369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370 | if encoder_hidden_states is not None
371 | else encoder_hidden_states
372 | )
373 |
374 | else:
375 | raise NotImplementedError
376 |
377 | hidden_states = self.processor(
378 | self,
379 | hidden_states,
380 | encoder_hidden_states=encoder_hidden_states,
381 | attention_mask=attention_mask,
382 | **cross_attention_kwargs,
383 | )
384 |
385 | if self.attention_mode == "Temporal":
386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387 |
388 | return hidden_states
389 |
--------------------------------------------------------------------------------
/src/models/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from .attention import TemporalBasicTransformerBlock
8 |
9 | from .attention import BasicTransformerBlock
10 |
11 |
12 | def torch_dfs(model: torch.nn.Module):
13 | result = [model]
14 | for child in model.children():
15 | result += torch_dfs(child)
16 | return result
17 |
18 |
19 | class ReferenceAttentionControl:
20 | def __init__(
21 | self,
22 | unet,
23 | mode="write",
24 | do_classifier_free_guidance=False,
25 | attention_auto_machine_weight=float("inf"),
26 | gn_auto_machine_weight=1.0,
27 | style_fidelity=1.0,
28 | reference_attn=True,
29 | reference_adain=False,
30 | fusion_blocks="midup",
31 | batch_size=1,
32 | ) -> None:
33 | # 10. Modify self attention and group norm
34 | self.unet = unet
35 | assert mode in ["read", "write"]
36 | assert fusion_blocks in ["midup", "full"]
37 | self.reference_attn = reference_attn
38 | self.reference_adain = reference_adain
39 | self.fusion_blocks = fusion_blocks
40 | self.register_reference_hooks(
41 | mode,
42 | do_classifier_free_guidance,
43 | attention_auto_machine_weight,
44 | gn_auto_machine_weight,
45 | style_fidelity,
46 | reference_attn,
47 | reference_adain,
48 | fusion_blocks,
49 | batch_size=batch_size,
50 | )
51 |
52 | def register_reference_hooks(
53 | self,
54 | mode,
55 | do_classifier_free_guidance,
56 | attention_auto_machine_weight,
57 | gn_auto_machine_weight,
58 | style_fidelity,
59 | reference_attn,
60 | reference_adain,
61 | dtype=torch.float16,
62 | batch_size=1,
63 | num_images_per_prompt=1,
64 | device=torch.device("cpu"),
65 | fusion_blocks="midup",
66 | ):
67 | MODE = mode
68 | do_classifier_free_guidance = do_classifier_free_guidance
69 | attention_auto_machine_weight = attention_auto_machine_weight
70 | gn_auto_machine_weight = gn_auto_machine_weight
71 | style_fidelity = style_fidelity
72 | reference_attn = reference_attn
73 | reference_adain = reference_adain
74 | fusion_blocks = fusion_blocks
75 | num_images_per_prompt = num_images_per_prompt
76 | dtype = dtype
77 | if do_classifier_free_guidance:
78 | uc_mask = (
79 | torch.Tensor(
80 | [1] * batch_size * num_images_per_prompt * 16
81 | + [0] * batch_size * num_images_per_prompt * 16
82 | )
83 | .to(device)
84 | .bool()
85 | )
86 | else:
87 | uc_mask = (
88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89 | .to(device)
90 | .bool()
91 | )
92 |
93 | def hacked_basic_transformer_inner_forward(
94 | self,
95 | hidden_states: torch.FloatTensor,
96 | attention_mask: Optional[torch.FloatTensor] = None,
97 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
98 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
99 | timestep: Optional[torch.LongTensor] = None,
100 | cross_attention_kwargs: Dict[str, Any] = None,
101 | class_labels: Optional[torch.LongTensor] = None,
102 | video_length=None,
103 | ):
104 | if self.use_ada_layer_norm: # False
105 | norm_hidden_states = self.norm1(hidden_states, timestep)
106 | elif self.use_ada_layer_norm_zero:
107 | (
108 | norm_hidden_states,
109 | gate_msa,
110 | shift_mlp,
111 | scale_mlp,
112 | gate_mlp,
113 | ) = self.norm1(
114 | hidden_states,
115 | timestep,
116 | class_labels,
117 | hidden_dtype=hidden_states.dtype,
118 | )
119 | else:
120 | norm_hidden_states = self.norm1(hidden_states)
121 |
122 | # 1. Self-Attention
123 | # self.only_cross_attention = False
124 | cross_attention_kwargs = (
125 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
126 | )
127 | if self.only_cross_attention:
128 | attn_output = self.attn1(
129 | norm_hidden_states,
130 | encoder_hidden_states=encoder_hidden_states
131 | if self.only_cross_attention
132 | else None,
133 | attention_mask=attention_mask,
134 | **cross_attention_kwargs,
135 | )
136 | else:
137 | if MODE == "write":
138 | self.bank.append(norm_hidden_states.clone())
139 | attn_output = self.attn1(
140 | norm_hidden_states,
141 | encoder_hidden_states=encoder_hidden_states
142 | if self.only_cross_attention
143 | else None,
144 | attention_mask=attention_mask,
145 | **cross_attention_kwargs,
146 | )
147 | if MODE == "read":
148 | bank_fea = [
149 | rearrange(
150 | d.unsqueeze(1).repeat(1, video_length, 1, 1),
151 | "b t l c -> (b t) l c",
152 | )
153 | for d in self.bank
154 | ]
155 | modify_norm_hidden_states = torch.cat(
156 | [norm_hidden_states] + bank_fea, dim=1
157 | )
158 | hidden_states_uc = (
159 | self.attn1(
160 | norm_hidden_states,
161 | encoder_hidden_states=modify_norm_hidden_states,
162 | attention_mask=attention_mask,
163 | )
164 | + hidden_states
165 | )
166 | if do_classifier_free_guidance:
167 | hidden_states_c = hidden_states_uc.clone()
168 | _uc_mask = uc_mask.clone()
169 | if hidden_states.shape[0] != _uc_mask.shape[0]:
170 | _uc_mask = (
171 | torch.Tensor(
172 | [1] * (hidden_states.shape[0] // 2)
173 | + [0] * (hidden_states.shape[0] // 2)
174 | )
175 | .to(device)
176 | .bool()
177 | )
178 | hidden_states_c[_uc_mask] = (
179 | self.attn1(
180 | norm_hidden_states[_uc_mask],
181 | encoder_hidden_states=norm_hidden_states[_uc_mask],
182 | attention_mask=attention_mask,
183 | )
184 | + hidden_states[_uc_mask]
185 | )
186 | hidden_states = hidden_states_c.clone()
187 | else:
188 | hidden_states = hidden_states_uc
189 |
190 | # self.bank.clear()
191 | if self.attn2 is not None:
192 | # Cross-Attention
193 | norm_hidden_states = (
194 | self.norm2(hidden_states, timestep)
195 | if self.use_ada_layer_norm
196 | else self.norm2(hidden_states)
197 | )
198 | hidden_states = (
199 | self.attn2(
200 | norm_hidden_states,
201 | encoder_hidden_states=encoder_hidden_states,
202 | attention_mask=attention_mask,
203 | )
204 | + hidden_states
205 | )
206 |
207 | # Feed-forward
208 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209 |
210 | # Temporal-Attention
211 | if self.unet_use_temporal_attention:
212 | d = hidden_states.shape[1]
213 | hidden_states = rearrange(
214 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
215 | )
216 | norm_hidden_states = (
217 | self.norm_temp(hidden_states, timestep)
218 | if self.use_ada_layer_norm
219 | else self.norm_temp(hidden_states)
220 | )
221 | hidden_states = (
222 | self.attn_temp(norm_hidden_states) + hidden_states
223 | )
224 | hidden_states = rearrange(
225 | hidden_states, "(b d) f c -> (b f) d c", d=d
226 | )
227 |
228 | return hidden_states
229 |
230 | if self.use_ada_layer_norm_zero:
231 | attn_output = gate_msa.unsqueeze(1) * attn_output
232 | hidden_states = attn_output + hidden_states
233 |
234 | if self.attn2 is not None:
235 | norm_hidden_states = (
236 | self.norm2(hidden_states, timestep)
237 | if self.use_ada_layer_norm
238 | else self.norm2(hidden_states)
239 | )
240 |
241 | # 2. Cross-Attention
242 | attn_output = self.attn2(
243 | norm_hidden_states,
244 | encoder_hidden_states=encoder_hidden_states,
245 | attention_mask=encoder_attention_mask,
246 | **cross_attention_kwargs,
247 | )
248 | hidden_states = attn_output + hidden_states
249 |
250 | # 3. Feed-forward
251 | norm_hidden_states = self.norm3(hidden_states)
252 |
253 | if self.use_ada_layer_norm_zero:
254 | norm_hidden_states = (
255 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256 | )
257 |
258 | ff_output = self.ff(norm_hidden_states)
259 |
260 | if self.use_ada_layer_norm_zero:
261 | ff_output = gate_mlp.unsqueeze(1) * ff_output
262 |
263 | hidden_states = ff_output + hidden_states
264 |
265 | return hidden_states
266 |
267 | if self.reference_attn:
268 | if self.fusion_blocks == "midup":
269 | attn_modules = [
270 | module
271 | for module in (
272 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273 | )
274 | if isinstance(module, BasicTransformerBlock)
275 | or isinstance(module, TemporalBasicTransformerBlock)
276 | ]
277 | elif self.fusion_blocks == "full":
278 | attn_modules = [
279 | module
280 | for module in torch_dfs(self.unet)
281 | if isinstance(module, BasicTransformerBlock)
282 | or isinstance(module, TemporalBasicTransformerBlock)
283 | ]
284 | attn_modules = sorted(
285 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286 | )
287 |
288 | for i, module in enumerate(attn_modules):
289 | module._original_inner_forward = module.forward
290 | if isinstance(module, BasicTransformerBlock):
291 | module.forward = hacked_basic_transformer_inner_forward.__get__(
292 | module, BasicTransformerBlock
293 | )
294 | if isinstance(module, TemporalBasicTransformerBlock):
295 | module.forward = hacked_basic_transformer_inner_forward.__get__(
296 | module, TemporalBasicTransformerBlock
297 | )
298 |
299 | module.bank = []
300 | module.attn_weight = float(i) / float(len(attn_modules))
301 |
302 | def update(self, writer, dtype=torch.float16):
303 | if self.reference_attn:
304 | if self.fusion_blocks == "midup":
305 | reader_attn_modules = [
306 | module
307 | for module in (
308 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309 | )
310 | if isinstance(module, TemporalBasicTransformerBlock)
311 | ]
312 | writer_attn_modules = [
313 | module
314 | for module in (
315 | torch_dfs(writer.unet.mid_block)
316 | + torch_dfs(writer.unet.up_blocks)
317 | )
318 | if isinstance(module, BasicTransformerBlock)
319 | ]
320 | elif self.fusion_blocks == "full":
321 | reader_attn_modules = [
322 | module
323 | for module in torch_dfs(self.unet)
324 | if isinstance(module, TemporalBasicTransformerBlock)
325 | ]
326 | writer_attn_modules = [
327 | module
328 | for module in torch_dfs(writer.unet)
329 | if isinstance(module, BasicTransformerBlock)
330 | ]
331 | reader_attn_modules = sorted(
332 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333 | )
334 | writer_attn_modules = sorted(
335 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336 | )
337 | for r, w in zip(reader_attn_modules, writer_attn_modules):
338 | r.bank = [v.clone().to(dtype) for v in w.bank]
339 | # w.bank.clear()
340 |
341 | def clear(self):
342 | if self.reference_attn:
343 | if self.fusion_blocks == "midup":
344 | reader_attn_modules = [
345 | module
346 | for module in (
347 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348 | )
349 | if isinstance(module, BasicTransformerBlock)
350 | or isinstance(module, TemporalBasicTransformerBlock)
351 | ]
352 | elif self.fusion_blocks == "full":
353 | reader_attn_modules = [
354 | module
355 | for module in torch_dfs(self.unet)
356 | if isinstance(module, BasicTransformerBlock)
357 | or isinstance(module, TemporalBasicTransformerBlock)
358 | ]
359 | reader_attn_modules = sorted(
360 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361 | )
362 | for r in reader_attn_modules:
363 | r.bank.clear()
364 |
--------------------------------------------------------------------------------
/src/models/pose_guider.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 | from einops import rearrange
6 | import numpy as np
7 | from diffusers.models.modeling_utils import ModelMixin
8 |
9 | from typing import Any, Dict, Optional
10 | from .attention import BasicTransformerBlock
11 |
12 |
13 | class PoseGuider(ModelMixin):
14 | def __init__(self, noise_latent_channels=320, use_ca=True):
15 | super(PoseGuider, self).__init__()
16 |
17 | self.use_ca = use_ca
18 |
19 | self.conv_layers = nn.Sequential(
20 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
21 | nn.BatchNorm2d(3),
22 | nn.ReLU(),
23 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
24 | nn.BatchNorm2d(16),
25 | nn.ReLU(),
26 |
27 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
28 | nn.BatchNorm2d(16),
29 | nn.ReLU(),
30 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
31 | nn.BatchNorm2d(32),
32 | nn.ReLU(),
33 |
34 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
35 | nn.BatchNorm2d(32),
36 | nn.ReLU(),
37 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
38 | nn.BatchNorm2d(64),
39 | nn.ReLU(),
40 |
41 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
42 | nn.BatchNorm2d(64),
43 | nn.ReLU(),
44 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
45 | nn.BatchNorm2d(128),
46 | nn.ReLU()
47 | )
48 |
49 | # Final projection layer
50 | self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
51 |
52 | self.conv_layers_1 = nn.Sequential(
53 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1),
54 | nn.BatchNorm2d(noise_latent_channels),
55 | nn.ReLU(),
56 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, stride=2, padding=1),
57 | nn.BatchNorm2d(noise_latent_channels),
58 | nn.ReLU(),
59 | )
60 |
61 | self.conv_layers_2 = nn.Sequential(
62 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1),
63 | nn.BatchNorm2d(noise_latent_channels),
64 | nn.ReLU(),
65 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels*2, kernel_size=3, stride=2, padding=1),
66 | nn.BatchNorm2d(noise_latent_channels*2),
67 | nn.ReLU(),
68 | )
69 |
70 | self.conv_layers_3 = nn.Sequential(
71 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*2, kernel_size=3, padding=1),
72 | nn.BatchNorm2d(noise_latent_channels*2),
73 | nn.ReLU(),
74 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*4, kernel_size=3, stride=2, padding=1),
75 | nn.BatchNorm2d(noise_latent_channels*4),
76 | nn.ReLU(),
77 | )
78 |
79 | self.conv_layers_4 = nn.Sequential(
80 | nn.Conv2d(in_channels=noise_latent_channels*4, out_channels=noise_latent_channels*4, kernel_size=3, padding=1),
81 | nn.BatchNorm2d(noise_latent_channels*4),
82 | nn.ReLU(),
83 | )
84 |
85 | if self.use_ca:
86 | self.cross_attn1 = Transformer2DModel(in_channels=noise_latent_channels)
87 | self.cross_attn2 = Transformer2DModel(in_channels=noise_latent_channels*2)
88 | self.cross_attn3 = Transformer2DModel(in_channels=noise_latent_channels*4)
89 | self.cross_attn4 = Transformer2DModel(in_channels=noise_latent_channels*4)
90 |
91 | # Initialize layers
92 | self._initialize_weights()
93 |
94 | self.scale = nn.Parameter(torch.ones(1) * 2)
95 |
96 | # def _initialize_weights(self):
97 | # # Initialize weights with Gaussian distribution and zero out the final layer
98 | # for m in self.conv_layers:
99 | # if isinstance(m, nn.Conv2d):
100 | # init.normal_(m.weight, mean=0.0, std=0.02)
101 | # if m.bias is not None:
102 | # init.zeros_(m.bias)
103 |
104 | # init.zeros_(self.final_proj.weight)
105 | # if self.final_proj.bias is not None:
106 | # init.zeros_(self.final_proj.bias)
107 |
108 | def _initialize_weights(self):
109 | # Initialize weights with He initialization and zero out the biases
110 | conv_blocks = [self.conv_layers, self.conv_layers_1, self.conv_layers_2, self.conv_layers_3, self.conv_layers_4]
111 | for block_item in conv_blocks:
112 | for m in block_item:
113 | if isinstance(m, nn.Conv2d):
114 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
115 | init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
116 | if m.bias is not None:
117 | init.zeros_(m.bias)
118 |
119 | # For the final projection layer, initialize weights to zero (or you may choose to use He initialization here as well)
120 | init.zeros_(self.final_proj.weight)
121 | if self.final_proj.bias is not None:
122 | init.zeros_(self.final_proj.bias)
123 |
124 | def forward(self, x, ref_x):
125 | fea = []
126 | b = x.shape[0]
127 |
128 | x = rearrange(x, "b c f h w -> (b f) c h w")
129 | x = self.conv_layers(x)
130 | x = self.final_proj(x)
131 | x = x * self.scale
132 | # x = rearrange(x, "(b f) c h w -> b c f h w", b=b)
133 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
134 |
135 | x = self.conv_layers_1(x)
136 | if self.use_ca:
137 | ref_x = self.conv_layers(ref_x)
138 | ref_x = self.final_proj(ref_x)
139 | ref_x = ref_x * self.scale
140 | ref_x = self.conv_layers_1(ref_x)
141 | x = self.cross_attn1(x, ref_x)
142 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
143 |
144 | x = self.conv_layers_2(x)
145 | if self.use_ca:
146 | ref_x = self.conv_layers_2(ref_x)
147 | x = self.cross_attn2(x, ref_x)
148 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
149 |
150 | x = self.conv_layers_3(x)
151 | if self.use_ca:
152 | ref_x = self.conv_layers_3(ref_x)
153 | x = self.cross_attn3(x, ref_x)
154 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
155 |
156 | x = self.conv_layers_4(x)
157 | if self.use_ca:
158 | ref_x = self.conv_layers_4(ref_x)
159 | x = self.cross_attn4(x, ref_x)
160 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
161 |
162 | return fea
163 |
164 | @classmethod
165 | def from_pretrained(cls,pretrained_model_path):
166 | if not os.path.exists(pretrained_model_path):
167 | print(f"There is no model file in {pretrained_model_path}")
168 | print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
169 |
170 | state_dict = torch.load(pretrained_model_path, map_location="cpu")
171 | model = Hack_PoseGuider(noise_latent_channels=320)
172 |
173 | m, u = model.load_state_dict(state_dict, strict=True)
174 | # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
175 | params = [p.numel() for n, p in model.named_parameters()]
176 | print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
177 |
178 | return model
179 |
180 |
181 | class Transformer2DModel(ModelMixin):
182 | _supports_gradient_checkpointing = True
183 | def __init__(
184 | self,
185 | num_attention_heads: int = 16,
186 | attention_head_dim: int = 88,
187 | in_channels: Optional[int] = None,
188 | num_layers: int = 1,
189 | dropout: float = 0.0,
190 | norm_num_groups: int = 32,
191 | cross_attention_dim: Optional[int] = None,
192 | attention_bias: bool = False,
193 | activation_fn: str = "geglu",
194 | num_embeds_ada_norm: Optional[int] = None,
195 | use_linear_projection: bool = False,
196 | only_cross_attention: bool = False,
197 | double_self_attention: bool = False,
198 | upcast_attention: bool = False,
199 | norm_type: str = "layer_norm",
200 | norm_elementwise_affine: bool = True,
201 | norm_eps: float = 1e-5,
202 | attention_type: str = "default",
203 | ):
204 | super().__init__()
205 | self.use_linear_projection = use_linear_projection
206 | self.num_attention_heads = num_attention_heads
207 | self.attention_head_dim = attention_head_dim
208 | inner_dim = num_attention_heads * attention_head_dim
209 |
210 | self.in_channels = in_channels
211 |
212 | self.norm = torch.nn.GroupNorm(
213 | num_groups=norm_num_groups,
214 | num_channels=in_channels,
215 | eps=1e-6,
216 | affine=True,
217 | )
218 | if use_linear_projection:
219 | self.proj_in = nn.Linear(in_channels, inner_dim)
220 | else:
221 | self.proj_in = nn.Conv2d(
222 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
223 | )
224 |
225 | # 3. Define transformers blocks
226 | self.transformer_blocks = nn.ModuleList(
227 | [
228 | BasicTransformerBlock(
229 | inner_dim,
230 | num_attention_heads,
231 | attention_head_dim,
232 | dropout=dropout,
233 | cross_attention_dim=cross_attention_dim,
234 | activation_fn=activation_fn,
235 | num_embeds_ada_norm=num_embeds_ada_norm,
236 | attention_bias=attention_bias,
237 | only_cross_attention=only_cross_attention,
238 | double_self_attention=double_self_attention,
239 | upcast_attention=upcast_attention,
240 | norm_type=norm_type,
241 | norm_elementwise_affine=norm_elementwise_affine,
242 | norm_eps=norm_eps,
243 | attention_type=attention_type,
244 | )
245 | for d in range(num_layers)
246 | ]
247 | )
248 |
249 | if use_linear_projection:
250 | self.proj_out = nn.Linear(inner_dim, in_channels)
251 | else:
252 | self.proj_out = nn.Conv2d(
253 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
254 | )
255 |
256 | self.gradient_checkpointing = False
257 |
258 | def _set_gradient_checkpointing(self, module, value=False):
259 | if hasattr(module, "gradient_checkpointing"):
260 | module.gradient_checkpointing = value
261 |
262 | def forward(
263 | self,
264 | hidden_states: torch.Tensor,
265 | encoder_hidden_states: Optional[torch.Tensor] = None,
266 | timestep: Optional[torch.LongTensor] = None,
267 | ):
268 | batch, _, height, width = hidden_states.shape
269 | residual = hidden_states
270 |
271 | hidden_states = self.norm(hidden_states)
272 | if not self.use_linear_projection:
273 | hidden_states = self.proj_in(hidden_states)
274 | inner_dim = hidden_states.shape[1]
275 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
276 | batch, height * width, inner_dim
277 | )
278 | else:
279 | inner_dim = hidden_states.shape[1]
280 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
281 | batch, height * width, inner_dim
282 | )
283 | hidden_states = self.proj_in(hidden_states)
284 |
285 | for block in self.transformer_blocks:
286 | hidden_states = block(
287 | hidden_states,
288 | encoder_hidden_states=encoder_hidden_states,
289 | timestep=timestep,
290 | )
291 |
292 | if not self.use_linear_projection:
293 | hidden_states = (
294 | hidden_states.reshape(batch, height, width, inner_dim)
295 | .permute(0, 3, 1, 2)
296 | .contiguous()
297 | )
298 | hidden_states = self.proj_out(hidden_states)
299 | else:
300 | hidden_states = self.proj_out(hidden_states)
301 | hidden_states = (
302 | hidden_states.reshape(batch, height, width, inner_dim)
303 | .permute(0, 3, 1, 2)
304 | .contiguous()
305 | )
306 |
307 | output = hidden_states + residual
308 | return output
309 |
310 |
311 | if __name__ == '__main__':
312 | model = PoseGuider(noise_latent_channels=320).to(device="cuda")
313 |
314 | input_data = torch.randn(1,3,1,512,512).to(device="cuda")
315 | input_data1 = torch.randn(1,3,512,512).to(device="cuda")
316 |
317 | output = model(input_data, input_data1)
318 | for item in output:
319 | print(item.shape)
320 |
321 | # tf_model = Transformer2DModel(
322 | # in_channels=320
323 | # ).to('cuda')
324 |
325 | # input_data = torch.randn(4,320,32,32).to(device="cuda")
326 | # # input_emb = torch.randn(4,1,768).to(device="cuda")
327 | # input_emb = torch.randn(4,320,32,32).to(device="cuda")
328 | # o1 = tf_model(input_data, input_emb)
329 | # print(o1.shape)
330 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 | from typing import Dict, Optional
8 |
9 |
10 | class InflatedConv3d(nn.Conv2d):
11 | def forward(self, x):
12 | video_length = x.shape[2]
13 |
14 | x = rearrange(x, "b c f h w -> (b f) c h w")
15 | x = super().forward(x)
16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17 |
18 | return x
19 |
20 |
21 | class InflatedGroupNorm(nn.GroupNorm):
22 | def forward(self, x):
23 | video_length = x.shape[2]
24 |
25 | x = rearrange(x, "b c f h w -> (b f) c h w")
26 | x = super().forward(x)
27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28 |
29 | return x
30 |
31 |
32 | class Upsample3D(nn.Module):
33 | def __init__(
34 | self,
35 | channels,
36 | use_conv=False,
37 | use_conv_transpose=False,
38 | out_channels=None,
39 | name="conv",
40 | ):
41 | super().__init__()
42 | self.channels = channels
43 | self.out_channels = out_channels or channels
44 | self.use_conv = use_conv
45 | self.use_conv_transpose = use_conv_transpose
46 | self.name = name
47 |
48 | conv = None
49 | if use_conv_transpose:
50 | raise NotImplementedError
51 | elif use_conv:
52 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
53 |
54 | def forward(self, hidden_states, output_size=None):
55 | assert hidden_states.shape[1] == self.channels
56 |
57 | if self.use_conv_transpose:
58 | raise NotImplementedError
59 |
60 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
61 | dtype = hidden_states.dtype
62 | if dtype == torch.bfloat16:
63 | hidden_states = hidden_states.to(torch.float32)
64 |
65 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
66 | if hidden_states.shape[0] >= 64:
67 | hidden_states = hidden_states.contiguous()
68 |
69 | # if `output_size` is passed we force the interpolation output
70 | # size and do not make use of `scale_factor=2`
71 | if output_size is None:
72 | hidden_states = F.interpolate(
73 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
74 | )
75 | else:
76 | hidden_states = F.interpolate(
77 | hidden_states, size=output_size, mode="nearest"
78 | )
79 |
80 | # If the input is bfloat16, we cast back to bfloat16
81 | if dtype == torch.bfloat16:
82 | hidden_states = hidden_states.to(dtype)
83 |
84 | # if self.use_conv:
85 | # if self.name == "conv":
86 | # hidden_states = self.conv(hidden_states)
87 | # else:
88 | # hidden_states = self.Conv2d_0(hidden_states)
89 | hidden_states = self.conv(hidden_states)
90 |
91 | return hidden_states
92 |
93 |
94 | class Downsample3D(nn.Module):
95 | def __init__(
96 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
97 | ):
98 | super().__init__()
99 | self.channels = channels
100 | self.out_channels = out_channels or channels
101 | self.use_conv = use_conv
102 | self.padding = padding
103 | stride = 2
104 | self.name = name
105 |
106 | if use_conv:
107 | self.conv = InflatedConv3d(
108 | self.channels, self.out_channels, 3, stride=stride, padding=padding
109 | )
110 | else:
111 | raise NotImplementedError
112 |
113 | def forward(self, hidden_states):
114 | assert hidden_states.shape[1] == self.channels
115 | if self.use_conv and self.padding == 0:
116 | raise NotImplementedError
117 |
118 | assert hidden_states.shape[1] == self.channels
119 | hidden_states = self.conv(hidden_states)
120 |
121 | return hidden_states
122 |
123 |
124 | class ResnetBlock3D(nn.Module):
125 | def __init__(
126 | self,
127 | *,
128 | in_channels,
129 | out_channels=None,
130 | conv_shortcut=False,
131 | dropout=0.0,
132 | temb_channels=512,
133 | groups=32,
134 | groups_out=None,
135 | pre_norm=True,
136 | eps=1e-6,
137 | non_linearity="swish",
138 | time_embedding_norm="default",
139 | output_scale_factor=1.0,
140 | use_in_shortcut=None,
141 | use_inflated_groupnorm=None,
142 | ):
143 | super().__init__()
144 | self.pre_norm = pre_norm
145 | self.pre_norm = True
146 | self.in_channels = in_channels
147 | out_channels = in_channels if out_channels is None else out_channels
148 | self.out_channels = out_channels
149 | self.use_conv_shortcut = conv_shortcut
150 | self.time_embedding_norm = time_embedding_norm
151 | self.output_scale_factor = output_scale_factor
152 |
153 | if groups_out is None:
154 | groups_out = groups
155 |
156 | assert use_inflated_groupnorm != None
157 | if use_inflated_groupnorm:
158 | self.norm1 = InflatedGroupNorm(
159 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
160 | )
161 | else:
162 | self.norm1 = torch.nn.GroupNorm(
163 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
164 | )
165 |
166 | self.conv1 = InflatedConv3d(
167 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
168 | )
169 |
170 | if temb_channels is not None:
171 | if self.time_embedding_norm == "default":
172 | time_emb_proj_out_channels = out_channels
173 | elif self.time_embedding_norm == "scale_shift":
174 | time_emb_proj_out_channels = out_channels * 2
175 | else:
176 | raise ValueError(
177 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
178 | )
179 |
180 | self.time_emb_proj = torch.nn.Linear(
181 | temb_channels, time_emb_proj_out_channels
182 | )
183 | else:
184 | self.time_emb_proj = None
185 |
186 | if use_inflated_groupnorm:
187 | self.norm2 = InflatedGroupNorm(
188 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
189 | )
190 | else:
191 | self.norm2 = torch.nn.GroupNorm(
192 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
193 | )
194 | self.dropout = torch.nn.Dropout(dropout)
195 | self.conv2 = InflatedConv3d(
196 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
197 | )
198 |
199 | if non_linearity == "swish":
200 | self.nonlinearity = lambda x: F.silu(x)
201 | elif non_linearity == "mish":
202 | self.nonlinearity = Mish()
203 | elif non_linearity == "silu":
204 | self.nonlinearity = nn.SiLU()
205 |
206 | self.use_in_shortcut = (
207 | self.in_channels != self.out_channels
208 | if use_in_shortcut is None
209 | else use_in_shortcut
210 | )
211 |
212 | self.conv_shortcut = None
213 | if self.use_in_shortcut:
214 | self.conv_shortcut = InflatedConv3d(
215 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
216 | )
217 |
218 | def forward(self, input_tensor, temb):
219 | hidden_states = input_tensor
220 |
221 | hidden_states = self.norm1(hidden_states)
222 | hidden_states = self.nonlinearity(hidden_states)
223 |
224 | hidden_states = self.conv1(hidden_states)
225 |
226 | if temb is not None:
227 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
228 |
229 | if temb is not None and self.time_embedding_norm == "default":
230 | hidden_states = hidden_states + temb
231 |
232 | hidden_states = self.norm2(hidden_states)
233 |
234 | if temb is not None and self.time_embedding_norm == "scale_shift":
235 | scale, shift = torch.chunk(temb, 2, dim=1)
236 | hidden_states = hidden_states * (1 + scale) + shift
237 |
238 | hidden_states = self.nonlinearity(hidden_states)
239 |
240 | hidden_states = self.dropout(hidden_states)
241 | hidden_states = self.conv2(hidden_states)
242 |
243 | if self.conv_shortcut is not None:
244 | input_tensor = self.conv_shortcut(input_tensor)
245 |
246 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
247 |
248 | return output_tensor
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
--------------------------------------------------------------------------------
/src/models/transformer_2d.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, Optional
4 |
5 | import torch
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 | from diffusers.models.embeddings import PixArtAlphaTextProjection
8 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
9 | from diffusers.models.modeling_utils import ModelMixin
10 | from diffusers.models.normalization import AdaLayerNormSingle
11 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
12 | from torch import nn
13 |
14 | from .attention import BasicTransformerBlock
15 |
16 |
17 | @dataclass
18 | class Transformer2DModelOutput(BaseOutput):
19 | """
20 | The output of [`Transformer2DModel`].
21 |
22 | Args:
23 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
24 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
25 | distributions for the unnoised latent pixels.
26 | """
27 |
28 | sample: torch.FloatTensor
29 | ref_feature: torch.FloatTensor
30 |
31 |
32 | class Transformer2DModel(ModelMixin, ConfigMixin):
33 | """
34 | A 2D Transformer model for image-like data.
35 |
36 | Parameters:
37 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
38 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
39 | in_channels (`int`, *optional*):
40 | The number of channels in the input and output (specify if the input is **continuous**).
41 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
42 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
43 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
44 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
45 | This is fixed during training since it is used to learn a number of position embeddings.
46 | num_vector_embeds (`int`, *optional*):
47 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
48 | Includes the class for the masked latent pixel.
49 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
50 | num_embeds_ada_norm ( `int`, *optional*):
51 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is
52 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
53 | added to the hidden states.
54 |
55 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
56 | attention_bias (`bool`, *optional*):
57 | Configure if the `TransformerBlocks` attention should contain a bias parameter.
58 | """
59 |
60 | _supports_gradient_checkpointing = True
61 |
62 | @register_to_config
63 | def __init__(
64 | self,
65 | num_attention_heads: int = 16,
66 | attention_head_dim: int = 88,
67 | in_channels: Optional[int] = None,
68 | out_channels: Optional[int] = None,
69 | num_layers: int = 1,
70 | dropout: float = 0.0,
71 | norm_num_groups: int = 32,
72 | cross_attention_dim: Optional[int] = None,
73 | attention_bias: bool = False,
74 | sample_size: Optional[int] = None,
75 | num_vector_embeds: Optional[int] = None,
76 | patch_size: Optional[int] = None,
77 | activation_fn: str = "geglu",
78 | num_embeds_ada_norm: Optional[int] = None,
79 | use_linear_projection: bool = False,
80 | only_cross_attention: bool = False,
81 | double_self_attention: bool = False,
82 | upcast_attention: bool = False,
83 | norm_type: str = "layer_norm",
84 | norm_elementwise_affine: bool = True,
85 | norm_eps: float = 1e-5,
86 | attention_type: str = "default",
87 | caption_channels: int = None,
88 | ):
89 | super().__init__()
90 | self.use_linear_projection = use_linear_projection
91 | self.num_attention_heads = num_attention_heads
92 | self.attention_head_dim = attention_head_dim
93 | inner_dim = num_attention_heads * attention_head_dim
94 |
95 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
96 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
97 |
98 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
99 | # Define whether input is continuous or discrete depending on configuration
100 | self.is_input_continuous = (in_channels is not None) and (patch_size is None)
101 | self.is_input_vectorized = num_vector_embeds is not None
102 | self.is_input_patches = in_channels is not None and patch_size is not None
103 |
104 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
105 | deprecation_message = (
106 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
107 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
108 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
109 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
110 | " would be very nice if you could open a Pull request for the `transformer/config.json` file"
111 | )
112 | deprecate(
113 | "norm_type!=num_embeds_ada_norm",
114 | "1.0.0",
115 | deprecation_message,
116 | standard_warn=False,
117 | )
118 | norm_type = "ada_norm"
119 |
120 | if self.is_input_continuous and self.is_input_vectorized:
121 | raise ValueError(
122 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
123 | " sure that either `in_channels` or `num_vector_embeds` is None."
124 | )
125 | elif self.is_input_vectorized and self.is_input_patches:
126 | raise ValueError(
127 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
128 | " sure that either `num_vector_embeds` or `num_patches` is None."
129 | )
130 | elif (
131 | not self.is_input_continuous
132 | and not self.is_input_vectorized
133 | and not self.is_input_patches
134 | ):
135 | raise ValueError(
136 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
137 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
138 | )
139 |
140 | # 2. Define input layers
141 | self.in_channels = in_channels
142 |
143 | self.norm = torch.nn.GroupNorm(
144 | num_groups=norm_num_groups,
145 | num_channels=in_channels,
146 | eps=1e-6,
147 | affine=True,
148 | )
149 | if use_linear_projection:
150 | self.proj_in = linear_cls(in_channels, inner_dim)
151 | else:
152 | self.proj_in = conv_cls(
153 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
154 | )
155 |
156 | # 3. Define transformers blocks
157 | self.transformer_blocks = nn.ModuleList(
158 | [
159 | BasicTransformerBlock(
160 | inner_dim,
161 | num_attention_heads,
162 | attention_head_dim,
163 | dropout=dropout,
164 | cross_attention_dim=cross_attention_dim,
165 | activation_fn=activation_fn,
166 | num_embeds_ada_norm=num_embeds_ada_norm,
167 | attention_bias=attention_bias,
168 | only_cross_attention=only_cross_attention,
169 | double_self_attention=double_self_attention,
170 | upcast_attention=upcast_attention,
171 | norm_type=norm_type,
172 | norm_elementwise_affine=norm_elementwise_affine,
173 | norm_eps=norm_eps,
174 | attention_type=attention_type,
175 | )
176 | for d in range(num_layers)
177 | ]
178 | )
179 |
180 | # 4. Define output layers
181 | self.out_channels = in_channels if out_channels is None else out_channels
182 | # TODO: should use out_channels for continuous projections
183 | if use_linear_projection:
184 | self.proj_out = linear_cls(inner_dim, in_channels)
185 | else:
186 | self.proj_out = conv_cls(
187 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
188 | )
189 |
190 | # 5. PixArt-Alpha blocks.
191 | self.adaln_single = None
192 | self.use_additional_conditions = False
193 | if norm_type == "ada_norm_single":
194 | self.use_additional_conditions = self.config.sample_size == 128
195 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
196 | # additional conditions until we find better name
197 | self.adaln_single = AdaLayerNormSingle(
198 | inner_dim, use_additional_conditions=self.use_additional_conditions
199 | )
200 |
201 | self.caption_projection = None
202 | if caption_channels is not None:
203 | self.caption_projection = PixArtAlphaTextProjection(
204 | in_features=caption_channels, hidden_size=inner_dim
205 | )
206 |
207 | self.gradient_checkpointing = False
208 |
209 | def _set_gradient_checkpointing(self, module, value=False):
210 | if hasattr(module, "gradient_checkpointing"):
211 | module.gradient_checkpointing = value
212 |
213 | def forward(
214 | self,
215 | hidden_states: torch.Tensor,
216 | encoder_hidden_states: Optional[torch.Tensor] = None,
217 | timestep: Optional[torch.LongTensor] = None,
218 | added_cond_kwargs: Dict[str, torch.Tensor] = None,
219 | class_labels: Optional[torch.LongTensor] = None,
220 | cross_attention_kwargs: Dict[str, Any] = None,
221 | attention_mask: Optional[torch.Tensor] = None,
222 | encoder_attention_mask: Optional[torch.Tensor] = None,
223 | return_dict: bool = True,
224 | ):
225 | """
226 | The [`Transformer2DModel`] forward method.
227 |
228 | Args:
229 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
230 | Input `hidden_states`.
231 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
232 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
233 | self-attention.
234 | timestep ( `torch.LongTensor`, *optional*):
235 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
236 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
237 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
238 | `AdaLayerZeroNorm`.
239 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
240 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
241 | `self.processor` in
242 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
243 | attention_mask ( `torch.Tensor`, *optional*):
244 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
245 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
246 | negative values to the attention scores corresponding to "discard" tokens.
247 | encoder_attention_mask ( `torch.Tensor`, *optional*):
248 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
249 |
250 | * Mask `(batch, sequence_length)` True = keep, False = discard.
251 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
252 |
253 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
254 | above. This bias will be added to the cross-attention scores.
255 | return_dict (`bool`, *optional*, defaults to `True`):
256 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
257 | tuple.
258 |
259 | Returns:
260 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
261 | `tuple` where the first element is the sample tensor.
262 | """
263 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
264 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
265 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
266 | # expects mask of shape:
267 | # [batch, key_tokens]
268 | # adds singleton query_tokens dimension:
269 | # [batch, 1, key_tokens]
270 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
271 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
272 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
273 | if attention_mask is not None and attention_mask.ndim == 2:
274 | # assume that mask is expressed as:
275 | # (1 = keep, 0 = discard)
276 | # convert mask into a bias that can be added to attention scores:
277 | # (keep = +0, discard = -10000.0)
278 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
279 | attention_mask = attention_mask.unsqueeze(1)
280 |
281 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
282 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
283 | encoder_attention_mask = (
284 | 1 - encoder_attention_mask.to(hidden_states.dtype)
285 | ) * -10000.0
286 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
287 |
288 | # Retrieve lora scale.
289 | lora_scale = (
290 | cross_attention_kwargs.get("scale", 1.0)
291 | if cross_attention_kwargs is not None
292 | else 1.0
293 | )
294 |
295 | # 1. Input
296 | batch, _, height, width = hidden_states.shape
297 | residual = hidden_states
298 |
299 | hidden_states = self.norm(hidden_states)
300 | if not self.use_linear_projection:
301 | hidden_states = (
302 | self.proj_in(hidden_states, scale=lora_scale)
303 | if not USE_PEFT_BACKEND
304 | else self.proj_in(hidden_states)
305 | )
306 | inner_dim = hidden_states.shape[1]
307 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
308 | batch, height * width, inner_dim
309 | )
310 | else:
311 | inner_dim = hidden_states.shape[1]
312 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313 | batch, height * width, inner_dim
314 | )
315 | hidden_states = (
316 | self.proj_in(hidden_states, scale=lora_scale)
317 | if not USE_PEFT_BACKEND
318 | else self.proj_in(hidden_states)
319 | )
320 |
321 | # 2. Blocks
322 | if self.caption_projection is not None:
323 | batch_size = hidden_states.shape[0]
324 | encoder_hidden_states = self.caption_projection(encoder_hidden_states)
325 | encoder_hidden_states = encoder_hidden_states.view(
326 | batch_size, -1, hidden_states.shape[-1]
327 | )
328 |
329 | ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
330 | for block in self.transformer_blocks:
331 | if self.training and self.gradient_checkpointing:
332 |
333 | def create_custom_forward(module, return_dict=None):
334 | def custom_forward(*inputs):
335 | if return_dict is not None:
336 | return module(*inputs, return_dict=return_dict)
337 | else:
338 | return module(*inputs)
339 |
340 | return custom_forward
341 |
342 | ckpt_kwargs: Dict[str, Any] = (
343 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344 | )
345 | hidden_states = torch.utils.checkpoint.checkpoint(
346 | create_custom_forward(block),
347 | hidden_states,
348 | attention_mask,
349 | encoder_hidden_states,
350 | encoder_attention_mask,
351 | timestep,
352 | cross_attention_kwargs,
353 | class_labels,
354 | **ckpt_kwargs,
355 | )
356 | else:
357 | hidden_states = block(
358 | hidden_states,
359 | attention_mask=attention_mask,
360 | encoder_hidden_states=encoder_hidden_states,
361 | encoder_attention_mask=encoder_attention_mask,
362 | timestep=timestep,
363 | cross_attention_kwargs=cross_attention_kwargs,
364 | class_labels=class_labels,
365 | )
366 |
367 | # 3. Output
368 | if self.is_input_continuous:
369 | if not self.use_linear_projection:
370 | hidden_states = (
371 | hidden_states.reshape(batch, height, width, inner_dim)
372 | .permute(0, 3, 1, 2)
373 | .contiguous()
374 | )
375 | hidden_states = (
376 | self.proj_out(hidden_states, scale=lora_scale)
377 | if not USE_PEFT_BACKEND
378 | else self.proj_out(hidden_states)
379 | )
380 | else:
381 | hidden_states = (
382 | self.proj_out(hidden_states, scale=lora_scale)
383 | if not USE_PEFT_BACKEND
384 | else self.proj_out(hidden_states)
385 | )
386 | hidden_states = (
387 | hidden_states.reshape(batch, height, width, inner_dim)
388 | .permute(0, 3, 1, 2)
389 | .contiguous()
390 | )
391 |
392 | output = hidden_states + residual
393 | if not return_dict:
394 | return (output, ref_feature)
395 |
396 | return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
397 |
--------------------------------------------------------------------------------
/src/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Dict
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock, ResidualTemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(
59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60 | )
61 | if use_linear_projection:
62 | self.proj_in = nn.Linear(in_channels, inner_dim)
63 | else:
64 | self.proj_in = nn.Conv2d(
65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66 | )
67 |
68 | # Define transformers blocks
69 | self.transformer_blocks = nn.ModuleList(
70 | [
71 | TemporalBasicTransformerBlock(
72 | inner_dim,
73 | num_attention_heads,
74 | attention_head_dim,
75 | dropout=dropout,
76 | cross_attention_dim=cross_attention_dim,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | attention_bias=attention_bias,
80 | only_cross_attention=only_cross_attention,
81 | upcast_attention=upcast_attention,
82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83 | unet_use_temporal_attention=unet_use_temporal_attention,
84 | )
85 | for d in range(num_layers)
86 | ]
87 | )
88 |
89 | # 4. Define output layers
90 | if use_linear_projection:
91 | self.proj_out = nn.Linear(in_channels, inner_dim)
92 | else:
93 | self.proj_out = nn.Conv2d(
94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95 | )
96 |
97 | self.gradient_checkpointing = False
98 |
99 | def _set_gradient_checkpointing(self, module, value=False):
100 | if hasattr(module, "gradient_checkpointing"):
101 | module.gradient_checkpointing = value
102 |
103 | def forward(
104 | self,
105 | hidden_states,
106 | encoder_hidden_states=None,
107 | timestep=None,
108 | return_dict: bool = True,
109 | ):
110 | # Input
111 | assert (
112 | hidden_states.dim() == 5
113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114 | video_length = hidden_states.shape[2]
115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117 | encoder_hidden_states = repeat(
118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119 | )
120 |
121 | batch, channel, height, weight = hidden_states.shape
122 | residual = hidden_states
123 |
124 | hidden_states = self.norm(hidden_states)
125 | if not self.use_linear_projection:
126 | hidden_states = self.proj_in(hidden_states)
127 | inner_dim = hidden_states.shape[1]
128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129 | batch, height * weight, inner_dim
130 | )
131 | else:
132 | inner_dim = hidden_states.shape[1]
133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134 | batch, height * weight, inner_dim
135 | )
136 | hidden_states = self.proj_in(hidden_states)
137 |
138 | # Blocks
139 | for i, block in enumerate(self.transformer_blocks):
140 | hidden_states = block(
141 | hidden_states,
142 | encoder_hidden_states=encoder_hidden_states,
143 | timestep=timestep,
144 | video_length=video_length,
145 | )
146 |
147 | # Output
148 | if not self.use_linear_projection:
149 | hidden_states = (
150 | hidden_states.reshape(batch, height, weight, inner_dim)
151 | .permute(0, 3, 1, 2)
152 | .contiguous()
153 | )
154 | hidden_states = self.proj_out(hidden_states)
155 | else:
156 | hidden_states = self.proj_out(hidden_states)
157 | hidden_states = (
158 | hidden_states.reshape(batch, height, weight, inner_dim)
159 | .permute(0, 3, 1, 2)
160 | .contiguous()
161 | )
162 |
163 | output = hidden_states + residual
164 |
165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166 | if not return_dict:
167 | return (output,)
168 |
169 | return Transformer3DModelOutput(sample=output)
170 |
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__init__.py
--------------------------------------------------------------------------------
/src/pipelines/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/pipelines/__pycache__/context.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/context.cpython-310.pyc
--------------------------------------------------------------------------------
/src/pipelines/__pycache__/pipeline_pose2vid_long.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/pipeline_pose2vid_long.cpython-310.pyc
--------------------------------------------------------------------------------
/src/pipelines/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/pipelines/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/pipelines/context.py:
--------------------------------------------------------------------------------
1 | # TODO: Adapted from cli
2 | from typing import Callable, List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | def uniform(
16 | step: int = ...,
17 | num_steps: Optional[int] = None,
18 | num_frames: int = ...,
19 | context_size: Optional[int] = None,
20 | context_stride: int = 3,
21 | context_overlap: int = 4,
22 | closed_loop: bool = True,
23 | ):
24 | if num_frames <= context_size:
25 | yield list(range(num_frames))
26 | return
27 |
28 | context_stride = min(
29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30 | )
31 |
32 | for context_step in 1 << np.arange(context_stride):
33 | pad = int(round(num_frames * ordered_halving(step)))
34 | for j in range(
35 | int(ordered_halving(step) * context_step) + pad,
36 | num_frames + pad + (0 if closed_loop else -context_overlap),
37 | (context_size * context_step - context_overlap),
38 | ):
39 | yield [
40 | e % num_frames
41 | for e in range(j, j + context_size * context_step, context_step)
42 | ]
43 |
44 |
45 | def get_context_scheduler(name: str) -> Callable:
46 | if name == "uniform":
47 | return uniform
48 | else:
49 | raise ValueError(f"Unknown context_overlap policy {name}")
50 |
51 |
52 | def get_total_steps(
53 | scheduler,
54 | timesteps: List[int],
55 | num_steps: Optional[int] = None,
56 | num_frames: int = ...,
57 | context_size: Optional[int] = None,
58 | context_stride: int = 3,
59 | context_overlap: int = 4,
60 | closed_loop: bool = True,
61 | ):
62 | return sum(
63 | len(
64 | list(
65 | scheduler(
66 | i,
67 | num_steps,
68 | num_frames,
69 | context_size,
70 | context_stride,
71 | context_overlap,
72 | )
73 | )
74 | )
75 | for i in range(len(timesteps))
76 | )
77 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline_pose2vid.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision.transforms as transforms
9 | from diffusers import DiffusionPipeline
10 | from diffusers.image_processor import VaeImageProcessor
11 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
12 | EulerAncestralDiscreteScheduler,
13 | EulerDiscreteScheduler, LMSDiscreteScheduler,
14 | PNDMScheduler)
15 | from diffusers.utils import BaseOutput, is_accelerate_available
16 | from diffusers.utils.torch_utils import randn_tensor
17 | from einops import rearrange
18 | from tqdm import tqdm
19 | from transformers import CLIPImageProcessor
20 |
21 | from ..models.mutual_self_attention import ReferenceAttentionControl
22 |
23 |
24 | @dataclass
25 | class Pose2VideoPipelineOutput(BaseOutput):
26 | videos: Union[torch.Tensor, np.ndarray]
27 |
28 |
29 | class Pose2VideoPipeline(DiffusionPipeline):
30 | _optional_components = []
31 |
32 | def __init__(
33 | self,
34 | vae,
35 | image_encoder,
36 | reference_unet,
37 | denoising_unet,
38 | pose_guider,
39 | scheduler: Union[
40 | DDIMScheduler,
41 | PNDMScheduler,
42 | LMSDiscreteScheduler,
43 | EulerDiscreteScheduler,
44 | EulerAncestralDiscreteScheduler,
45 | DPMSolverMultistepScheduler,
46 | ],
47 | image_proj_model=None,
48 | tokenizer=None,
49 | text_encoder=None,
50 | ):
51 | super().__init__()
52 |
53 | self.register_modules(
54 | vae=vae,
55 | image_encoder=image_encoder,
56 | reference_unet=reference_unet,
57 | denoising_unet=denoising_unet,
58 | pose_guider=pose_guider,
59 | scheduler=scheduler,
60 | image_proj_model=image_proj_model,
61 | tokenizer=tokenizer,
62 | text_encoder=text_encoder,
63 | )
64 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
65 | self.clip_image_processor = CLIPImageProcessor()
66 | self.ref_image_processor = VaeImageProcessor(
67 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
68 | )
69 | self.cond_image_processor = VaeImageProcessor(
70 | vae_scale_factor=self.vae_scale_factor,
71 | do_convert_rgb=True,
72 | do_normalize=True,
73 | )
74 |
75 | def enable_vae_slicing(self):
76 | self.vae.enable_slicing()
77 |
78 | def disable_vae_slicing(self):
79 | self.vae.disable_slicing()
80 |
81 | def enable_sequential_cpu_offload(self, gpu_id=0):
82 | if is_accelerate_available():
83 | from accelerate import cpu_offload
84 | else:
85 | raise ImportError("Please install accelerate via `pip install accelerate`")
86 |
87 | device = torch.device(f"cuda:{gpu_id}")
88 |
89 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
90 | if cpu_offloaded_model is not None:
91 | cpu_offload(cpu_offloaded_model, device)
92 |
93 | @property
94 | def _execution_device(self):
95 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
96 | return self.device
97 | for module in self.unet.modules():
98 | if (
99 | hasattr(module, "_hf_hook")
100 | and hasattr(module._hf_hook, "execution_device")
101 | and module._hf_hook.execution_device is not None
102 | ):
103 | return torch.device(module._hf_hook.execution_device)
104 | return self.device
105 |
106 | def decode_latents(self, latents):
107 | video_length = latents.shape[2]
108 | latents = 1 / 0.18215 * latents
109 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
110 | # video = self.vae.decode(latents).sample
111 | video = []
112 | for frame_idx in tqdm(range(latents.shape[0])):
113 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
114 | video = torch.cat(video)
115 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
116 | video = (video / 2 + 0.5).clamp(0, 1)
117 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
118 | video = video.cpu().float().numpy()
119 | return video
120 |
121 | def prepare_extra_step_kwargs(self, generator, eta):
122 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
123 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
124 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
125 | # and should be between [0, 1]
126 |
127 | accepts_eta = "eta" in set(
128 | inspect.signature(self.scheduler.step).parameters.keys()
129 | )
130 | extra_step_kwargs = {}
131 | if accepts_eta:
132 | extra_step_kwargs["eta"] = eta
133 |
134 | # check if the scheduler accepts generator
135 | accepts_generator = "generator" in set(
136 | inspect.signature(self.scheduler.step).parameters.keys()
137 | )
138 | if accepts_generator:
139 | extra_step_kwargs["generator"] = generator
140 | return extra_step_kwargs
141 |
142 | def prepare_latents(
143 | self,
144 | batch_size,
145 | num_channels_latents,
146 | width,
147 | height,
148 | video_length,
149 | dtype,
150 | device,
151 | generator,
152 | latents=None,
153 | ):
154 | shape = (
155 | batch_size,
156 | num_channels_latents,
157 | video_length,
158 | height // self.vae_scale_factor,
159 | width // self.vae_scale_factor,
160 | )
161 | if isinstance(generator, list) and len(generator) != batch_size:
162 | raise ValueError(
163 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
164 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
165 | )
166 |
167 | if latents is None:
168 | latents = randn_tensor(
169 | shape, generator=generator, device=device, dtype=dtype
170 | )
171 | else:
172 | latents = latents.to(device)
173 |
174 | # scale the initial noise by the standard deviation required by the scheduler
175 | latents = latents * self.scheduler.init_noise_sigma
176 | return latents
177 |
178 | def _encode_prompt(
179 | self,
180 | prompt,
181 | device,
182 | num_videos_per_prompt,
183 | do_classifier_free_guidance,
184 | negative_prompt,
185 | ):
186 | batch_size = len(prompt) if isinstance(prompt, list) else 1
187 |
188 | text_inputs = self.tokenizer(
189 | prompt,
190 | padding="max_length",
191 | max_length=self.tokenizer.model_max_length,
192 | truncation=True,
193 | return_tensors="pt",
194 | )
195 | text_input_ids = text_inputs.input_ids
196 | untruncated_ids = self.tokenizer(
197 | prompt, padding="longest", return_tensors="pt"
198 | ).input_ids
199 |
200 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
201 | text_input_ids, untruncated_ids
202 | ):
203 | removed_text = self.tokenizer.batch_decode(
204 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
205 | )
206 |
207 | if (
208 | hasattr(self.text_encoder.config, "use_attention_mask")
209 | and self.text_encoder.config.use_attention_mask
210 | ):
211 | attention_mask = text_inputs.attention_mask.to(device)
212 | else:
213 | attention_mask = None
214 |
215 | text_embeddings = self.text_encoder(
216 | text_input_ids.to(device),
217 | attention_mask=attention_mask,
218 | )
219 | text_embeddings = text_embeddings[0]
220 |
221 | # duplicate text embeddings for each generation per prompt, using mps friendly method
222 | bs_embed, seq_len, _ = text_embeddings.shape
223 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
224 | text_embeddings = text_embeddings.view(
225 | bs_embed * num_videos_per_prompt, seq_len, -1
226 | )
227 |
228 | # get unconditional embeddings for classifier free guidance
229 | if do_classifier_free_guidance:
230 | uncond_tokens: List[str]
231 | if negative_prompt is None:
232 | uncond_tokens = [""] * batch_size
233 | elif type(prompt) is not type(negative_prompt):
234 | raise TypeError(
235 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
236 | f" {type(prompt)}."
237 | )
238 | elif isinstance(negative_prompt, str):
239 | uncond_tokens = [negative_prompt]
240 | elif batch_size != len(negative_prompt):
241 | raise ValueError(
242 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
243 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
244 | " the batch size of `prompt`."
245 | )
246 | else:
247 | uncond_tokens = negative_prompt
248 |
249 | max_length = text_input_ids.shape[-1]
250 | uncond_input = self.tokenizer(
251 | uncond_tokens,
252 | padding="max_length",
253 | max_length=max_length,
254 | truncation=True,
255 | return_tensors="pt",
256 | )
257 |
258 | if (
259 | hasattr(self.text_encoder.config, "use_attention_mask")
260 | and self.text_encoder.config.use_attention_mask
261 | ):
262 | attention_mask = uncond_input.attention_mask.to(device)
263 | else:
264 | attention_mask = None
265 |
266 | uncond_embeddings = self.text_encoder(
267 | uncond_input.input_ids.to(device),
268 | attention_mask=attention_mask,
269 | )
270 | uncond_embeddings = uncond_embeddings[0]
271 |
272 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
273 | seq_len = uncond_embeddings.shape[1]
274 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
275 | uncond_embeddings = uncond_embeddings.view(
276 | batch_size * num_videos_per_prompt, seq_len, -1
277 | )
278 |
279 | # For classifier free guidance, we need to do two forward passes.
280 | # Here we concatenate the unconditional and text embeddings into a single batch
281 | # to avoid doing two forward passes
282 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
283 |
284 | return text_embeddings
285 |
286 | @torch.no_grad()
287 | def __call__(
288 | self,
289 | ref_image,
290 | pose_images,
291 | ref_pose_image,
292 | width,
293 | height,
294 | video_length,
295 | num_inference_steps,
296 | guidance_scale,
297 | num_images_per_prompt=1,
298 | eta: float = 0.0,
299 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
300 | output_type: Optional[str] = "tensor",
301 | return_dict: bool = True,
302 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
303 | callback_steps: Optional[int] = 1,
304 | **kwargs,
305 | ):
306 | # Default height and width to unet
307 | height = height or self.unet.config.sample_size * self.vae_scale_factor
308 | width = width or self.unet.config.sample_size * self.vae_scale_factor
309 |
310 | device = self._execution_device
311 |
312 | do_classifier_free_guidance = guidance_scale > 1.0
313 |
314 | # Prepare timesteps
315 | self.scheduler.set_timesteps(num_inference_steps, device=device)
316 | timesteps = self.scheduler.timesteps
317 |
318 | batch_size = 1
319 |
320 | # Prepare clip image embeds
321 | clip_image = self.clip_image_processor.preprocess(
322 | ref_image, return_tensors="pt"
323 | ).pixel_values
324 | clip_image_embeds = self.image_encoder(
325 | clip_image.to(device, dtype=self.image_encoder.dtype)
326 | ).image_embeds
327 | encoder_hidden_states = clip_image_embeds.unsqueeze(1)
328 |
329 | uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
330 |
331 | if do_classifier_free_guidance:
332 | encoder_hidden_states = torch.cat(
333 | [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
334 | )
335 | reference_control_writer = ReferenceAttentionControl(
336 | self.reference_unet,
337 | do_classifier_free_guidance=do_classifier_free_guidance,
338 | mode="write",
339 | batch_size=batch_size,
340 | fusion_blocks="full",
341 | )
342 | reference_control_reader = ReferenceAttentionControl(
343 | self.denoising_unet,
344 | do_classifier_free_guidance=do_classifier_free_guidance,
345 | mode="read",
346 | batch_size=batch_size,
347 | fusion_blocks="full",
348 | )
349 |
350 | num_channels_latents = self.denoising_unet.in_channels
351 | latents = self.prepare_latents(
352 | batch_size * num_images_per_prompt,
353 | num_channels_latents,
354 | width,
355 | height,
356 | video_length,
357 | clip_image_embeds.dtype,
358 | device,
359 | generator,
360 | )
361 |
362 | # Prepare extra step kwargs.
363 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
364 |
365 | # Prepare ref image latents
366 | ref_image_tensor = self.ref_image_processor.preprocess(
367 | ref_image, height=height, width=width
368 | ) # (bs, c, width, height)
369 | ref_image_tensor = ref_image_tensor.to(
370 | dtype=self.vae.dtype, device=self.vae.device
371 | )
372 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
373 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
374 |
375 | # Prepare a list of pose condition images
376 | pose_cond_tensor_list = []
377 | for pose_image in pose_images:
378 | pose_cond_tensor = self.cond_image_processor.preprocess(
379 | pose_image, height=height, width=width
380 | ).transpose(0, 1) # (c, 1, h, w)
381 | pose_cond_tensor_list.append(pose_cond_tensor)
382 | pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
383 |
384 | pose_cond_tensor = pose_cond_tensor.unsqueeze(0) # (1, c, t, h, w)
385 | pose_cond_tensor = pose_cond_tensor.to(
386 | device=device, dtype=self.pose_guider.dtype
387 | )
388 |
389 | ref_pose_tensor = self.cond_image_processor.preprocess(
390 | ref_pose_image, height=height, width=width
391 | )
392 | ref_pose_tensor = ref_pose_tensor.to(
393 | device=device, dtype=self.pose_guider.dtype
394 | )
395 |
396 | pose_fea = self.pose_guider(pose_cond_tensor, ref_pose_tensor)
397 | if do_classifier_free_guidance:
398 | for idxx in range(len(pose_fea)):
399 | pose_fea[idxx] = torch.cat([pose_fea[idxx]] * 2)
400 |
401 | # denoising loop
402 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
403 | with self.progress_bar(total=num_inference_steps) as progress_bar:
404 | for i, t in enumerate(timesteps):
405 | # 1. Forward reference image
406 | if i == 0:
407 | self.reference_unet(
408 | ref_image_latents.repeat(
409 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
410 | ),
411 | torch.zeros_like(t),
412 | # t,
413 | encoder_hidden_states=encoder_hidden_states,
414 | return_dict=False,
415 | )
416 | reference_control_reader.update(reference_control_writer)
417 |
418 | # 3.1 expand the latents if we are doing classifier free guidance
419 | latent_model_input = (
420 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
421 | )
422 | latent_model_input = self.scheduler.scale_model_input(
423 | latent_model_input, t
424 | )
425 |
426 | noise_pred = self.denoising_unet(
427 | latent_model_input,
428 | t,
429 | encoder_hidden_states=encoder_hidden_states,
430 | pose_cond_fea=pose_fea,
431 | return_dict=False,
432 | )[0]
433 |
434 | # perform guidance
435 | if do_classifier_free_guidance:
436 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
437 | noise_pred = noise_pred_uncond + guidance_scale * (
438 | noise_pred_text - noise_pred_uncond
439 | )
440 |
441 | # compute the previous noisy sample x_t -> x_t-1
442 | latents = self.scheduler.step(
443 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
444 | )[0]
445 |
446 | # call the callback, if provided
447 | if i == len(timesteps) - 1 or (
448 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
449 | ):
450 | progress_bar.update()
451 | if callback is not None and i % callback_steps == 0:
452 | step_idx = i // getattr(self.scheduler, "order", 1)
453 | callback(step_idx, t, latents)
454 |
455 | reference_control_reader.clear()
456 | reference_control_writer.clear()
457 |
458 | # Post-processing
459 | images = self.decode_latents(latents) # (b, c, f, h, w)
460 |
461 | # Convert to tensor
462 | if output_type == "tensor":
463 | images = torch.from_numpy(images)
464 |
465 | if not return_dict:
466 | return images
467 |
468 | return Pose2VideoPipelineOutput(videos=images)
469 |
--------------------------------------------------------------------------------
/src/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | tensor_interpolation = None
4 |
5 |
6 | def get_tensor_interpolation_method():
7 | return tensor_interpolation
8 |
9 |
10 | def set_tensor_interpolation_method(is_slerp):
11 | global tensor_interpolation
12 | tensor_interpolation = slerp if is_slerp else linear
13 |
14 |
15 | def linear(v1, v2, t):
16 | return (1.0 - t) * v1 + t * v2
17 |
18 |
19 | def slerp(
20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21 | ) -> torch.Tensor:
22 | u0 = v0 / v0.norm()
23 | u1 = v1 / v1.norm()
24 | dot = (u0 * u1).sum()
25 | if dot.abs() > DOT_THRESHOLD:
26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27 | return (1.0 - t) * v0 + t * v1
28 | omega = dot.acos()
29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
30 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/audio_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/audio_util.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/draw_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/draw_util.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/face_landmark.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/face_landmark.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/logger.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/logger.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/mp_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/mp_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/pose_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/pose_util.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/src/utils/audio_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 |
4 | import librosa
5 | import numpy as np
6 | from transformers import Wav2Vec2FeatureExtractor
7 |
8 |
9 | class DataProcessor:
10 | def __init__(self, sampling_rate, wav2vec_model_path):
11 | self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
12 | self._sampling_rate = sampling_rate
13 |
14 | def extract_feature(self, audio_path):
15 | speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate)
16 | input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values)
17 | return input_value
18 |
19 |
20 | def prepare_audio_feature(wav_file, fps=30, sampling_rate=16000, wav2vec_model_path=None):
21 | data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path)
22 |
23 | input_value = data_preprocessor.extract_feature(wav_file)
24 | seq_len = math.ceil(len(input_value)/sampling_rate*fps)
25 | return {
26 | "audio_feature": input_value,
27 | "seq_len": seq_len
28 | }
29 |
30 |
31 |
--------------------------------------------------------------------------------
/src/utils/draw_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mediapipe as mp
3 | import numpy as np
4 | from mediapipe.framework.formats import landmark_pb2
5 |
6 | class FaceMeshVisualizer:
7 | def __init__(self, forehead_edge=False):
8 | self.mp_drawing = mp.solutions.drawing_utils
9 | mp_face_mesh = mp.solutions.face_mesh
10 | self.mp_face_mesh = mp_face_mesh
11 | self.forehead_edge = forehead_edge
12 |
13 | DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
14 | f_thick = 2
15 | f_rad = 1
16 | right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
17 | right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
18 | right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
19 | left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
20 | left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
21 | left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
22 | head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
23 |
24 | mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
25 | mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
26 |
27 | mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
28 | mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
29 |
30 | mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
31 | mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
32 |
33 | mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
34 | mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad)
35 |
36 | FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)]
37 | FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)]
38 |
39 | FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)]
40 | FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)]
41 |
42 | FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)]
43 | FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)]
44 |
45 | FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)]
46 | FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)]
47 |
48 | FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)]
49 |
50 | # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
51 | face_connection_spec = {}
52 | if self.forehead_edge:
53 | for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
54 | face_connection_spec[edge] = head_draw
55 | else:
56 | for edge in FACEMESH_CUSTOM_FACE_OVAL:
57 | face_connection_spec[edge] = head_draw
58 | for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
59 | face_connection_spec[edge] = left_eye_draw
60 | for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
61 | face_connection_spec[edge] = left_eyebrow_draw
62 | # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
63 | # face_connection_spec[edge] = left_iris_draw
64 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
65 | face_connection_spec[edge] = right_eye_draw
66 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
67 | face_connection_spec[edge] = right_eyebrow_draw
68 | # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
69 | # face_connection_spec[edge] = right_iris_draw
70 | # for edge in mp_face_mesh.FACEMESH_LIPS:
71 | # face_connection_spec[edge] = mouth_draw
72 |
73 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
74 | face_connection_spec[edge] = mouth_draw_obl
75 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
76 | face_connection_spec[edge] = mouth_draw_obr
77 | for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
78 | face_connection_spec[edge] = mouth_draw_ibl
79 | for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
80 | face_connection_spec[edge] = mouth_draw_ibr
81 | for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
82 | face_connection_spec[edge] = mouth_draw_otl
83 | for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
84 | face_connection_spec[edge] = mouth_draw_otr
85 | for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
86 | face_connection_spec[edge] = mouth_draw_itl
87 | for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
88 | face_connection_spec[edge] = mouth_draw_itr
89 |
90 |
91 | iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
92 |
93 | self.face_connection_spec = face_connection_spec
94 | def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
95 | """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
96 | landmarks. Until our PR is merged into mediapipe, we need this separate method."""
97 | if len(image.shape) != 3:
98 | raise ValueError("Input image must be H,W,C.")
99 | image_rows, image_cols, image_channels = image.shape
100 | if image_channels != 3: # BGR channels
101 | raise ValueError('Input image must contain three channel bgr data.')
102 | for idx, landmark in enumerate(landmark_list.landmark):
103 | if (
104 | (landmark.HasField('visibility') and landmark.visibility < 0.9) or
105 | (landmark.HasField('presence') and landmark.presence < 0.5)
106 | ):
107 | continue
108 | if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
109 | continue
110 | image_x = int(image_cols*landmark.x)
111 | image_y = int(image_rows*landmark.y)
112 | draw_color = None
113 | if isinstance(drawing_spec, Mapping):
114 | if drawing_spec.get(idx) is None:
115 | continue
116 | else:
117 | draw_color = drawing_spec[idx].color
118 | elif isinstance(drawing_spec, DrawingSpec):
119 | draw_color = drawing_spec.color
120 | image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
121 |
122 |
123 |
124 | def draw_landmarks(self, image_size, keypoints, normed=False):
125 | ini_size = [512, 512]
126 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
127 | new_landmarks = landmark_pb2.NormalizedLandmarkList()
128 | for i in range(keypoints.shape[0]):
129 | landmark = new_landmarks.landmark.add()
130 | if normed:
131 | landmark.x = keypoints[i, 0]
132 | landmark.y = keypoints[i, 1]
133 | else:
134 | landmark.x = keypoints[i, 0] / image_size[0]
135 | landmark.y = keypoints[i, 1] / image_size[1]
136 | landmark.z = 1.0
137 |
138 | self.mp_drawing.draw_landmarks(
139 | image=image,
140 | landmark_list=new_landmarks,
141 | connections=self.face_connection_spec.keys(),
142 | landmark_drawing_spec=None,
143 | connection_drawing_spec=self.face_connection_spec
144 | )
145 | # draw_pupils(image, face_landmarks, iris_landmark_spec, 2)
146 | image = cv2.resize(image, (image_size[0], image_size[1]))
147 |
148 | return image
149 |
150 |
--------------------------------------------------------------------------------
/src/utils/frame_interpolation.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/dajes/frame-interpolation-pytorch
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import bisect
7 | import shutil
8 | import pdb
9 | from tqdm import tqdm
10 |
11 | def init_frame_interpolation_model():
12 | print("Initializing frame interpolation model")
13 | checkpoint_name = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),"pretrained_model/film_net_fp16.pt")
14 | #checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
15 |
16 | model = torch.jit.load(checkpoint_name, map_location='cpu')
17 | model.eval()
18 | model = model.half()
19 | model = model.to(device="cuda")
20 | return model
21 |
22 |
23 | def batch_images_interpolation_tool(input_tensor, model, inter_frames=1):
24 |
25 | video_tensor = []
26 | frame_num = input_tensor.shape[2] # bs, channel, frame, height, width
27 |
28 | for idx in tqdm(range(frame_num-1)):
29 | image1 = input_tensor[:,:,idx]
30 | image2 = input_tensor[:,:,idx+1]
31 |
32 | results = [image1, image2]
33 |
34 | inter_frames = int(inter_frames)
35 | idxes = [0, inter_frames + 1]
36 | remains = list(range(1, inter_frames + 1))
37 |
38 | splits = torch.linspace(0, 1, inter_frames + 2)
39 |
40 | for _ in range(len(remains)):
41 | starts = splits[idxes[:-1]]
42 | ends = splits[idxes[1:]]
43 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
44 | matrix = torch.argmin(distances).item()
45 | start_i, step = np.unravel_index(matrix, distances.shape)
46 | end_i = start_i + 1
47 |
48 | x0 = results[start_i]
49 | x1 = results[end_i]
50 |
51 | x0 = x0.half()
52 | x1 = x1.half()
53 | x0 = x0.cuda()
54 | x1 = x1.cuda()
55 |
56 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
57 |
58 | with torch.no_grad():
59 | prediction = model(x0, x1, dt)
60 | insert_position = bisect.bisect_left(idxes, remains[step])
61 | idxes.insert(insert_position, remains[step])
62 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
63 | del remains[step]
64 |
65 | for sub_idx in range(len(results)-1):
66 | video_tensor.append(results[sub_idx].unsqueeze(2))
67 |
68 | video_tensor.append(input_tensor[:,:,-1].unsqueeze(2))
69 | video_tensor = torch.cat(video_tensor, dim=2)
70 | return video_tensor
71 |
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import copy
3 | import logging
4 |
5 | class ColoredFormatter(logging.Formatter):
6 | COLORS = {
7 | "DEBUG": "\033[0;36m", # CYAN
8 | "INFO": "\033[0;32m", # GREEN
9 | "WARNING": "\033[0;33m", # YELLOW
10 | "ERROR": "\033[0;31m", # RED
11 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED
12 | "RESET": "\033[0m", # RESET COLOR
13 | }
14 |
15 | def format(self, record):
16 | colored_record = copy.copy(record)
17 | levelname = colored_record.levelname
18 | seq = self.COLORS.get(levelname, self.COLORS["RESET"])
19 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
20 | return super().format(colored_record)
21 |
22 |
23 | # Create a new logger
24 | logger = logging.getLogger("AniPortrait")
25 | logger.propagate = False
26 |
27 | # Add handler if we don't have one.
28 | if not logger.handlers:
29 | handler = logging.StreamHandler(sys.stdout)
30 | handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s"))
31 | logger.addHandler(handler)
32 |
33 | # Configure logger
34 | loglevel = logging.INFO
35 | logger.setLevel(loglevel)
36 |
--------------------------------------------------------------------------------
/src/utils/mp_models/blaze_face_short_range.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/mp_models/blaze_face_short_range.tflite
--------------------------------------------------------------------------------
/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task
--------------------------------------------------------------------------------
/src/utils/mp_models/pose_landmarker_heavy.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frankchieng/ComfyUI_Aniportrait/001673aa0c190b79aa0c051e690984c5e8542cab/src/utils/mp_models/pose_landmarker_heavy.task
--------------------------------------------------------------------------------
/src/utils/mp_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import time
5 | from tqdm import tqdm
6 | import multiprocessing
7 | import glob
8 |
9 | import mediapipe as mp
10 | from mediapipe import solutions
11 | from mediapipe.framework.formats import landmark_pb2
12 | from mediapipe.tasks import python
13 | from mediapipe.tasks.python import vision
14 | from . import face_landmark
15 |
16 | CUR_DIR = os.path.dirname(__file__)
17 |
18 |
19 | class LMKExtractor():
20 | def __init__(self, FPS=25):
21 | # Create an FaceLandmarker object.
22 | self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE
23 | base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task'))
24 | base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU
25 | options = vision.FaceLandmarkerOptions(base_options=base_options,
26 | running_mode=self.mode,
27 | output_face_blendshapes=True,
28 | output_facial_transformation_matrixes=True,
29 | num_faces=1)
30 | self.detector = face_landmark.FaceLandmarker.create_from_options(options)
31 | self.last_ts = 0
32 | self.frame_ms = int(1000 / FPS)
33 |
34 | det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite'))
35 | det_options = vision.FaceDetectorOptions(base_options=det_base_options)
36 | self.det_detector = vision.FaceDetector.create_from_options(det_options)
37 |
38 |
39 | def __call__(self, img):
40 | frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
42 | t0 = time.time()
43 | if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO:
44 | det_result = self.det_detector.detect(image)
45 | if len(det_result.detections) != 1:
46 | return None
47 | self.last_ts += self.frame_ms
48 | try:
49 | detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts)
50 | except:
51 | return None
52 | elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE:
53 | # det_result = self.det_detector.detect(image)
54 |
55 | # if len(det_result.detections) != 1:
56 | # return None
57 | try:
58 | detection_result, mesh3d = self.detector.detect(image)
59 | except:
60 | return None
61 |
62 |
63 | bs_list = detection_result.face_blendshapes
64 | if len(bs_list) == 1:
65 | bs = bs_list[0]
66 | bs_values = []
67 | for index in range(len(bs)):
68 | bs_values.append(bs[index].score)
69 | bs_values = bs_values[1:] # remove neutral
70 | trans_mat = detection_result.facial_transformation_matrixes[0]
71 | face_landmarks_list = detection_result.face_landmarks
72 | face_landmarks = face_landmarks_list[0]
73 | lmks = []
74 | for index in range(len(face_landmarks)):
75 | x = face_landmarks[index].x
76 | y = face_landmarks[index].y
77 | z = face_landmarks[index].z
78 | lmks.append([x, y, z])
79 | lmks = np.array(lmks)
80 |
81 | lmks3d = np.array(mesh3d.vertex_buffer)
82 | lmks3d = lmks3d.reshape(-1, 5)[:, :3]
83 | mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1
84 |
85 | return {
86 | "lmks": lmks,
87 | 'lmks3d': lmks3d,
88 | "trans_mat": trans_mat,
89 | 'faces': mp_tris,
90 | "bs": bs_values
91 | }
92 | else:
93 | # print('multiple faces in the image: {}'.format(img_path))
94 | return None
95 |
96 |
--------------------------------------------------------------------------------
/src/utils/pose_util.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as R
5 |
6 |
7 | def create_perspective_matrix(aspect_ratio):
8 | kDegreesToRadians = np.pi / 180.
9 | near = 1
10 | far = 10000
11 | perspective_matrix = np.zeros(16, dtype=np.float32)
12 |
13 | # Standard perspective projection matrix calculations.
14 | f = 1.0 / np.tan(kDegreesToRadians * 63 / 2.)
15 |
16 | denom = 1.0 / (near - far)
17 | perspective_matrix[0] = f / aspect_ratio
18 | perspective_matrix[5] = f
19 | perspective_matrix[10] = (near + far) * denom
20 | perspective_matrix[11] = -1.
21 | perspective_matrix[14] = 1. * far * near * denom
22 |
23 | # If the environment's origin point location is in the top left corner,
24 | # then skip additional flip along Y-axis is required to render correctly.
25 |
26 | perspective_matrix[5] *= -1.
27 | return perspective_matrix
28 |
29 |
30 | def project_points(points_3d, transformation_matrix, pose_vectors, image_shape):
31 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
32 | L, N, _ = points_3d.shape
33 | projected_points = np.zeros((L, N, 2))
34 | for i in range(L):
35 | points_3d_frame = points_3d[i]
36 | ones = np.ones((points_3d_frame.shape[0], 1))
37 | points_3d_homogeneous = np.hstack([points_3d_frame, ones])
38 | transformed_points = points_3d_homogeneous @ (transformation_matrix @ euler_and_translation_to_matrix(pose_vectors[i][:3], pose_vectors[i][3:])).T @ P
39 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
40 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
41 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
42 | projected_points[i] = projected_points_frame
43 | return projected_points
44 |
45 |
46 | def project_points_with_trans(points_3d, transformation_matrix, image_shape):
47 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
48 | L, N, _ = points_3d.shape
49 | projected_points = np.zeros((L, N, 2))
50 | for i in range(L):
51 | points_3d_frame = points_3d[i]
52 | ones = np.ones((points_3d_frame.shape[0], 1))
53 | points_3d_homogeneous = np.hstack([points_3d_frame, ones])
54 | transformed_points = points_3d_homogeneous @ transformation_matrix[i].T @ P
55 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
56 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
57 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
58 | projected_points[i] = projected_points_frame
59 | return projected_points
60 |
61 |
62 | def euler_and_translation_to_matrix(euler_angles, translation_vector):
63 | rotation = R.from_euler('xyz', euler_angles, degrees=True)
64 | rotation_matrix = rotation.as_matrix()
65 |
66 | matrix = np.eye(4)
67 | matrix[:3, :3] = rotation_matrix
68 | matrix[:3, 3] = translation_vector
69 |
70 | return matrix
71 |
72 |
73 | def matrix_to_euler_and_translation(matrix):
74 | rotation_matrix = matrix[:3, :3]
75 | translation_vector = matrix[:3, 3]
76 | rotation = R.from_matrix(rotation_matrix)
77 | euler_angles = rotation.as_euler('xyz', degrees=True)
78 | return euler_angles, translation_vector
79 |
80 |
81 | def smooth_pose_seq(pose_seq, window_size=5):
82 | smoothed_pose_seq = np.zeros_like(pose_seq)
83 |
84 | for i in range(len(pose_seq)):
85 | start = max(0, i - window_size // 2)
86 | end = min(len(pose_seq), i + window_size // 2 + 1)
87 | smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0)
88 |
89 | return smoothed_pose_seq
90 |
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import sys
6 | from pathlib import Path
7 | import hashlib
8 |
9 | from typing import Iterable
10 | import subprocess
11 | import re
12 |
13 | from .logger import logger
14 |
15 | import av
16 | import numpy as np
17 | import torch
18 | import torchvision
19 | from einops import rearrange
20 | from PIL import Image
21 |
22 | def seed_everything(seed):
23 | import random
24 |
25 | import numpy as np
26 |
27 | torch.manual_seed(seed)
28 | torch.cuda.manual_seed_all(seed)
29 | np.random.seed(seed % (2**32))
30 | random.seed(seed)
31 |
32 |
33 | def import_filename(filename):
34 | spec = importlib.util.spec_from_file_location("mymodule", filename)
35 | module = importlib.util.module_from_spec(spec)
36 | sys.modules[spec.name] = module
37 | spec.loader.exec_module(module)
38 | return module
39 |
40 |
41 | def delete_additional_ckpt(base_path, num_keep):
42 | dirs = []
43 | for d in os.listdir(base_path):
44 | if d.startswith("checkpoint-"):
45 | dirs.append(d)
46 | num_tot = len(dirs)
47 | if num_tot <= num_keep:
48 | return
49 | # ensure ckpt is sorted and delete the ealier!
50 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
51 | for d in del_dirs:
52 | path_to_dir = osp.join(base_path, d)
53 | if osp.exists(path_to_dir):
54 | shutil.rmtree(path_to_dir)
55 |
56 |
57 | def save_videos_from_pil(pil_images, path, fps=8):
58 | import av
59 |
60 | save_fmt = Path(path).suffix
61 | os.makedirs(os.path.dirname(path), exist_ok=True)
62 | width, height = pil_images[0].size
63 |
64 | if save_fmt == ".mp4":
65 | codec = "libx264"
66 | container = av.open(path, "w")
67 | stream = container.add_stream(codec, rate=fps)
68 |
69 | stream.width = width
70 | stream.height = height
71 |
72 | for pil_image in pil_images:
73 | # pil_image = Image.fromarray(image_arr).convert("RGB")
74 | av_frame = av.VideoFrame.from_image(pil_image)
75 | container.mux(stream.encode(av_frame))
76 | container.mux(stream.encode())
77 | container.close()
78 |
79 | elif save_fmt == ".gif":
80 | pil_images[0].save(
81 | fp=path,
82 | format="GIF",
83 | append_images=pil_images[1:],
84 | save_all=True,
85 | duration=(1 / fps * 1000),
86 | loop=0,
87 | )
88 | else:
89 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
90 |
91 |
92 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
93 | videos = rearrange(videos, "b c t h w -> t b c h w")
94 | height, width = videos.shape[-2:]
95 | outputs = []
96 |
97 | for x in videos:
98 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
99 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
100 | if rescale:
101 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
102 | x = (x * 255).numpy().astype(np.uint8)
103 | x = Image.fromarray(x)
104 |
105 | outputs.append(x)
106 |
107 | os.makedirs(os.path.dirname(path), exist_ok=True)
108 |
109 | save_videos_from_pil(outputs, path, fps)
110 |
111 |
112 | def read_frames(video_path):
113 | container = av.open(video_path)
114 |
115 | video_stream = next(s for s in container.streams if s.type == "video")
116 | frames = []
117 | for packet in container.demux(video_stream):
118 | for frame in packet.decode():
119 | image = Image.frombytes(
120 | "RGB",
121 | (frame.width, frame.height),
122 | frame.to_rgb().to_ndarray(),
123 | )
124 | frames.append(image)
125 |
126 | return frames
127 |
128 |
129 | def get_fps(video_path):
130 | container = av.open(video_path)
131 | video_stream = next(s for s in container.streams if s.type == "video")
132 | fps = video_stream.average_rate
133 | container.close()
134 | return fps
135 |
136 | def ffmpeg_suitability(path):
137 | try:
138 | version = subprocess.run([path, "-version"], check=True,
139 | capture_output=True).stdout.decode("utf-8")
140 | except:
141 | return 0
142 | score = 0
143 | #rough layout of the importance of various features
144 | simple_criterion = [("libvpx", 20),("264",10), ("265",3),
145 | ("svtav1",5),("libopus", 1)]
146 | for criterion in simple_criterion:
147 | if version.find(criterion[0]) >= 0:
148 | score += criterion[1]
149 | #obtain rough compile year from copyright information
150 | copyright_index = version.find('2000-2')
151 | if copyright_index >= 0:
152 | copyright_year = version[copyright_index+6:copyright_index+9]
153 | if copyright_year.isnumeric():
154 | score += int(copyright_year)
155 | return score
156 |
157 |
158 | if "VHS_FORCE_FFMPEG_PATH" in os.environ:
159 | ffmpeg_path = os.env["VHS_FORCE_FFMPEG_PATH"]
160 | else:
161 | ffmpeg_paths = []
162 | try:
163 | from imageio_ffmpeg import get_ffmpeg_exe
164 | imageio_ffmpeg_path = get_ffmpeg_exe()
165 | ffmpeg_paths.append(imageio_ffmpeg_path)
166 | except:
167 | if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
168 | raise
169 | logger.warn("Failed to import imageio_ffmpeg")
170 | if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
171 | ffmpeg_path = imageio_ffmpeg_path
172 | else:
173 | system_ffmpeg = shutil.which("ffmpeg")
174 | if system_ffmpeg is not None:
175 | ffmpeg_paths.append(system_ffmpeg)
176 | if len(ffmpeg_paths) == 0:
177 | logger.error("No valid ffmpeg found.")
178 | ffmpeg_path = None
179 | else:
180 | ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)
181 |
182 |
183 | def get_sorted_dir_files_from_directory(directory: str, skip_first_images: int=0, select_every_nth: int=1, extensions: Iterable=None):
184 | directory = directory.strip()
185 | dir_files = os.listdir(directory)
186 | dir_files = sorted(dir_files)
187 | dir_files = [os.path.join(directory, x) for x in dir_files]
188 | dir_files = list(filter(lambda filepath: os.path.isfile(filepath), dir_files))
189 | # filter by extension, if needed
190 | if extensions is not None:
191 | extensions = list(extensions)
192 | new_dir_files = []
193 | for filepath in dir_files:
194 | ext = "." + filepath.split(".")[-1]
195 | if ext.lower() in extensions:
196 | new_dir_files.append(filepath)
197 | dir_files = new_dir_files
198 | # start at skip_first_images
199 | dir_files = dir_files[skip_first_images:]
200 | dir_files = dir_files[0::select_every_nth]
201 | return dir_files
202 |
203 |
204 | # modified from https://stackoverflow.com/questions/22058048/hashing-a-file-in-python
205 | def calculate_file_hash(filename: str, hash_every_n: int = 1):
206 | h = hashlib.sha256()
207 | b = bytearray(10*1024*1024) # read 10 megabytes at a time
208 | mv = memoryview(b)
209 | with open(filename, 'rb', buffering=0) as f:
210 | i = 0
211 | # don't hash entire file, only portions of it if requested
212 | while n := f.readinto(mv):
213 | if i%hash_every_n == 0:
214 | h.update(mv[:n])
215 | i += 1
216 | return h.hexdigest()
217 |
218 |
219 | def get_audio(file, start_time=0, duration=0):
220 | args = [ffmpeg_path, "-v", "error", "-i", file]
221 | if start_time > 0:
222 | args += ["-ss", str(start_time)]
223 | if duration > 0:
224 | args += ["-t", str(duration)]
225 | return subprocess.run(args + ["-f", "wav", "-"],
226 | stdout=subprocess.PIPE, check=True).stdout
227 |
228 |
229 | def lazy_eval(func):
230 | class Cache:
231 | def __init__(self, func):
232 | self.res = None
233 | self.func = func
234 | def get(self):
235 | if self.res is None:
236 | self.res = self.func()
237 | return self.res
238 | cache = Cache(func)
239 | return lambda : cache.get()
240 |
241 |
242 | def is_url(url):
243 | return url.split("://")[0] in ["http", "https"]
244 |
245 | def validate_sequence(path):
246 | #Check if path is a valid ffmpeg sequence that points to at least one file
247 | (path, file) = os.path.split(path)
248 | if not os.path.isdir(path):
249 | return False
250 | match = re.search('%0?\d+d', file)
251 | if not match:
252 | return False
253 | seq = match.group()
254 | if seq == '%d':
255 | seq = '\\\\d+'
256 | else:
257 | seq = '\\\\d{%s}' % seq[1:-1]
258 | file_matcher = re.compile(re.sub('%0?\d+d', seq, file))
259 | for file in os.listdir(path):
260 | if file_matcher.fullmatch(file):
261 | return True
262 | return False
263 |
264 | def hash_path(path):
265 | if path is None:
266 | return "input"
267 | if is_url(path):
268 | return "url"
269 | return calculate_file_hash(path.strip("\""))
270 |
271 |
272 | def validate_path(path, allow_none=False, allow_url=True):
273 | if path is None:
274 | return allow_none
275 | if is_url(path):
276 | #Probably not feasible to check if url resolves here
277 | return True if allow_url else "URLs are unsupported for this path"
278 | if not os.path.isfile(path.strip("\"")):
279 | return "Invalid file path: {}".format(path)
280 | return True
281 |
--------------------------------------------------------------------------------