├── .gitignore
├── LICENSE.txt
├── README.md
├── app.py
├── assets
├── demo.gif
├── driving_audio
│ ├── 1.wav
│ ├── 2.wav
│ ├── 3.wav
│ ├── 4.wav
│ ├── 5.wav
│ └── 6.wav
├── driving_video
│ ├── .DS_Store
│ ├── 1.mp4
│ ├── 2.mp4
│ ├── 3.mp4
│ ├── 4.mp4
│ ├── 5.mp4
│ ├── 6.mp4
│ ├── 7.mp4
│ └── 8.mp4
├── gradio.png
├── logo.png
└── ref_images
│ ├── 1.png
│ ├── 10.png
│ ├── 11.png
│ ├── 12.png
│ ├── 13.png
│ ├── 14.png
│ ├── 15.png
│ ├── 16.png
│ ├── 17.png
│ ├── 18.png
│ ├── 19.png
│ ├── 2.png
│ ├── 20.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ ├── 7.png
│ └── 8.png
├── diffposetalk
├── common.py
├── diff_talking_head.py
├── diffposetalk.py
├── hubert.py
├── utils
│ ├── __init__.py
│ ├── common.py
│ ├── media.py
│ ├── renderer.py
│ └── rotation_conversions.py
└── wav2vec2.py
├── eval
├── arc_score.py
├── curricularface
│ ├── __init__.py
│ ├── common.py
│ ├── model_irse.py
│ └── model_resnet.py
├── expression_score.py
└── pose_score.py
├── inference.py
├── inference_audio.py
├── inference_audio_long_video.py
├── inference_long_video.py
├── requirements.txt
├── scripts
├── __init__.py
└── demo.py
└── skyreels_a1
├── __init__.py
├── ddim_solver.py
├── models
├── __init__.py
└── transformer3d.py
├── pipeline_output.py
├── pre_process_lmk3d.py
├── skyreels_a1_i2v_long_pipeline.py
├── skyreels_a1_i2v_pipeline.py
└── src
├── FLAME
├── FLAME.py
└── lbs.py
├── __init__.py
├── frame_interpolation.py
├── lmk3d_test.py
├── media_pipe
├── draw_util.py
├── draw_util_2d.py
├── face_landmark.py
├── mp_models
│ ├── blaze_face_short_range.tflite
│ ├── face_landmarker_v2_with_blendshapes.task
│ └── pose_landmarker_heavy.task
├── mp_utils.py
└── readme
├── multi_fps.py
├── renderer.py
├── smirk_encoder.py
└── utils
└── mediapipe_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | pretrained_models/
2 | __pycache__/
3 | .gradio/
4 | outputs/
5 | outputs_audio/
6 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | ---
2 | language:
3 | - en
4 | - zh
5 | license: other
6 | tasks:
7 | - text-generation
8 |
9 | ---
10 |
11 |
12 |
13 |
14 | # 声明与协议/Terms and Conditions
15 |
16 | ## 声明
17 |
18 | 我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
19 |
20 | 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
21 |
22 | We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
23 |
24 | We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
25 |
26 | ## 协议
27 |
28 | 社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
29 |
30 |
31 | The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
32 |
33 |
34 |
35 | [《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
36 |
37 |
38 | [skywork-opensource@kunlun-inc.com]: mailto:skywork-opensource@kunlun-inc.com
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers
7 |
8 |
15 |
16 |
21 |
22 |
23 | Skywork AI, Kunlun Inc.
24 |
25 |
26 |
27 |
28 |
29 |
30 |

31 |

32 |

33 |

34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | 🔥 For more results, visit our homepage 🔥
43 |
44 |
45 |
46 | 👋 Join our Discord
47 |
48 |
49 |
50 | This repo, named **SkyReels-A1**, contains the official PyTorch implementation of our paper [SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers](https://arxiv.org/abs/2502.10841).
51 |
52 |
53 | ## 🔥🔥🔥 News!!
54 | * Apr 3, 2025: 🔥 We release [SkyReels-A2](https://github.com/SkyworkAI/SkyReels-A2). This is an open-sourced controllable video generation framework capable of assembling arbitrary visual elements.
55 | * Mar 4, 2025: 🔥 We release audio-driven portrait image animation pipeline. Try out on [Huggingface Spaces Demo](https://huggingface.co/spaces/Skywork/skyreels-a1-talking-head) !
56 | * Feb 18, 2025: 👋 We release the inference code and model weights of SkyReels-A1. [Download](https://huggingface.co/Skywork/SkyReels-A1)
57 | * Feb 18, 2025: 🎉 We have made our technical report available as open source. [Read](https://skyworkai.github.io/skyreels-a1.github.io/report.pdf)
58 | * Feb 18, 2025: 🔥 Our online demo of LipSync is available on SkyReels now! Try out on [LipSync](https://www.skyreels.ai/home/tools/lip-sync?refer=navbar) .
59 | * Feb 18, 2025: 🔥 We have open-sourced I2V video generation model [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model.
60 |
61 | ## 📑 TODO List
62 | - [x] Checkpoints
63 | - [x] Inference Code
64 | - [x] Web Demo (Gradio)
65 | - [x] Audio-driven Portrait Image Animation Pipeline
66 | - [x] Inference Code for Long Videos
67 | - [ ] User-Level GPU Inference on RTX4090
68 | - [ ] ComfyUI
69 |
70 |
71 | ## Getting Started 🏁
72 |
73 | ### 1. Clone the code and prepare the environment 🛠️
74 | First git clone the repository with code:
75 | ```bash
76 | git clone https://github.com/SkyworkAI/SkyReels-A1.git
77 | cd SkyReels-A1
78 |
79 | # create env using conda
80 | conda create -n skyreels-a1 python=3.10
81 | conda activate skyreels-a1
82 | ```
83 | Then, install the remaining dependencies:
84 | ```bash
85 | pip install -r requirements.txt
86 | ```
87 |
88 |
89 | ### 2. Download pretrained weights 📥
90 | You can download the pretrained weights is from HuggingFace:
91 | ```bash
92 | # !pip install -U "huggingface_hub[cli]"
93 | huggingface-cli download Skywork/SkyReels-A1 --local-dir local_path --exclude "*.git*" "README.md" "docs"
94 | ```
95 |
96 | The FLAME, mediapipe, and smirk models are located in the SkyReels-A1/extra_models folder.
97 |
98 | The directory structure of our SkyReels-A1 code is formulated as:
99 | ```text
100 | pretrained_models
101 | ├── FLAME
102 | ├── SkyReels-A1-5B
103 | │ ├── pose_guider
104 | │ ├── scheduler
105 | │ ├── tokenizer
106 | │ ├── siglip-so400m-patch14-384
107 | │ ├── transformer
108 | │ ├── vae
109 | │ └── text_encoder
110 | ├── mediapipe
111 | └── smirk
112 |
113 | ```
114 |
115 | #### Download DiffposeTalk assets and pretrained weights (For Audio-driven)
116 |
117 | - We use [diffposetalk](https://github.com/DiffPoseTalk/DiffPoseTalk/tree/main) to generate flame coefficients from audio, thereby constructing motion signals.
118 |
119 | - Download the diffposetalk code and follow its README to download the weights and related data.
120 |
121 | - Then place them in the specified directory.
122 |
123 | ```bash
124 | cp -r ${diffposetalk_root}/style pretrained_models/diffposetalk
125 | cp ${diffposetalk_root}/experiments/DPT/head-SA-hubert-WM/checkpoints/iter_0110000.pt pretrained_models/diffposetalk
126 | cp ${diffposetalk_root}/datasets/HDTF_TFHP/lmdb/stats_train.npz pretrained_models/diffposetalk
127 | ```
128 |
129 | - Or you can download style files from [link](https://drive.google.com/file/d/1XT426b-jt7RUkRTYsjGvG-wS4Jed2U1T/view?usp=sharing) and stats_train.npz from [link](https://drive.google.com/file/d/1_I5XRzkMP7xULCSGVuaN8q1Upplth9xR/view?usp=sharing).
130 |
131 | ```text
132 | pretrained_models
133 | ├── FLAME
134 | ├── SkyReels-A1-5B
135 | ├── mediapipe
136 | ├── diffposetalk
137 | │ ├── style
138 | │ ├── iter_0110000.pt
139 | │ ├── stats_train.npz
140 | └── smirk
141 |
142 | ```
143 |
144 | #### Download Frame interpolation Model pretrained weights (For Long Video Inference and Dynamic Resolution)
145 |
146 | - We use [FILM](https://github.com/dajes/frame-interpolation-pytorch) to generate transition frames, making the video transitions smoother (Set `use_interpolation` to True).
147 |
148 | - Download [film_net_fp16.pt](https://github.com/dajes/frame-interpolation-pytorch/releases), and place it in the specified directory.
149 |
150 | ```text
151 | pretrained_models
152 | ├── FLAME
153 | ├── SkyReels-A1-5B
154 | ├── mediapipe
155 | ├── diffposetalk
156 | ├── film_net
157 | │ ├── film_net_fp16.pt
158 | └── smirk
159 | ```
160 |
161 |
162 | ### 3. Inference 🚀
163 | You can simply run the inference scripts as:
164 | ```bash
165 | python inference.py
166 |
167 | # inference audio to video
168 | python inference_audio.py
169 | ```
170 |
171 | If the script runs successfully, you will get an output mp4 file. This file includes the following results: driving video, input image or video, and generated result.
172 |
173 | #### Long Video Inference
174 |
175 | Now, you can run the long video inference scripts to obtain portrait animation of any length:
176 | ```bash
177 | python inference_long_video.py
178 |
179 | # inference audio to video
180 | python inference_audio_long_video.py
181 | ```
182 |
183 | #### Dynamic Resolution
184 |
185 | All inference scripts now support dynamic resolution, simply set `target_fps` to any desired fps, recommended fps include: 12fps (Native), 24fps, 48fps, 60fps, other settings such as 25fps and 30fps may cause unstable frame rates.
186 |
187 |
188 | ## Gradio Interface 🤗
189 |
190 | We provide a [Gradio](https://huggingface.co/docs/hub/spaces-sdks-gradio) interface for a better experience, just run by:
191 |
192 | ```bash
193 | python app.py
194 | ```
195 |
196 | The graphical interactive interface is shown as below:
197 |
198 | 
199 |
200 |
201 | ## Metric Evaluation 👓
202 |
203 | We also provide all scripts for automatically calculating the metrics, including SimFace, FID, and L1 distance between expression and motion, reported in the paper.
204 |
205 | All codes can be found in the ```eval``` folder. After setting the video result path, run the following commands in sequence:
206 |
207 | ```bash
208 | python arc_score.py
209 | python expression_score.py
210 | python pose_score.py
211 | ```
212 |
213 |
214 | ## Acknowledgements 💐
215 | We would like to thank the contributors of [CogvideoX](https://github.com/THUDM/CogVideo), [finetrainers](https://github.com/a-r-r-o-w/finetrainers) and [DiffPoseTalk](https://github.com/DiffPoseTalk/DiffPoseTalk)repositories, for their open research and contributions.
216 |
217 | ## Citation 💖
218 | If you find SkyReels-A1 useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
219 | ```bibtex
220 | @article{qiu2025skyreels,
221 | title={Skyreels-a1: Expressive portrait animation in video diffusion transformers},
222 | author={Qiu, Di and Fei, Zhengcong and Wang, Rui and Bai, Jialin and Yu, Changqian and Fan, Mingyuan and Chen, Guibin and Wen, Xiang},
223 | journal={arXiv preprint arXiv:2502.10841},
224 | year={2025}
225 | }
226 | ```
227 |
228 | ## Star History
229 |
230 | [](https://www.star-history.com/#SkyworkAI/SkyReels-A1&Date)
231 |
232 |
233 |
234 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from scripts.demo import init_model, generate_video, crop_and_resize
3 | import os
4 | import os.path as osp
5 | import stat
6 | from datetime import datetime
7 | import torch
8 | import numpy as np
9 | from diffusers.utils import export_to_video, load_image
10 |
11 | os.environ['GRADIO_TEMP_DIR'] = 'tmp'
12 |
13 | example_portrait_dir = "assets/ref_images"
14 | example_video_dir = "assets/driving_video"
15 |
16 |
17 | pipe, face_helper, processor, lmk_extractor, vis = init_model()
18 | # Gradio interface using Interface
19 | with gr.Blocks() as demo:
20 | gr.Markdown("""
21 |
22 | SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers
23 |
24 |
29 | """)
30 |
31 | with gr.Row(): # 创建一个水平排列的行
32 | with gr.Accordion(open=True, label="Portrait Image"):
33 | image_input = gr.Image(type="filepath")
34 | gr.Examples(
35 | examples=[
36 | [osp.join(example_portrait_dir, "1.png")],
37 | [osp.join(example_portrait_dir, "2.png")],
38 | [osp.join(example_portrait_dir, "3.png")],
39 | [osp.join(example_portrait_dir, "4.png")],
40 | [osp.join(example_portrait_dir, "5.png")],
41 | [osp.join(example_portrait_dir, "6.png")],
42 | [osp.join(example_portrait_dir, "7.png")],
43 | [osp.join(example_portrait_dir, "8.png")],
44 | ],
45 | inputs=[image_input],
46 | cache_examples=False,
47 | )
48 | with gr.Accordion(open=True, label="Driving Video"):
49 | control_video_input = gr.Video()
50 | gr.Examples(
51 | examples=[
52 | [osp.join(example_video_dir, "1.mp4")],
53 | [osp.join(example_video_dir, "2.mp4")],
54 | [osp.join(example_video_dir, "3.mp4")],
55 | [osp.join(example_video_dir, "4.mp4")],
56 | [osp.join(example_video_dir, "5.mp4")],
57 | [osp.join(example_video_dir, "6.mp4")],
58 | [osp.join(example_video_dir, "7.mp4")],
59 | [osp.join(example_video_dir, "8.mp4")],
60 | ],
61 | inputs=[control_video_input],
62 | cache_examples=False,
63 | )
64 |
65 | def face_check(image_path):
66 | image = load_image(image=image_path)
67 | image = crop_and_resize(image, 480, 720)
68 |
69 | with torch.no_grad():
70 | face_helper.clean_all()
71 | face_helper.read_image(np.array(image)[:, :, ::-1])
72 | face_helper.get_face_landmarks_5(only_center_face=True)
73 | face_helper.align_warp_face()
74 | if len(face_helper.cropped_faces) == 0:
75 | return False
76 | face = face_helper.det_faces
77 | face_w = int(face[2] - face[0])
78 | if face_w < 50:
79 | return False
80 | return True
81 |
82 |
83 | def gradio_generate_video(control_video_path, image_path, progress=gr.Progress(track_tqdm=True)):
84 | try:
85 | save_dir = "./outputs/"
86 | if not os.path.exists(save_dir):
87 | os.makedirs(save_dir, exist_ok=True)
88 | current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
89 | save_path = os.path.join(save_dir, f"generated_video_{current_time}.mp4")
90 | print(control_video_path, image_path)
91 |
92 | face = face_check(image_path)
93 | if face == False:
94 | return "Face too small or no face.", None, None
95 |
96 | generate_video(
97 | pipe,
98 | face_helper,
99 | processor,
100 | lmk_extractor,
101 | vis,
102 | control_video_path=control_video_path,
103 | image_path=image_path,
104 | save_path=save_path,
105 | guidance_scale=3,
106 | seed=43,
107 | num_inference_steps=20,
108 | sample_size=[480, 720],
109 | max_frame_num=49,
110 | )
111 |
112 | print("finished.")
113 | print(save_path)
114 | if not os.path.exists(save_path):
115 | print("Error: Video file not found")
116 | return "Error: Video file not found", None
117 |
118 | video_update = gr.update(visible=True, value=save_path)
119 | return "Video generated successfully.", save_path, video_update
120 | except Exception as e:
121 | return f"Error occurred: {str(e)}", None, None
122 |
123 |
124 | generate_button = gr.Button("Generate Video")
125 | output_text = gr.Textbox(label="Output")
126 | output_video = gr.Video(label="Output Video")
127 | with gr.Row():
128 | download_video_button = gr.File(label="📥 Download Video", visible=False)
129 |
130 | generate_button.click(
131 | gradio_generate_video,
132 | inputs=[
133 | control_video_input,
134 | image_input
135 | ],
136 | outputs=[output_text, output_video, download_video_button], # 更新输出以包含视频
137 | show_progress=True,
138 | )
139 |
140 |
141 | if __name__ == "__main__":
142 | # demo.queue(concurrency_count=8)
143 | demo.launch(share=True, enable_queue=True)
144 |
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/demo.gif
--------------------------------------------------------------------------------
/assets/driving_audio/1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/1.wav
--------------------------------------------------------------------------------
/assets/driving_audio/2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/2.wav
--------------------------------------------------------------------------------
/assets/driving_audio/3.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/3.wav
--------------------------------------------------------------------------------
/assets/driving_audio/4.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/4.wav
--------------------------------------------------------------------------------
/assets/driving_audio/5.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/5.wav
--------------------------------------------------------------------------------
/assets/driving_audio/6.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/6.wav
--------------------------------------------------------------------------------
/assets/driving_video/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/.DS_Store
--------------------------------------------------------------------------------
/assets/driving_video/1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/1.mp4
--------------------------------------------------------------------------------
/assets/driving_video/2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/2.mp4
--------------------------------------------------------------------------------
/assets/driving_video/3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/3.mp4
--------------------------------------------------------------------------------
/assets/driving_video/4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/4.mp4
--------------------------------------------------------------------------------
/assets/driving_video/5.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/5.mp4
--------------------------------------------------------------------------------
/assets/driving_video/6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/6.mp4
--------------------------------------------------------------------------------
/assets/driving_video/7.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/7.mp4
--------------------------------------------------------------------------------
/assets/driving_video/8.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/8.mp4
--------------------------------------------------------------------------------
/assets/gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/gradio.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/logo.png
--------------------------------------------------------------------------------
/assets/ref_images/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/1.png
--------------------------------------------------------------------------------
/assets/ref_images/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/10.png
--------------------------------------------------------------------------------
/assets/ref_images/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/11.png
--------------------------------------------------------------------------------
/assets/ref_images/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/12.png
--------------------------------------------------------------------------------
/assets/ref_images/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/13.png
--------------------------------------------------------------------------------
/assets/ref_images/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/14.png
--------------------------------------------------------------------------------
/assets/ref_images/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/15.png
--------------------------------------------------------------------------------
/assets/ref_images/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/16.png
--------------------------------------------------------------------------------
/assets/ref_images/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/17.png
--------------------------------------------------------------------------------
/assets/ref_images/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/18.png
--------------------------------------------------------------------------------
/assets/ref_images/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/19.png
--------------------------------------------------------------------------------
/assets/ref_images/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/2.png
--------------------------------------------------------------------------------
/assets/ref_images/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/20.png
--------------------------------------------------------------------------------
/assets/ref_images/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/3.png
--------------------------------------------------------------------------------
/assets/ref_images/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/4.png
--------------------------------------------------------------------------------
/assets/ref_images/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/5.png
--------------------------------------------------------------------------------
/assets/ref_images/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/6.png
--------------------------------------------------------------------------------
/assets/ref_images/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/7.png
--------------------------------------------------------------------------------
/assets/ref_images/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/8.png
--------------------------------------------------------------------------------
/diffposetalk/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class PositionalEncoding(nn.Module):
9 | def __init__(self, d_model, dropout=0.1, max_len=600):
10 | super().__init__()
11 | self.dropout = nn.Dropout(p=dropout)
12 | # vanilla sinusoidal encoding
13 | pe = torch.zeros(max_len, d_model)
14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16 | pe[:, 0::2] = torch.sin(position * div_term)
17 | pe[:, 1::2] = torch.cos(position * div_term)
18 | pe = pe.unsqueeze(0)
19 | self.register_buffer('pe', pe)
20 |
21 | def forward(self, x):
22 | x = x + self.pe[:, x.shape[1], :]
23 | return self.dropout(x)
24 |
25 |
26 | def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'):
27 | mask = torch.ones(T, S)
28 | for i in range(T):
29 | mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
30 | return (mask == 1).to(device=device)
31 |
32 |
33 | def pad_audio(audio, audio_unit=320, pad_threshold=80):
34 | batch_size, audio_len = audio.shape
35 | n_units = audio_len // audio_unit
36 | side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
37 | if side_len >= 0:
38 | reflect_len = side_len // 2
39 | replicate_len = side_len % 2
40 | if reflect_len > 0:
41 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
42 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
43 | if replicate_len > 0:
44 | audio = F.pad(audio, (1, 1), mode='replicate')
45 |
46 | return audio
47 |
--------------------------------------------------------------------------------
/diffposetalk/diffposetalk.py:
--------------------------------------------------------------------------------
1 | import math
2 | import tempfile
3 | import warnings
4 | from pathlib import Path
5 |
6 | import cv2
7 | import librosa
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from tqdm import tqdm
12 | from pydantic import BaseModel
13 |
14 | from .diff_talking_head import DiffTalkingHead
15 | from .utils import NullableArgs, coef_dict_to_vertices, get_coef_dict
16 | from .utils.media import combine_video_and_audio, convert_video, reencode_audio
17 |
18 | warnings.filterwarnings('ignore', message='PySoundFile failed. Trying audioread instead.')
19 |
20 | class DiffPoseTalkConfig(BaseModel):
21 | no_context_audio_feat: bool = False
22 | model_path: str = "pretrained_models/diffposetalk/iter_0110000.pt" # DPT/head-SA-hubert-WM
23 | coef_stats: str = "pretrained_models/diffposetalk/stats_train.npz"
24 | style_path: str = "pretrained_models/diffposetalk/style/L4H4-T0.1-BS32/iter_0034000/normal.npy"
25 | dynamic_threshold_ratio: float = 0.99
26 | dynamic_threshold_min: float = 1.0
27 | dynamic_threshold_max: float = 4.0
28 | scale_audio: float = 1.15
29 | scale_style: float = 3.0
30 |
31 | class DiffPoseTalk:
32 | def __init__(self, config: DiffPoseTalkConfig = DiffPoseTalkConfig(), device="cuda"):
33 | self.cfg = config
34 | self.device = device
35 |
36 | self.no_context_audio_feat = self.cfg.no_context_audio_feat
37 | model_data = torch.load(self.cfg.model_path, map_location=self.device)
38 |
39 | self.model_args = NullableArgs(model_data['args'])
40 | self.model = DiffTalkingHead(self.model_args, self.device)
41 | model_data['model'].pop('denoising_net.TE.pe')
42 | self.model.load_state_dict(model_data['model'], strict=False)
43 | self.model.to(self.device)
44 | self.model.eval()
45 |
46 | self.use_indicator = self.model_args.use_indicator
47 | self.rot_repr = self.model_args.rot_repr
48 | self.predict_head_pose = not self.model_args.no_head_pose
49 | if self.model.use_style:
50 | style_dir = Path(self.model_args.style_enc_ckpt)
51 | style_dir = Path(*style_dir.with_suffix('').parts[-3::2])
52 | self.style_dir = style_dir
53 |
54 | # sequence
55 | self.n_motions = self.model_args.n_motions
56 | self.n_prev_motions = self.model_args.n_prev_motions
57 | self.fps = self.model_args.fps
58 | self.audio_unit = 16000. / self.fps # num of samples per frame
59 | self.n_audio_samples = round(self.audio_unit * self.n_motions)
60 | self.pad_mode = self.model_args.pad_mode
61 |
62 | self.coef_stats = dict(np.load(self.cfg.coef_stats))
63 | self.coef_stats = {k: torch.from_numpy(v).to(self.device) for k, v in self.coef_stats.items()}
64 |
65 | if self.cfg.dynamic_threshold_ratio > 0:
66 | self.dynamic_threshold = (self.cfg.dynamic_threshold_ratio, self.cfg.dynamic_threshold_min,
67 | self.cfg.dynamic_threshold_max)
68 | else:
69 | self.dynamic_threshold = None
70 |
71 |
72 | def infer_from_file(self, audio_path, shape_coef):
73 | n_repetitions = 1
74 | cfg_mode = None
75 | cfg_cond = self.model.guiding_conditions
76 | cfg_scale = []
77 | for cond in cfg_cond:
78 | if cond == 'audio':
79 | cfg_scale.append(self.cfg.scale_audio)
80 | elif cond == 'style':
81 | cfg_scale.append(self.cfg.scale_style)
82 |
83 | coef_dict = self.infer_coeffs(audio_path, shape_coef, self.cfg.style_path, n_repetitions,
84 | cfg_mode, cfg_cond, cfg_scale, include_shape=True)
85 | return coef_dict
86 |
87 | @torch.no_grad()
88 | def infer_coeffs(self, audio, shape_coef, style_feat=None, n_repetitions=1,
89 | cfg_mode=None, cfg_cond=None, cfg_scale=1.15, include_shape=False):
90 | # Returns dict[str, (n_repetitions, L, *)]
91 | # Step 1: Preprocessing
92 | # Preprocess audio
93 | if isinstance(audio, (str, Path)):
94 | audio, _ = librosa.load(audio, sr=16000, mono=True)
95 | if isinstance(audio, np.ndarray):
96 | audio = torch.from_numpy(audio).to(self.device)
97 | assert audio.ndim == 1, 'Audio must be 1D tensor.'
98 | audio_mean, audio_std = torch.mean(audio), torch.std(audio)
99 | audio = (audio - audio_mean) / (audio_std + 1e-5)
100 |
101 | # Preprocess shape coefficient
102 | if isinstance(shape_coef, (str, Path)):
103 | shape_coef = np.load(shape_coef)
104 | if not isinstance(shape_coef, np.ndarray):
105 | shape_coef = shape_coef['shape']
106 | if isinstance(shape_coef, np.ndarray):
107 | shape_coef = torch.from_numpy(shape_coef).float().to(self.device)
108 | assert shape_coef.ndim <= 2, 'Shape coefficient must be 1D or 2D tensor.'
109 | if shape_coef.ndim > 1:
110 | # use the first frame as the shape coefficient
111 | shape_coef = shape_coef[0]
112 | original_shape_coef = shape_coef.clone()
113 | if self.coef_stats is not None:
114 | shape_coef = (shape_coef - self.coef_stats['shape_mean']) / self.coef_stats['shape_std']
115 | shape_coef = shape_coef.unsqueeze(0).expand(n_repetitions, -1)
116 |
117 | # Preprocess style feature if given
118 | if style_feat is not None:
119 | assert self.model.use_style
120 | if isinstance(style_feat, (str, Path)):
121 | style_feat = Path(style_feat)
122 | if not style_feat.exists() and not style_feat.is_absolute():
123 | style_feat = style_feat.parent / self.style_dir / style_feat.name
124 | style_feat = np.load(style_feat)
125 | if not isinstance(style_feat, np.ndarray):
126 | style_feat = style_feat['style']
127 | if isinstance(style_feat, np.ndarray):
128 | style_feat = torch.from_numpy(style_feat).float().to(self.device)
129 | assert style_feat.ndim == 1, 'Style feature must be 1D tensor.'
130 | style_feat = style_feat.unsqueeze(0).expand(n_repetitions, -1)
131 |
132 | # Step 2: Predict motion coef
133 | # divide into synthesize units and do synthesize
134 | clip_len = int(len(audio) / 16000 * self.fps)
135 | stride = self.n_motions
136 | if clip_len <= self.n_motions:
137 | n_subdivision = 1
138 | else:
139 | n_subdivision = math.ceil(clip_len / stride)
140 |
141 | # Prepare audio input
142 | n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio)
143 | n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
144 | if n_padding_audio_samples > 0:
145 | if self.pad_mode == 'zero':
146 | padding_value = 0
147 | elif self.pad_mode == 'replicate':
148 | padding_value = audio[-1]
149 | else:
150 | raise ValueError(f'Unknown pad mode: {self.pad_mode}')
151 | audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value)
152 |
153 | if not self.no_context_audio_feat:
154 | audio_feat = self.model.extract_audio_feature(audio.unsqueeze(0), self.n_motions * n_subdivision)
155 |
156 | # Generate `self.n_motions` new frames at one time, and use the last `self.n_prev_motions` frames
157 | # from the previous generation as the initial motion condition
158 | coef_list = []
159 | for i in range(0, n_subdivision):
160 | start_idx = i * stride
161 | end_idx = start_idx + self.n_motions
162 | indicator = torch.ones((n_repetitions, self.n_motions)).to(self.device) if self.use_indicator else None
163 | if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0:
164 | indicator[:, -n_padding_frames:] = 0
165 | if not self.no_context_audio_feat:
166 | audio_in = audio_feat[:, start_idx:end_idx].expand(n_repetitions, -1, -1)
167 | else:
168 | audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0)
169 |
170 | # generate motion coefficients
171 | if i == 0:
172 | # -> (N, L, d_motion=n_code_per_frame * code_dim)
173 | motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
174 | indicator=indicator, cfg_mode=cfg_mode,
175 | cfg_cond=cfg_cond, cfg_scale=cfg_scale,
176 | dynamic_threshold=self.dynamic_threshold)
177 | else:
178 | motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
179 | prev_motion_feat, prev_audio_feat, noise,
180 | indicator=indicator, cfg_mode=cfg_mode,
181 | cfg_cond=cfg_cond, cfg_scale=cfg_scale,
182 | dynamic_threshold=self.dynamic_threshold)
183 | prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone()
184 | prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:]
185 |
186 | motion_coef = motion_feat
187 | if i == n_subdivision - 1 and n_padding_frames > 0:
188 | motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
189 | coef_list.append(motion_coef)
190 |
191 | motion_coef = torch.cat(coef_list, dim=1)
192 |
193 | # Step 3: restore to coef dict
194 | coef_dict = get_coef_dict(motion_coef, None, self.coef_stats, self.predict_head_pose, self.rot_repr)
195 | if include_shape:
196 | coef_dict['shape'] = original_shape_coef[None, None].expand(n_repetitions, motion_coef.shape[1], -1)
197 | return self.coef_to_a1_format(coef_dict)
198 |
199 | def coef_to_a1_format(self, coef_dict):
200 | n_frames = coef_dict['exp'].shape[1]
201 | new_coef_dict = []
202 | for i in range(n_frames):
203 |
204 | new_coef_dict.append({
205 | "expression_params": coef_dict["exp"][0, i:i+1],
206 | "jaw_params": coef_dict["pose"][0, i:i+1, 3:],
207 | "eye_pose_params": torch.zeros(1, 6).type_as(coef_dict["pose"]),
208 | "pose_params": coef_dict["pose"][0, i:i+1, :3],
209 | "eyelid_params": None
210 | })
211 | return new_coef_dict
212 |
213 |
214 |
215 |
216 |
217 | @staticmethod
218 | def _pad_coef(coef, n_frames, elem_ndim=1):
219 | if coef.ndim == elem_ndim:
220 | coef = coef[None]
221 | elem_shape = coef.shape[1:]
222 | if coef.shape[0] >= n_frames:
223 | new_coef = coef[:n_frames]
224 | else:
225 | # repeat the last coef frame
226 | new_coef = torch.cat([coef, coef[[-1]].expand(n_frames - coef.shape[0], *elem_shape)], dim=0)
227 | return new_coef # (n_frames, *elem_shape)
228 |
229 |
--------------------------------------------------------------------------------
/diffposetalk/hubert.py:
--------------------------------------------------------------------------------
1 | from transformers import HubertModel
2 | from transformers.modeling_outputs import BaseModelOutput
3 |
4 | from .wav2vec2 import linear_interpolation
5 |
6 | _CONFIG_FOR_DOC = 'HubertConfig'
7 |
8 |
9 | class HubertModel(HubertModel):
10 | def __init__(self, config):
11 | super().__init__(config)
12 |
13 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
14 | output_hidden_states=None, return_dict=None, frame_num=None):
15 | self.config.output_attentions = True
16 |
17 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
18 | output_hidden_states = (
19 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
20 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
21 |
22 | extract_features = self.feature_extractor(input_values) # (N, C, L)
23 | # Resample the audio feature @ 50 fps to `output_fps`.
24 | if frame_num is not None:
25 | extract_features_len = round(frame_num * 50 / output_fps)
26 | extract_features = extract_features[:, :, :extract_features_len]
27 | extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
28 | extract_features = extract_features.transpose(1, 2) # (N, L, C)
29 |
30 | if attention_mask is not None:
31 | # compute reduced attention_mask corresponding to feature vectors
32 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
33 |
34 | hidden_states = self.feature_projection(extract_features)
35 | hidden_states = self._mask_hidden_states(hidden_states)
36 |
37 | encoder_outputs = self.encoder(
38 | hidden_states,
39 | attention_mask=attention_mask,
40 | output_attentions=output_attentions,
41 | output_hidden_states=output_hidden_states,
42 | return_dict=return_dict,
43 | )
44 |
45 | hidden_states = encoder_outputs[0]
46 |
47 | if not return_dict:
48 | return (hidden_states,) + encoder_outputs[1:]
49 |
50 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
51 | attentions=encoder_outputs.attentions, )
52 |
--------------------------------------------------------------------------------
/diffposetalk/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 |
--------------------------------------------------------------------------------
/diffposetalk/utils/media.py:
--------------------------------------------------------------------------------
1 | import shlex
2 | import subprocess
3 | from pathlib import Path
4 |
5 |
6 | def combine_video_and_audio(video_file, audio_file, output, quality=17, copy_audio=True):
7 | audio_codec = '-c:a copy' if copy_audio else ''
8 | cmd = f'ffmpeg -i {video_file} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
9 | f'{audio_codec} -fflags +shortest -y -hide_banner -loglevel error {output}'
10 | assert subprocess.run(shlex.split(cmd)).returncode == 0
11 |
12 |
13 | def combine_frames_and_audio(frame_files, audio_file, fps, output, quality=17):
14 | cmd = f'ffmpeg -framerate {fps} -i {frame_files} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
15 | f'-c:a copy -fflags +shortest -y -hide_banner -loglevel error {output}'
16 | assert subprocess.run(shlex.split(cmd)).returncode == 0
17 |
18 |
19 | def convert_video(video_file, output, quality=17):
20 | cmd = f'ffmpeg -i {video_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
21 | f'-fflags +shortest -y -hide_banner -loglevel error {output}'
22 | assert subprocess.run(shlex.split(cmd)).returncode == 0
23 |
24 |
25 | def reencode_audio(audio_file, output):
26 | cmd = f'ffmpeg -i {audio_file} -y -hide_banner -loglevel error {output}'
27 | assert subprocess.run(shlex.split(cmd)).returncode == 0
28 |
29 |
30 | def extract_frames(filename, output_dir, quality=1):
31 | output_dir = Path(output_dir)
32 | output_dir.mkdir(parents=True, exist_ok=True)
33 | cmd = f'ffmpeg -i {filename} -qmin 1 -qscale:v {quality} -y -start_number 0 -hide_banner -loglevel error ' \
34 | f'{output_dir / "%06d.jpg"}'
35 | assert subprocess.run(shlex.split(cmd)).returncode == 0
36 |
--------------------------------------------------------------------------------
/diffposetalk/utils/renderer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import cv2
5 | import kiui.mesh
6 | import numpy as np
7 |
8 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # osmesa or egl
9 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
10 | import pyrender
11 | import trimesh
12 | # from psbody.mesh import Mesh
13 |
14 |
15 | class MeshRenderer:
16 | def __init__(self, size, fov=16 / 180 * np.pi, camera_pose=None, light_pose=None, black_bg=False):
17 | # Camera
18 | self.frustum = {'near': 0.01, 'far': 3.0}
19 | self.camera = pyrender.PerspectiveCamera(yfov=fov, znear=self.frustum['near'],
20 | zfar=self.frustum['far'], aspectRatio=1.0)
21 |
22 | # Material
23 | self.primitive_material = pyrender.material.MetallicRoughnessMaterial(
24 | alphaMode='BLEND',
25 | baseColorFactor=[0.3, 0.3, 0.3, 1.0],
26 | metallicFactor=0.8,
27 | roughnessFactor=0.8
28 | )
29 |
30 | # Lighting
31 | light_color = np.array([1., 1., 1.])
32 | self.light = pyrender.DirectionalLight(color=light_color, intensity=2)
33 | self.light_angle = np.pi / 6.0
34 |
35 | # Scene
36 | self.scene = None
37 | self._init_scene(black_bg)
38 |
39 | # add camera and lighting
40 | self._init_camera(camera_pose)
41 | self._init_lighting(light_pose)
42 |
43 | # Renderer
44 | self.renderer = pyrender.OffscreenRenderer(*size, point_size=1.0)
45 |
46 | def _init_scene(self, black_bg=False):
47 | if black_bg:
48 | self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
49 | else:
50 | self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
51 |
52 | def _init_camera(self, camera_pose=None):
53 | if camera_pose is None:
54 | camera_pose = np.eye(4)
55 | camera_pose[:3, 3] = np.array([0, 0, 1])
56 | self.camera_pose = camera_pose.copy()
57 | self.camera_node = self.scene.add(self.camera, pose=camera_pose)
58 |
59 | def _init_lighting(self, light_pose=None):
60 | if light_pose is None:
61 | light_pose = np.eye(4)
62 | light_pose[:3, 3] = np.array([0, 0, 1])
63 | self.light_pose = light_pose.copy()
64 |
65 | light_poses = self._get_light_poses(self.light_angle, light_pose)
66 | self.light_nodes = [self.scene.add(self.light, pose=light_pose) for light_pose in light_poses]
67 |
68 | def set_camera_pose(self, camera_pose):
69 | self.camera_pose = camera_pose.copy()
70 | self.scene.set_pose(self.camera_node, pose=camera_pose)
71 |
72 | def set_lighting_pose(self, light_pose):
73 | self.light_pose = light_pose.copy()
74 |
75 | light_poses = self._get_light_poses(self.light_angle, light_pose)
76 | for light_node, light_pose in zip(self.light_nodes, light_poses):
77 | self.scene.set_pose(light_node, pose=light_pose)
78 |
79 | def render_mesh(self, v, f, t_center, rot=np.zeros(3), tex_img=None, tex_uv=None,
80 | camera_pose=None, light_pose=None):
81 | # Prepare mesh
82 | v[:] = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center
83 | if tex_img is not None:
84 | tex = pyrender.Texture(source=tex_img, source_channels='RGB')
85 | tex_material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex)
86 | from kiui.mesh import Mesh
87 | import torch
88 | mesh = Mesh(
89 | v=torch.from_numpy(v),
90 | f=torch.from_numpy(f),
91 | vt=tex_uv['vt'],
92 | ft=tex_uv['ft']
93 | )
94 | with tempfile.NamedTemporaryFile(suffix='.obj') as f:
95 | mesh.write_obj(f.name)
96 | tri_mesh = trimesh.load(f.name, process=False)
97 | return tri_mesh
98 | # tri_mesh = self._pyrender_mesh_workaround(mesh)
99 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=tex_material)
100 | else:
101 | tri_mesh = trimesh.Trimesh(vertices=v, faces=f)
102 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=self.primitive_material, smooth=True)
103 | mesh_node = self.scene.add(render_mesh, pose=np.eye(4))
104 |
105 | # Change camera and lighting pose if necessary
106 | if camera_pose is not None:
107 | self.set_camera_pose(camera_pose)
108 | if light_pose is not None:
109 | self.set_lighting_pose(light_pose)
110 |
111 | # Render
112 | flags = pyrender.RenderFlags.SKIP_CULL_FACES
113 | color, depth = self.renderer.render(self.scene, flags=flags)
114 |
115 | # Remove mesh
116 | self.scene.remove_node(mesh_node)
117 |
118 | return color, depth
119 |
120 | @staticmethod
121 | def _get_light_poses(light_angle, light_pose):
122 | light_poses = []
123 | init_pos = light_pose[:3, 3].copy()
124 |
125 | light_poses.append(light_pose.copy())
126 |
127 | light_pose[:3, 3] = cv2.Rodrigues(np.array([light_angle, 0, 0]))[0].dot(init_pos)
128 | light_poses.append(light_pose.copy())
129 |
130 | light_pose[:3, 3] = cv2.Rodrigues(np.array([-light_angle, 0, 0]))[0].dot(init_pos)
131 | light_poses.append(light_pose.copy())
132 |
133 | light_pose[:3, 3] = cv2.Rodrigues(np.array([0, -light_angle, 0]))[0].dot(init_pos)
134 | light_poses.append(light_pose.copy())
135 |
136 | light_pose[:3, 3] = cv2.Rodrigues(np.array([0, light_angle, 0]))[0].dot(init_pos)
137 | light_poses.append(light_pose.copy())
138 |
139 | return light_poses
140 |
141 | @staticmethod
142 | def _pyrender_mesh_workaround(mesh):
143 | # Workaround as pyrender requires number of vertices and uv coordinates to be the same
144 | with tempfile.NamedTemporaryFile(suffix='.obj') as f:
145 | mesh.write_obj(f.name)
146 | tri_mesh = trimesh.load(f.name, process=False)
147 | return tri_mesh
148 |
--------------------------------------------------------------------------------
/diffposetalk/wav2vec2.py:
--------------------------------------------------------------------------------
1 | from packaging import version
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import transformers
8 | from transformers import Wav2Vec2Model
9 | from transformers.modeling_outputs import BaseModelOutput
10 |
11 | _CONFIG_FOR_DOC = 'Wav2Vec2Config'
12 |
13 |
14 | # the implementation of Wav2Vec2Model is borrowed from
15 | # https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
16 | # initialize our encoder with the pre-trained wav2vec 2.0 weights.
17 | def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
18 | attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
19 | bsz, all_sz = shape
20 | mask = np.full((bsz, all_sz), False)
21 |
22 | all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
23 | all_num_mask = max(min_masks, all_num_mask)
24 | mask_idcs = []
25 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None
26 | for i in range(bsz):
27 | if padding_mask is not None:
28 | sz = all_sz - padding_mask[i].long().sum().item()
29 | num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
30 | num_mask = max(min_masks, num_mask)
31 | else:
32 | sz = all_sz
33 | num_mask = all_num_mask
34 |
35 | lengths = np.full(num_mask, mask_length)
36 |
37 | if sum(lengths) == 0:
38 | lengths[0] = min(mask_length, sz - 1)
39 |
40 | min_len = min(lengths)
41 | if sz - min_len <= num_mask:
42 | min_len = sz - num_mask - 1
43 |
44 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
45 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
46 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
47 |
48 | min_len = min([len(m) for m in mask_idcs])
49 | for i, mask_idc in enumerate(mask_idcs):
50 | if len(mask_idc) > min_len:
51 | mask_idc = np.random.choice(mask_idc, min_len, replace=False)
52 | mask[i, mask_idc] = True
53 | return mask
54 |
55 |
56 | # linear interpolation layer
57 | def linear_interpolation(features, input_fps, output_fps, output_len=None):
58 | # features: (N, C, L)
59 | seq_len = features.shape[2] / float(input_fps)
60 | if output_len is None:
61 | output_len = int(seq_len * output_fps)
62 | output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
63 | return output_features
64 |
65 |
66 | class Wav2Vec2Model(Wav2Vec2Model):
67 | def __init__(self, config):
68 | super().__init__(config)
69 | self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
70 |
71 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
72 | output_hidden_states=None, return_dict=None, frame_num=None):
73 | self.config.output_attentions = True
74 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75 | output_hidden_states = (
76 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
77 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78 |
79 | hidden_states = self.feature_extractor(input_values) # (N, C, L)
80 | # Resample the audio feature @ 50 fps to `output_fps`.
81 | if frame_num is not None:
82 | hidden_states_len = round(frame_num * 50 / output_fps)
83 | hidden_states = hidden_states[:, :, :hidden_states_len]
84 | hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
85 | hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
86 |
87 | if attention_mask is not None:
88 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
89 | attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
90 | device=hidden_states.device)
91 | attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
92 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
93 |
94 | if self.is_old_version:
95 | hidden_states = self.feature_projection(hidden_states)
96 | else:
97 | hidden_states = self.feature_projection(hidden_states)[0]
98 |
99 | if self.config.apply_spec_augment and self.training:
100 | batch_size, sequence_length, hidden_size = hidden_states.size()
101 | if self.config.mask_time_prob > 0:
102 | mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
103 | self.config.mask_time_length, attention_mask=attention_mask,
104 | min_masks=2, )
105 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
106 | if self.config.mask_feature_prob > 0:
107 | mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
108 | self.config.mask_feature_length, )
109 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
110 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
111 | encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
112 | output_attentions=output_attentions, output_hidden_states=output_hidden_states,
113 | return_dict=return_dict, )
114 | hidden_states = encoder_outputs[0]
115 | if not return_dict:
116 | return (hidden_states,) + encoder_outputs[1:]
117 |
118 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
119 | attentions=encoder_outputs.attentions, )
120 |
--------------------------------------------------------------------------------
/eval/arc_score.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from insightface.app import FaceAnalysis
4 | from insightface.utils import face_align
5 | from PIL import Image
6 | from torchvision import models, transforms
7 | from curricularface import get_model
8 | import cv2
9 | import numpy as np
10 | import numpy
11 |
12 |
13 | def matrix_sqrt(matrix):
14 | eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
15 | sqrt_eigenvalues = torch.sqrt(torch.clamp(eigenvalues, min=0))
16 | sqrt_matrix = (eigenvectors * sqrt_eigenvalues).mm(eigenvectors.T)
17 | return sqrt_matrix
18 |
19 | def sample_video_frames(video_path, num_frames=16):
20 | cap = cv2.VideoCapture(video_path)
21 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
22 | frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
23 |
24 | frames = []
25 | for idx in frame_indices:
26 | cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
27 | ret, frame = cap.read()
28 | if ret:
29 | # print(frame.shape)
30 | #if frame.shape[1] > 1024:
31 | # frame = frame[:, 1440:, :]
32 | # print(frame.shape)
33 | frames.append(frame)
34 | cap.release()
35 | return frames
36 |
37 |
38 | def get_face_keypoints(face_model, image_bgr):
39 | face_info = face_model.get(image_bgr)
40 | if len(face_info) > 0:
41 | return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
42 | return None
43 |
44 | def load_image(image):
45 | img = image.convert('RGB')
46 | img = transforms.Resize((299, 299))(img) # Resize to Inception input size
47 | img = transforms.ToTensor()(img)
48 | return img.unsqueeze(0) # Add batch dimension
49 |
50 | def calculate_fid(real_activations, fake_activations, device="cuda"):
51 | real_activations_tensor = torch.tensor(real_activations).to(device)
52 | fake_activations_tensor = torch.tensor(fake_activations).to(device)
53 |
54 | mu1 = real_activations_tensor.mean(dim=0)
55 | sigma1 = torch.cov(real_activations_tensor.T)
56 | mu2 = fake_activations_tensor.mean(dim=0)
57 | sigma2 = torch.cov(fake_activations_tensor.T)
58 |
59 | ssdiff = torch.sum((mu1 - mu2) ** 2)
60 | covmean = matrix_sqrt(sigma1.mm(sigma2))
61 | if torch.is_complex(covmean):
62 | covmean = covmean.real
63 | fid = ssdiff + torch.trace(sigma1 + sigma2 - 2 * covmean)
64 | return fid.item()
65 |
66 | def batch_cosine_similarity(embedding_image, embedding_frames, device="cuda"):
67 | embedding_image = torch.tensor(embedding_image).to(device)
68 | embedding_frames = torch.tensor(embedding_frames).to(device)
69 | return torch.nn.functional.cosine_similarity(embedding_image, embedding_frames, dim=-1).cpu().numpy()
70 |
71 |
72 | def get_activations(images, model, batch_size=16):
73 | model.eval()
74 | activations = []
75 | with torch.no_grad():
76 | for i in range(0, len(images), batch_size):
77 | batch = images[i:i + batch_size]
78 | pred = model(batch)
79 | activations.append(pred)
80 | activations = torch.cat(activations, dim=0).cpu().numpy()
81 | if activations.shape[0] == 1:
82 | activations = np.repeat(activations, 2, axis=0)
83 | return activations
84 |
85 | def pad_np_bgr_image(np_image, scale=1.25):
86 | assert scale >= 1.0, "scale should be >= 1.0"
87 | pad_scale = scale - 1.0
88 | h, w = np_image.shape[:2]
89 | top = bottom = int(h * pad_scale)
90 | left = right = int(w * pad_scale)
91 | return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top)
92 |
93 |
94 | def process_image(face_model, image_path):
95 | if isinstance(image_path, str):
96 | np_faceid_image = np.array(Image.open(image_path).convert("RGB"))
97 | elif isinstance(image_path, numpy.ndarray):
98 | np_faceid_image = image_path
99 | else:
100 | raise TypeError("image_path should be a string or PIL.Image.Image object")
101 |
102 | image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
103 |
104 | face_info = get_face_keypoints(face_model, image_bgr)
105 | if face_info is None:
106 | padded_image, sub_coord = pad_np_bgr_image(image_bgr)
107 | face_info = get_face_keypoints(face_model, padded_image)
108 | if face_info is None:
109 | print("Warning: No face detected in the image. Continuing processing...")
110 | return None, None
111 | face_kps = face_info['kps']
112 | face_kps -= np.array(sub_coord)
113 | else:
114 | face_kps = face_info['kps']
115 | arcface_embedding = face_info['embedding']
116 | # print(face_kps)
117 | norm_face = face_align.norm_crop(image_bgr, landmark=face_kps, image_size=224)
118 | align_face = cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB)
119 |
120 | return align_face, arcface_embedding
121 |
122 | @torch.no_grad()
123 | def inference(face_model, img, device):
124 | img = cv2.resize(img, (112, 112))
125 | img = np.transpose(img, (2, 0, 1))
126 | img = torch.from_numpy(img).unsqueeze(0).float().to(device)
127 | img.div_(255).sub_(0.5).div_(0.5)
128 | embedding = face_model(img).detach().cpu().numpy()[0]
129 | return embedding / np.linalg.norm(embedding)
130 |
131 |
132 | def process_video(video_path, face_arc_model, face_cur_model, fid_model, arcface_image_embedding, cur_image_embedding, real_activations, device):
133 | video_frames = sample_video_frames(video_path, num_frames=16)
134 | #print(video_frames)
135 | # Initialize lists to store the scores
136 | cur_scores = []
137 | arc_scores = []
138 | fid_face = []
139 |
140 | for frame in video_frames:
141 | # Convert to RGB once at the beginning
142 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
143 |
144 | # Process the frame for ArcFace embeddings
145 | align_face_frame, arcface_frame_embedding = process_image(face_arc_model, frame_rgb)
146 |
147 | # Skip if alignment fails
148 | if align_face_frame is None:
149 | continue
150 |
151 | # Perform inference for current face model
152 | cur_embedding_frame = inference(face_cur_model, align_face_frame, device)
153 |
154 | # Compute cosine similarity for cur_score and arc_score in a compact manner
155 | cur_score = max(0.0, batch_cosine_similarity(cur_image_embedding, cur_embedding_frame, device=device).item())
156 | arc_score = max(0.0, batch_cosine_similarity(arcface_image_embedding, arcface_frame_embedding, device=device).item())
157 |
158 | # Process FID score
159 | align_face_frame_pil = Image.fromarray(align_face_frame)
160 | fake_image = load_image(align_face_frame_pil).to(device)
161 | fake_activations = get_activations(fake_image, fid_model)
162 | fid_score = calculate_fid(real_activations, fake_activations, device)
163 |
164 | # Collect scores
165 | fid_face.append(fid_score)
166 | cur_scores.append(cur_score)
167 | arc_scores.append(arc_score)
168 |
169 | # Aggregate results with default values for empty lists
170 | avg_cur_score = np.mean(cur_scores) if cur_scores else 0.0
171 | avg_arc_score = np.mean(arc_scores) if arc_scores else 0.0
172 | avg_fid_score = np.mean(fid_face) if fid_face else 0.0
173 |
174 | return avg_cur_score, avg_arc_score, avg_fid_score
175 |
176 |
177 |
178 | def main():
179 | device = "cuda"
180 | # data_path = "data/SkyActor"
181 | # data_path = "data/LivePotraits"
182 | # data_path = "data/Actor-One"
183 | data_path = "data/FollowYourEmoji"
184 | img_path = "/maindata/data/shared/public/rui.wang/act_review/ref_images"
185 | pre_tag = False
186 | mp4_list = os.listdir(data_path)
187 | print(mp4_list)
188 |
189 | img_list = []
190 | video_list = []
191 | for mp4 in mp4_list:
192 | if "mp4" not in mp4:
193 | continue
194 | if pre_tag:
195 | png_path = mp4.split('.')[0].split('-')[0] + ".png"
196 | else:
197 | if "-" in mp4:
198 | png_path = mp4.split('.')[0].split('-')[1] + ".png"
199 | else:
200 | png_path = mp4.split('.')[0].split('_')[1] + ".png"
201 | img_list.append(os.path.join(img_path, png_path))
202 | video_list.append(os.path.join(data_path, mp4))
203 | print(img_list)
204 | print(video_list[0])
205 |
206 | model_path = "eval"
207 | face_arc_path = os.path.join(model_path, "face_encoder")
208 | face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin")
209 |
210 | # Initialize FaceEncoder model for face detection and embedding extraction
211 | face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider'])
212 | face_arc_model.prepare(ctx_id=0, det_size=(320, 320))
213 |
214 | # Load face recognition model
215 | face_cur_model = get_model('IR_101')([112, 112])
216 | face_cur_model.load_state_dict(torch.load(face_cur_path, map_location="cpu"))
217 | face_cur_model = face_cur_model.to(device)
218 | face_cur_model.eval()
219 |
220 | # Load InceptionV3 model for FID calculation
221 | fid_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
222 | fid_model.fc = torch.nn.Identity() # Remove final classification layer
223 | fid_model.eval()
224 | fid_model = fid_model.to(device)
225 |
226 | # Process the single video and image pair
227 | # Extract embeddings and features from the image
228 | cur_list, arc_list, fid_list = [], [], []
229 | for i in range(len(img_list)):
230 | align_face_image, arcface_image_embedding = process_image(face_arc_model, img_list[i])
231 |
232 | cur_image_embedding = inference(face_cur_model, align_face_image, device)
233 | align_face_image_pil = Image.fromarray(align_face_image)
234 | real_image = load_image(align_face_image_pil).to(device)
235 | real_activations = get_activations(real_image, fid_model)
236 |
237 | # Process the video and calculate scores
238 | cur_score, arc_score, fid_score = process_video(
239 | video_list[i], face_arc_model, face_cur_model, fid_model,
240 | arcface_image_embedding, cur_image_embedding, real_activations, device
241 | )
242 | print(cur_score, arc_score, fid_score)
243 | cur_list.append(cur_score)
244 | arc_list.append(arc_score)
245 | fid_list.append(fid_score)
246 | # break
247 | print("cur", sum(cur_list)/ len(cur_list))
248 | print("arc", sum(arc_list)/ len(arc_list))
249 | print("fid", sum(fid_list)/ len(fid_list))
250 |
251 |
252 |
253 | main()
254 |
--------------------------------------------------------------------------------
/eval/curricularface/__init__.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone
3 | from .model_irse import IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200
4 | from .model_resnet import ResNet_50, ResNet_101, ResNet_152
5 |
6 |
7 | _model_dict = {
8 | 'ResNet_50': ResNet_50,
9 | 'ResNet_101': ResNet_101,
10 | 'ResNet_152': ResNet_152,
11 | 'IR_18': IR_18,
12 | 'IR_34': IR_34,
13 | 'IR_50': IR_50,
14 | 'IR_101': IR_101,
15 | 'IR_152': IR_152,
16 | 'IR_200': IR_200,
17 | 'IR_SE_50': IR_SE_50,
18 | 'IR_SE_101': IR_SE_101,
19 | 'IR_SE_152': IR_SE_152,
20 | 'IR_SE_200': IR_SE_200
21 | }
22 |
23 |
24 | def get_model(key):
25 | """ Get different backbone network by key,
26 | support ResNet50, ResNet_101, ResNet_152
27 | IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,
28 | IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.
29 | """
30 | if key in _model_dict.keys():
31 | return _model_dict[key]
32 | else:
33 | raise KeyError('not support model {}'.format(key))
34 |
--------------------------------------------------------------------------------
/eval/curricularface/common.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py
3 | import torch.nn as nn
4 | from torch.nn import Conv2d, Module, ReLU, Sigmoid
5 |
6 |
7 | def initialize_weights(modules):
8 | """ Weight initilize, conv2d and linear is initialized with kaiming_normal
9 | """
10 | for m in modules:
11 | if isinstance(m, nn.Conv2d):
12 | nn.init.kaiming_normal_(
13 | m.weight, mode='fan_out', nonlinearity='relu')
14 | if m.bias is not None:
15 | m.bias.data.zero_()
16 | elif isinstance(m, nn.BatchNorm2d):
17 | m.weight.data.fill_(1)
18 | m.bias.data.zero_()
19 | elif isinstance(m, nn.Linear):
20 | nn.init.kaiming_normal_(
21 | m.weight, mode='fan_out', nonlinearity='relu')
22 | if m.bias is not None:
23 | m.bias.data.zero_()
24 |
25 |
26 | class Flatten(Module):
27 | """ Flat tensor
28 | """
29 |
30 | def forward(self, input):
31 | return input.view(input.size(0), -1)
32 |
33 |
34 | class SEModule(Module):
35 | """ SE block
36 | """
37 |
38 | def __init__(self, channels, reduction):
39 | super(SEModule, self).__init__()
40 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
41 | self.fc1 = Conv2d(
42 | channels,
43 | channels // reduction,
44 | kernel_size=1,
45 | padding=0,
46 | bias=False)
47 |
48 | nn.init.xavier_uniform_(self.fc1.weight.data)
49 |
50 | self.relu = ReLU(inplace=True)
51 | self.fc2 = Conv2d(
52 | channels // reduction,
53 | channels,
54 | kernel_size=1,
55 | padding=0,
56 | bias=False)
57 |
58 | self.sigmoid = Sigmoid()
59 |
60 | def forward(self, x):
61 | module_input = x
62 | x = self.avg_pool(x)
63 | x = self.fc1(x)
64 | x = self.relu(x)
65 | x = self.fc2(x)
66 | x = self.sigmoid(x)
67 |
68 | return module_input * x
69 |
--------------------------------------------------------------------------------
/eval/curricularface/model_irse.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py
3 | from collections import namedtuple
4 |
5 | from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, Sequential
6 |
7 | from .common import Flatten, SEModule, initialize_weights
8 |
9 |
10 | class BasicBlockIR(Module):
11 | """ BasicBlock for IRNet
12 | """
13 |
14 | def __init__(self, in_channel, depth, stride):
15 | super(BasicBlockIR, self).__init__()
16 | if in_channel == depth:
17 | self.shortcut_layer = MaxPool2d(1, stride)
18 | else:
19 | self.shortcut_layer = Sequential(
20 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
21 | BatchNorm2d(depth))
22 | self.res_layer = Sequential(
23 | BatchNorm2d(in_channel),
24 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
25 | BatchNorm2d(depth), PReLU(depth),
26 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
27 | BatchNorm2d(depth))
28 |
29 | def forward(self, x):
30 | shortcut = self.shortcut_layer(x)
31 | res = self.res_layer(x)
32 |
33 | return res + shortcut
34 |
35 |
36 | class BottleneckIR(Module):
37 | """ BasicBlock with bottleneck for IRNet
38 | """
39 |
40 | def __init__(self, in_channel, depth, stride):
41 | super(BottleneckIR, self).__init__()
42 | reduction_channel = depth // 4
43 | if in_channel == depth:
44 | self.shortcut_layer = MaxPool2d(1, stride)
45 | else:
46 | self.shortcut_layer = Sequential(
47 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
48 | BatchNorm2d(depth))
49 | self.res_layer = Sequential(
50 | BatchNorm2d(in_channel),
51 | Conv2d(
52 | in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
53 | BatchNorm2d(reduction_channel), PReLU(reduction_channel),
54 | Conv2d(
55 | reduction_channel,
56 | reduction_channel, (3, 3), (1, 1),
57 | 1,
58 | bias=False), BatchNorm2d(reduction_channel),
59 | PReLU(reduction_channel),
60 | Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
61 | BatchNorm2d(depth))
62 |
63 | def forward(self, x):
64 | shortcut = self.shortcut_layer(x)
65 | res = self.res_layer(x)
66 |
67 | return res + shortcut
68 |
69 |
70 | class BasicBlockIRSE(BasicBlockIR):
71 |
72 | def __init__(self, in_channel, depth, stride):
73 | super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
74 | self.res_layer.add_module('se_block', SEModule(depth, 16))
75 |
76 |
77 | class BottleneckIRSE(BottleneckIR):
78 |
79 | def __init__(self, in_channel, depth, stride):
80 | super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
81 | self.res_layer.add_module('se_block', SEModule(depth, 16))
82 |
83 |
84 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
85 | '''A named tuple describing a ResNet block.'''
86 |
87 |
88 | def get_block(in_channel, depth, num_units, stride=2):
89 | return [Bottleneck(in_channel, depth, stride)] + \
90 | [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
91 |
92 |
93 | def get_blocks(num_layers):
94 | if num_layers == 18:
95 | blocks = [
96 | get_block(in_channel=64, depth=64, num_units=2),
97 | get_block(in_channel=64, depth=128, num_units=2),
98 | get_block(in_channel=128, depth=256, num_units=2),
99 | get_block(in_channel=256, depth=512, num_units=2)
100 | ]
101 | elif num_layers == 34:
102 | blocks = [
103 | get_block(in_channel=64, depth=64, num_units=3),
104 | get_block(in_channel=64, depth=128, num_units=4),
105 | get_block(in_channel=128, depth=256, num_units=6),
106 | get_block(in_channel=256, depth=512, num_units=3)
107 | ]
108 | elif num_layers == 50:
109 | blocks = [
110 | get_block(in_channel=64, depth=64, num_units=3),
111 | get_block(in_channel=64, depth=128, num_units=4),
112 | get_block(in_channel=128, depth=256, num_units=14),
113 | get_block(in_channel=256, depth=512, num_units=3)
114 | ]
115 | elif num_layers == 100:
116 | blocks = [
117 | get_block(in_channel=64, depth=64, num_units=3),
118 | get_block(in_channel=64, depth=128, num_units=13),
119 | get_block(in_channel=128, depth=256, num_units=30),
120 | get_block(in_channel=256, depth=512, num_units=3)
121 | ]
122 | elif num_layers == 152:
123 | blocks = [
124 | get_block(in_channel=64, depth=256, num_units=3),
125 | get_block(in_channel=256, depth=512, num_units=8),
126 | get_block(in_channel=512, depth=1024, num_units=36),
127 | get_block(in_channel=1024, depth=2048, num_units=3)
128 | ]
129 | elif num_layers == 200:
130 | blocks = [
131 | get_block(in_channel=64, depth=256, num_units=3),
132 | get_block(in_channel=256, depth=512, num_units=24),
133 | get_block(in_channel=512, depth=1024, num_units=36),
134 | get_block(in_channel=1024, depth=2048, num_units=3)
135 | ]
136 |
137 | return blocks
138 |
139 |
140 | class Backbone(Module):
141 |
142 | def __init__(self, input_size, num_layers, mode='ir'):
143 | """ Args:
144 | input_size: input_size of backbone
145 | num_layers: num_layers of backbone
146 | mode: support ir or irse
147 | """
148 | super(Backbone, self).__init__()
149 | assert input_size[0] in [112, 224], \
150 | 'input_size should be [112, 112] or [224, 224]'
151 | assert num_layers in [18, 34, 50, 100, 152, 200], \
152 | 'num_layers should be 18, 34, 50, 100 or 152'
153 | assert mode in ['ir', 'ir_se'], \
154 | 'mode should be ir or ir_se'
155 | self.input_layer = Sequential(
156 | Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
157 | PReLU(64))
158 | blocks = get_blocks(num_layers)
159 | if num_layers <= 100:
160 | if mode == 'ir':
161 | unit_module = BasicBlockIR
162 | elif mode == 'ir_se':
163 | unit_module = BasicBlockIRSE
164 | output_channel = 512
165 | else:
166 | if mode == 'ir':
167 | unit_module = BottleneckIR
168 | elif mode == 'ir_se':
169 | unit_module = BottleneckIRSE
170 | output_channel = 2048
171 |
172 | if input_size[0] == 112:
173 | self.output_layer = Sequential(
174 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
175 | Linear(output_channel * 7 * 7, 512),
176 | BatchNorm1d(512, affine=False))
177 | else:
178 | self.output_layer = Sequential(
179 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
180 | Linear(output_channel * 14 * 14, 512),
181 | BatchNorm1d(512, affine=False))
182 |
183 | modules = []
184 | mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101
185 | for block in blocks:
186 | if len(mid_layer_indices) == 0:
187 | mid_layer_indices.append(len(block) - 1)
188 | else:
189 | mid_layer_indices.append(len(block) + mid_layer_indices[-1])
190 | for bottleneck in block:
191 | modules.append(
192 | unit_module(bottleneck.in_channel, bottleneck.depth,
193 | bottleneck.stride))
194 | self.body = Sequential(*modules)
195 | self.mid_layer_indices = mid_layer_indices[-4:]
196 |
197 | # self.dtype = next(self.parameters()).dtype
198 | initialize_weights(self.modules())
199 |
200 | def device(self):
201 | return next(self.parameters()).device
202 |
203 | def dtype(self):
204 | return next(self.parameters()).dtype
205 |
206 | def forward(self, x, return_mid_feats=False):
207 | x = self.input_layer(x)
208 | if not return_mid_feats:
209 | x = self.body(x)
210 | x = self.output_layer(x)
211 | return x
212 | else:
213 | out_feats = []
214 | for idx, module in enumerate(self.body):
215 | x = module(x)
216 | if idx in self.mid_layer_indices:
217 | out_feats.append(x)
218 | x = self.output_layer(x)
219 | return x, out_feats
220 |
221 |
222 | def IR_18(input_size):
223 | """ Constructs a ir-18 model.
224 | """
225 | model = Backbone(input_size, 18, 'ir')
226 |
227 | return model
228 |
229 |
230 | def IR_34(input_size):
231 | """ Constructs a ir-34 model.
232 | """
233 | model = Backbone(input_size, 34, 'ir')
234 |
235 | return model
236 |
237 |
238 | def IR_50(input_size):
239 | """ Constructs a ir-50 model.
240 | """
241 | model = Backbone(input_size, 50, 'ir')
242 |
243 | return model
244 |
245 |
246 | def IR_101(input_size):
247 | """ Constructs a ir-101 model.
248 | """
249 | model = Backbone(input_size, 100, 'ir')
250 |
251 | return model
252 |
253 |
254 | def IR_152(input_size):
255 | """ Constructs a ir-152 model.
256 | """
257 | model = Backbone(input_size, 152, 'ir')
258 |
259 | return model
260 |
261 |
262 | def IR_200(input_size):
263 | """ Constructs a ir-200 model.
264 | """
265 | model = Backbone(input_size, 200, 'ir')
266 |
267 | return model
268 |
269 |
270 | def IR_SE_50(input_size):
271 | """ Constructs a ir_se-50 model.
272 | """
273 | model = Backbone(input_size, 50, 'ir_se')
274 |
275 | return model
276 |
277 |
278 | def IR_SE_101(input_size):
279 | """ Constructs a ir_se-101 model.
280 | """
281 | model = Backbone(input_size, 100, 'ir_se')
282 |
283 | return model
284 |
285 |
286 | def IR_SE_152(input_size):
287 | """ Constructs a ir_se-152 model.
288 | """
289 | model = Backbone(input_size, 152, 'ir_se')
290 |
291 | return model
292 |
293 |
294 | def IR_SE_200(input_size):
295 | """ Constructs a ir_se-200 model.
296 | """
297 | model = Backbone(input_size, 200, 'ir_se')
298 |
299 | return model
300 |
--------------------------------------------------------------------------------
/eval/curricularface/model_resnet.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py
3 | import torch.nn as nn
4 | from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential
5 |
6 | from .common import initialize_weights
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | """ 3x3 convolution with padding
11 | """
12 | return Conv2d(
13 | in_planes,
14 | out_planes,
15 | kernel_size=3,
16 | stride=stride,
17 | padding=1,
18 | bias=False)
19 |
20 |
21 | def conv1x1(in_planes, out_planes, stride=1):
22 | """ 1x1 convolution
23 | """
24 | return Conv2d(
25 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26 |
27 |
28 | class Bottleneck(Module):
29 | expansion = 4
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None):
32 | super(Bottleneck, self).__init__()
33 | self.conv1 = conv1x1(inplanes, planes)
34 | self.bn1 = BatchNorm2d(planes)
35 | self.conv2 = conv3x3(planes, planes, stride)
36 | self.bn2 = BatchNorm2d(planes)
37 | self.conv3 = conv1x1(planes, planes * self.expansion)
38 | self.bn3 = BatchNorm2d(planes * self.expansion)
39 | self.relu = ReLU(inplace=True)
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 | def forward(self, x):
44 | identity = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv2(out)
51 | out = self.bn2(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv3(out)
55 | out = self.bn3(out)
56 |
57 | if self.downsample is not None:
58 | identity = self.downsample(x)
59 |
60 | out += identity
61 | out = self.relu(out)
62 |
63 | return out
64 |
65 |
66 | class ResNet(Module):
67 | """ ResNet backbone
68 | """
69 |
70 | def __init__(self, input_size, block, layers, zero_init_residual=True):
71 | """ Args:
72 | input_size: input_size of backbone
73 | block: block function
74 | layers: layers in each block
75 | """
76 | super(ResNet, self).__init__()
77 | assert input_size[0] in [112, 224], \
78 | 'input_size should be [112, 112] or [224, 224]'
79 | self.inplanes = 64
80 | self.conv1 = Conv2d(
81 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
82 | self.bn1 = BatchNorm2d(64)
83 | self.relu = ReLU(inplace=True)
84 | self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
85 | self.layer1 = self._make_layer(block, 64, layers[0])
86 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
87 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
88 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
89 |
90 | self.bn_o1 = BatchNorm2d(2048)
91 | self.dropout = Dropout()
92 | if input_size[0] == 112:
93 | self.fc = Linear(2048 * 4 * 4, 512)
94 | else:
95 | self.fc = Linear(2048 * 7 * 7, 512)
96 | self.bn_o2 = BatchNorm1d(512)
97 |
98 | initialize_weights(self.modules)
99 | if zero_init_residual:
100 | for m in self.modules():
101 | if isinstance(m, Bottleneck):
102 | nn.init.constant_(m.bn3.weight, 0)
103 |
104 | def _make_layer(self, block, planes, blocks, stride=1):
105 | downsample = None
106 | if stride != 1 or self.inplanes != planes * block.expansion:
107 | downsample = Sequential(
108 | conv1x1(self.inplanes, planes * block.expansion, stride),
109 | BatchNorm2d(planes * block.expansion),
110 | )
111 |
112 | layers = []
113 | layers.append(block(self.inplanes, planes, stride, downsample))
114 | self.inplanes = planes * block.expansion
115 | for _ in range(1, blocks):
116 | layers.append(block(self.inplanes, planes))
117 |
118 | return Sequential(*layers)
119 |
120 | def forward(self, x):
121 | x = self.conv1(x)
122 | x = self.bn1(x)
123 | x = self.relu(x)
124 | x = self.maxpool(x)
125 |
126 | x = self.layer1(x)
127 | x = self.layer2(x)
128 | x = self.layer3(x)
129 | x = self.layer4(x)
130 |
131 | x = self.bn_o1(x)
132 | x = self.dropout(x)
133 | x = x.view(x.size(0), -1)
134 | x = self.fc(x)
135 | x = self.bn_o2(x)
136 |
137 | return x
138 |
139 |
140 | def ResNet_50(input_size, **kwargs):
141 | """ Constructs a ResNet-50 model.
142 | """
143 | model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
144 |
145 | return model
146 |
147 |
148 | def ResNet_101(input_size, **kwargs):
149 | """ Constructs a ResNet-101 model.
150 | """
151 | model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)
152 |
153 | return model
154 |
155 |
156 | def ResNet_152(input_size, **kwargs):
157 | """ Constructs a ResNet-152 model.
158 | """
159 | model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)
160 |
161 | return model
162 |
--------------------------------------------------------------------------------
/eval/expression_score.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os
3 | import torch
4 | from insightface.app import FaceAnalysis
5 | from insightface.utils import face_align
6 | from PIL import Image
7 | from torchvision import models, transforms
8 | from curricularface import get_model
9 | import cv2
10 | import numpy as np
11 | import numpy
12 |
13 | def pad_np_bgr_image(np_image, scale=1.25):
14 | assert scale >= 1.0, "scale should be >= 1.0"
15 | pad_scale = scale - 1.0
16 | h, w = np_image.shape[:2]
17 | top = bottom = int(h * pad_scale)
18 | left = right = int(w * pad_scale)
19 | return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top)
20 |
21 |
22 | def sample_video_frames(video_path,):
23 | cap = cv2.VideoCapture(video_path)
24 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
25 | frame_indices = np.linspace(0, total_frames - 1, total_frames, dtype=int)
26 |
27 | frames = []
28 | for idx in frame_indices:
29 | cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
30 | ret, frame = cap.read()
31 | if ret:
32 | # if frame.shape[1] > 1024:
33 | # frame = frame[:, 1440:, :]
34 | # print(frame.shape)
35 | frame = cv2.resize(frame, (720, 480))
36 | # print(frame.shape)
37 | frames.append(frame)
38 | cap.release()
39 | return frames
40 |
41 |
42 | def get_face_keypoints(face_model, image_bgr):
43 | face_info = face_model.get(image_bgr)
44 | if len(face_info) > 0:
45 | return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
46 | return None
47 |
48 | def process_image(face_model, image_path):
49 | if isinstance(image_path, str):
50 | np_faceid_image = np.array(Image.open(image_path).convert("RGB"))
51 | elif isinstance(image_path, numpy.ndarray):
52 | np_faceid_image = image_path
53 | else:
54 | raise TypeError("image_path should be a string or PIL.Image.Image object")
55 |
56 | image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
57 |
58 | face_info = get_face_keypoints(face_model, image_bgr)
59 | if face_info is None:
60 | padded_image, sub_coord = pad_np_bgr_image(image_bgr)
61 | face_info = get_face_keypoints(face_model, padded_image)
62 | if face_info is None:
63 | print("Warning: No face detected in the image. Continuing processing...")
64 | return None
65 | face_kps = face_info['kps']
66 | face_kps -= np.array(sub_coord)
67 | else:
68 | face_kps = face_info['kps']
69 | return face_kps
70 |
71 | def process_video(video_path, face_arc_model):
72 | video_frames = sample_video_frames(video_path,)
73 | print(len(video_frames))
74 | kps_list = []
75 | for frame in video_frames:
76 | # Convert to RGB once at the beginning
77 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
78 | kps = process_image(face_arc_model, frame_rgb)
79 | if kps is None:
80 | return None
81 | # print(kps)
82 | kps_list.append(kps)
83 | return kps_list
84 |
85 |
86 | def calculate_l1_distance(list1, list2):
87 | """
88 | 计算两个列表的 L1 距离
89 | :param list1: 第一个列表,形状为 (5, 2)
90 | :param list2: 第二个列表,形状为 (5, 2)
91 | :return: L1 距离
92 | """
93 | # 将列表转换为 NumPy 数组
94 | list1 = np.array(list1)
95 | list2 = np.array(list2)
96 |
97 | # 计算每对点的 L1 距离
98 | l1_distances = np.abs(list1 - list2).sum(axis=1)
99 |
100 | # 返回所有点的 L1 距离之和
101 | return l1_distances.sum()
102 |
103 |
104 | def calculate_kps(list1, list2):
105 | distance_list = []
106 | for kps1 in list1:
107 | min_dis = (480 + 720) * 5 + 1
108 | for kps2 in list2:
109 | min_dis = min(min_dis, calculate_l1_distance(kps1, kps2))
110 | distance_list.append(min_dis/(480+720)/10)
111 | return sum(distance_list)/len(distance_list)
112 |
113 |
114 | def main():
115 | device = "cuda"
116 | # data_path = "data/SkyActor"
117 | # data_path = "data/LivePotraits"
118 | # data_path = "data/Actor-One"
119 | data_path = "data/FollowYourEmoji"
120 | img_path = "/maindata/data/shared/public/rui.wang/act_review/driving_video"
121 | pre_tag = False
122 | mp4_list = os.listdir(data_path)
123 | print(mp4_list)
124 |
125 | img_list = []
126 | video_list = []
127 | for mp4 in mp4_list:
128 | if "mp4" not in mp4:
129 | continue
130 | if pre_tag:
131 | png_path = mp4.split('.')[0].split('--')[1] + ".mp4"
132 | else:
133 | if "-" in mp4:
134 | png_path = mp4.split('.')[0].split('-')[0] + ".mp4"
135 | else:
136 | png_path = mp4.split('.')[0].split('_')[0] + ".mp4"
137 | img_list.append(os.path.join(img_path, png_path))
138 | video_list.append(os.path.join(data_path, mp4))
139 | print(img_list)
140 | print(video_list[0])
141 |
142 | model_path = "eval"
143 | face_arc_path = os.path.join(model_path, "face_encoder")
144 | face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin")
145 |
146 | # Initialize FaceEncoder model for face detection and embedding extraction
147 | face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider'])
148 | face_arc_model.prepare(ctx_id=0, det_size=(320, 320))
149 |
150 | expression_list = []
151 | for i in range(len(img_list)):
152 | print("number: ", str(i), " total: ", len(img_list), data_path)
153 | kps_1 = process_video(video_list[i], face_arc_model)
154 | kps_2 = process_video(img_list[i], face_arc_model)
155 | if kps_1 is None or kps_2 is None:
156 | continue
157 |
158 | dis = calculate_kps(kps_1, kps_2)
159 | print(dis)
160 | expression_list.append(dis)
161 | # break
162 |
163 | print("kps", sum(expression_list)/ len(expression_list))
164 |
165 |
166 |
167 | main()
168 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 | import insightface
7 | import cv2
8 | import subprocess
9 | import argparse
10 | from decord import VideoReader
11 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
12 | from facexlib.parsing import init_parsing_model
13 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
14 | from insightface.app import FaceAnalysis
15 |
16 | from diffusers.models import AutoencoderKLCogVideoX
17 | from diffusers.utils import export_to_video, load_image
18 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
19 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
20 |
21 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
22 | from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline
23 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
24 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
25 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
26 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
27 | from skyreels_a1.src.multi_fps import multi_fps_tool
28 |
29 | def crop_and_resize(image, height, width):
30 | image = np.array(image)
31 | image_height, image_width, _ = image.shape
32 | if image_height / image_width < height / width:
33 | croped_width = int(image_height / height * width)
34 | left = (image_width - croped_width) // 2
35 | image = image[:, left: left+croped_width]
36 | image = Image.fromarray(image).resize((width, height))
37 | else:
38 | pad = int((((width / height) * image_height) - image_width) / 2.)
39 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
40 | padded_image[:, pad:pad+image_width] = image
41 | image = Image.fromarray(padded_image).resize((width, height))
42 | return image
43 |
44 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"):
45 | clip = ImageSequenceClip(samples, fps=fps)
46 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
47 | ffmpeg_params=["-crf", "18", "-preset", "slow"])
48 |
49 | def parse_video(driving_video_path, max_frame_num):
50 | vr = VideoReader(driving_video_path)
51 | fps = vr.get_avg_fps()
52 | video_length = len(vr)
53 |
54 | duration = video_length / fps
55 | target_times = np.arange(0, duration, 1/12)
56 | frame_indices = (target_times * fps).astype(np.int32)
57 |
58 | frame_indices = frame_indices[frame_indices < video_length]
59 | control_frames = vr.get_batch(frame_indices).asnumpy()[:(max_frame_num-1)]
60 |
61 | out_frames = len(control_frames) - 1
62 | if len(control_frames) < max_frame_num - 1:
63 | video_lenght_add = max_frame_num - len(control_frames) - 1
64 | control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1], [control_frames[-1]] * video_lenght_add), axis=0)
65 | else:
66 | control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1]), axis=0)
67 |
68 | return control_frames
69 |
70 | def exec_cmd(cmd):
71 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
72 |
73 | def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
74 | cmd = [
75 | 'ffmpeg',
76 | '-y',
77 | '-i', f'"{silent_video_path}"',
78 | '-i', f'"{audio_video_path}"',
79 | '-map', '0:v',
80 | '-map', '1:a',
81 | '-c:v', 'copy',
82 | '-shortest',
83 | f'"{output_video_path}"'
84 | ]
85 |
86 | try:
87 | exec_cmd(' '.join(cmd))
88 | print(f"Video with audio generated successfully: {output_video_path}")
89 | except subprocess.CalledProcessError as e:
90 | print(f"Error occurred: {e}")
91 |
92 |
93 | if __name__ == "__main__":
94 | parser = argparse.ArgumentParser(description="Process video and image for face animation.")
95 | parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
96 | parser.add_argument('--driving_video_path', type=str, default="assets/driving_video/1.mp4", help='Path to the driving video.')
97 | parser.add_argument('--output_path', type=str, default="outputs", help='Path to save the output video.')
98 | args = parser.parse_args()
99 |
100 | guidance_scale = 3.0
101 | seed = 43
102 | num_inference_steps = 10
103 | sample_size = [480, 720]
104 | max_frame_num = 49
105 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates
106 | weight_dtype = torch.bfloat16
107 | save_path = args.output_path
108 | generator = torch.Generator(device="cuda").manual_seed(seed)
109 | model_name = "pretrained_models/SkyReels-A1-5B/"
110 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384"
111 |
112 | lmk_extractor = LMKExtractor()
113 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
114 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,)
115 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",)
116 |
117 | # siglip visual encoder
118 | siglip = SiglipVisionModel.from_pretrained(siglip_name)
119 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name)
120 |
121 | # frame interpolation model
122 | if target_fps != 12:
123 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda")
124 |
125 | # skyreels a1 model
126 | transformer = CogVideoXTransformer3DModel.from_pretrained(
127 | model_name,
128 | subfolder="transformer"
129 | ).to(weight_dtype)
130 |
131 | vae = AutoencoderKLCogVideoX.from_pretrained(
132 | model_name,
133 | subfolder="vae"
134 | ).to(weight_dtype)
135 |
136 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
137 | model_name,
138 | subfolder="pose_guider",
139 | ).to(weight_dtype)
140 |
141 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
142 | model_name,
143 | transformer = transformer,
144 | vae = vae,
145 | lmk_encoder = lmk_encoder,
146 | image_encoder = siglip,
147 | feature_extractor = siglip_normalize,
148 | torch_dtype=torch.bfloat16
149 | )
150 |
151 | pipe.to("cuda")
152 | pipe.enable_model_cpu_offload()
153 | pipe.vae.enable_tiling()
154 |
155 | control_frames = parse_video(args.driving_video_path, max_frame_num)
156 |
157 | # driving video crop face
158 | driving_video_crop = []
159 | for control_frame in control_frames:
160 | frame, _, _ = processor.face_crop(control_frame)
161 | driving_video_crop.append(frame)
162 |
163 | image = load_image(image=args.image_path)
164 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1])
165 |
166 | # ref image crop face
167 | ref_image, x1, y1 = processor.face_crop(np.array(image))
168 | face_h, face_w, _, = ref_image.shape
169 | source_image = ref_image
170 | driving_video = driving_video_crop
171 | out_frames = processor.preprocess_lmk3d(source_image, driving_video)
172 |
173 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0)
174 | for ii in range(rescale_motions.shape[0]):
175 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
176 | ref_image = cv2.resize(ref_image, (512, 512))
177 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
178 |
179 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
180 |
181 | first_motion = np.zeros_like(np.array(image))
182 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
183 | first_motion = first_motion[np.newaxis, :]
184 |
185 | motions = np.concatenate([first_motion, rescale_motions])
186 | input_video = motions[:max_frame_num]
187 |
188 | face_helper.clean_all()
189 | face_helper.read_image(np.array(image)[:, :, ::-1])
190 | face_helper.get_face_landmarks_5(only_center_face=True)
191 | face_helper.align_warp_face()
192 | align_face = face_helper.cropped_faces[0]
193 | image_face = align_face[:, :, ::-1]
194 |
195 | input_video = input_video[:max_frame_num]
196 | motions = np.array(input_video)
197 |
198 | # [F, H, W, C]
199 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0)
200 | input_video = input_video / 255
201 |
202 | out_samples = []
203 |
204 | with torch.no_grad():
205 | sample = pipe(
206 | image=image,
207 | image_face=image_face,
208 | control_video = input_video,
209 | prompt = "",
210 | negative_prompt = "",
211 | height = sample_size[0],
212 | width = sample_size[1],
213 | num_frames = 49,
214 | generator = generator,
215 | guidance_scale = guidance_scale,
216 | num_inference_steps = num_inference_steps,
217 | )
218 | out_samples.extend(sample.frames[0])
219 | out_samples = out_samples[2:]
220 |
221 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_video_path).split(".")[0]+ ".mp4"
222 |
223 | if not os.path.exists(save_path):
224 | os.makedirs(save_path, exist_ok=True)
225 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4")
226 |
227 | if target_fps != 12:
228 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps)
229 |
230 | export_to_video(out_samples, video_path, fps=target_fps)
231 | add_audio_to_video(video_path, args.driving_video_path, video_path.split(".")[0] + "_audio.mp4")
232 |
233 | if target_fps == 12:
234 | target_h, target_w = sample_size[0], sample_size[1]
235 | final_images = []
236 | final_images2 =[]
237 | rescale_motions = rescale_motions[1:]
238 | control_frames = control_frames[1:]
239 | for q in range(len(out_samples)):
240 | frame1 = image
241 | frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[q])).convert("RGB"), target_h, target_w)
242 | frame3 = Image.fromarray(np.array(out_samples[q])).convert("RGB")
243 |
244 | result = Image.new('RGB', (target_w * 3, target_h))
245 | result.paste(frame1, (0, 0))
246 | result.paste(frame2, (target_w, 0))
247 | result.paste(frame3, (target_w * 2, 0))
248 | final_images.append(np.array(result))
249 |
250 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4")
251 | write_mp4(video_out_path, final_images, fps=12)
252 |
253 | add_audio_to_video(video_out_path, args.driving_video_path, video_out_path.split(".")[0] + f"_audio.mp4")
254 |
--------------------------------------------------------------------------------
/inference_audio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 | import insightface
7 | import cv2
8 | import subprocess
9 | import argparse
10 | from decord import VideoReader
11 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
12 | from facexlib.parsing import init_parsing_model
13 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
14 | from insightface.app import FaceAnalysis
15 |
16 | from diffusers.models import AutoencoderKLCogVideoX
17 | from diffusers.utils import export_to_video, load_image
18 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
19 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
20 |
21 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
22 | from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline
23 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
24 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
25 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
26 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
27 | from skyreels_a1.src.multi_fps import multi_fps_tool
28 |
29 | import moviepy.editor as mp
30 | from diffposetalk.diffposetalk import DiffPoseTalk
31 |
32 |
33 | def crop_and_resize(image, height, width):
34 | image = np.array(image)
35 | image_height, image_width, _ = image.shape
36 | if image_height / image_width < height / width:
37 | croped_width = int(image_height / height * width)
38 | left = (image_width - croped_width) // 2
39 | image = image[:, left: left+croped_width]
40 | image = Image.fromarray(image).resize((width, height))
41 | else:
42 | pad = int((((width / height) * image_height) - image_width) / 2.)
43 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
44 | padded_image[:, pad:pad+image_width] = image
45 | image = Image.fromarray(padded_image).resize((width, height))
46 | return image
47 |
48 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"):
49 | clip = ImageSequenceClip(samples, fps=fps)
50 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
51 | ffmpeg_params=["-crf", "18", "-preset", "slow"])
52 |
53 |
54 | def parse_video(driving_frames, max_frame_num, fps=25):
55 |
56 | video_length = len(driving_frames)
57 |
58 | duration = video_length / fps
59 | target_times = np.arange(0, duration, 1/12)
60 | frame_indices = (target_times * fps).astype(np.int32)
61 |
62 | frame_indices = frame_indices[frame_indices < video_length]
63 | new_driving_frames = []
64 | for idx in frame_indices:
65 | new_driving_frames.append(driving_frames[idx])
66 | if len(new_driving_frames) >= max_frame_num - 1:
67 | break
68 |
69 | video_lenght_add = max_frame_num - len(new_driving_frames) - 1
70 | new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add
71 | return new_driving_frames
72 |
73 | def exec_cmd(cmd):
74 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
75 |
76 | def add_audio_to_video(silent_video_path, audio_video_path, output_video_path):
77 | cmd = [
78 | 'ffmpeg',
79 | '-y',
80 | '-i', f'"{silent_video_path}"',
81 | '-i', f'"{audio_video_path}"',
82 | '-map', '0:v',
83 | '-map', '1:a',
84 | '-c:v', 'copy',
85 | '-shortest',
86 | f'"{output_video_path}"'
87 | ]
88 |
89 | try:
90 | exec_cmd(' '.join(cmd))
91 | print(f"Video with audio generated successfully: {output_video_path}")
92 | except subprocess.CalledProcessError as e:
93 | print(f"Error occurred: {e}")
94 |
95 | if __name__ == "__main__":
96 | parser = argparse.ArgumentParser(description="Process video and image for face animation.")
97 | parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
98 | parser.add_argument('--driving_audio_path', type=str, default="assets/driving_audio/1.wav", help='Path to the driving video.')
99 | parser.add_argument('--output_path', type=str, default="outputs_audio", help='Path to save the output video.')
100 | args = parser.parse_args()
101 |
102 | guidance_scale = 3.0
103 | seed = 43
104 | num_inference_steps = 10
105 | sample_size = [480, 720]
106 | max_frame_num = 49
107 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates
108 | weight_dtype = torch.bfloat16
109 | save_path = args.output_path
110 | generator = torch.Generator(device="cuda").manual_seed(seed)
111 | model_name = "pretrained_models/SkyReels-A1-5B/"
112 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384"
113 |
114 | lmk_extractor = LMKExtractor()
115 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
116 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,)
117 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",)
118 |
119 | # siglip visual encoder
120 | siglip = SiglipVisionModel.from_pretrained(siglip_name)
121 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name)
122 |
123 | # frame interpolation model
124 | if target_fps != 12:
125 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda")
126 |
127 | # diffposetalk
128 | diffposetalk = DiffPoseTalk()
129 |
130 | # skyreels a1 model
131 | transformer = CogVideoXTransformer3DModel.from_pretrained(
132 | model_name,
133 | subfolder="transformer"
134 | ).to(weight_dtype)
135 |
136 | vae = AutoencoderKLCogVideoX.from_pretrained(
137 | model_name,
138 | subfolder="vae"
139 | ).to(weight_dtype)
140 |
141 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
142 | model_name,
143 | subfolder="pose_guider",
144 | ).to(weight_dtype)
145 |
146 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
147 | model_name,
148 | transformer = transformer,
149 | vae = vae,
150 | lmk_encoder = lmk_encoder,
151 | image_encoder = siglip,
152 | feature_extractor = siglip_normalize,
153 | torch_dtype=torch.bfloat16
154 | )
155 |
156 | pipe.to("cuda")
157 | pipe.enable_model_cpu_offload()
158 | pipe.vae.enable_tiling()
159 |
160 | image = load_image(image=args.image_path)
161 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1])
162 |
163 | # ref image crop face
164 | ref_image, x1, y1 = processor.face_crop(np.array(image))
165 | face_h, face_w, _, = ref_image.shape
166 | source_image = ref_image
167 |
168 | source_outputs, source_tform, image_original = processor.process_source_image(source_image)
169 | driving_outputs = diffposetalk.infer_from_file(args.driving_audio_path, source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy())
170 | out_frames = processor.preprocess_lmk3d_from_coef(source_outputs, source_tform, image_original.shape, driving_outputs)
171 | out_frames = parse_video(out_frames, max_frame_num)
172 |
173 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0)
174 | for ii in range(rescale_motions.shape[0]):
175 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
176 | ref_image = cv2.resize(ref_image, (512, 512))
177 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
178 |
179 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
180 |
181 | first_motion = np.zeros_like(np.array(image))
182 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
183 | first_motion = first_motion[np.newaxis, :]
184 |
185 | motions = np.concatenate([first_motion, rescale_motions])
186 | input_video = motions[:max_frame_num]
187 |
188 | face_helper.clean_all()
189 | face_helper.read_image(np.array(image)[:, :, ::-1])
190 | face_helper.get_face_landmarks_5(only_center_face=True)
191 | face_helper.align_warp_face()
192 | align_face = face_helper.cropped_faces[0]
193 | image_face = align_face[:, :, ::-1]
194 |
195 | input_video = input_video[:max_frame_num]
196 | motions = np.array(input_video)
197 |
198 | # [F, H, W, C]
199 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0)
200 | input_video = input_video / 255
201 |
202 | out_samples = []
203 |
204 | with torch.no_grad():
205 | sample = pipe(
206 | image=image,
207 | image_face=image_face,
208 | control_video = input_video,
209 | prompt = "",
210 | negative_prompt = "",
211 | height = sample_size[0],
212 | width = sample_size[1],
213 | num_frames = 49,
214 | generator = generator,
215 | guidance_scale = guidance_scale,
216 | num_inference_steps = num_inference_steps,
217 | )
218 | out_samples.extend(sample.frames[0])
219 | out_samples = out_samples[2:]
220 |
221 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_audio_path).split(".")[0]+ ".mp4"
222 |
223 | if not os.path.exists(save_path):
224 | os.makedirs(save_path, exist_ok=True)
225 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4")
226 |
227 | if target_fps != 12:
228 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps)
229 |
230 | export_to_video(out_samples, video_path, fps=target_fps)
231 | add_audio_to_video(video_path, args.driving_audio_path, video_path.split(".")[0] + "_audio.mp4")
232 |
233 | if target_fps == 12:
234 | target_h, target_w = sample_size[0], sample_size[1]
235 | final_images = []
236 | final_images2 =[]
237 | rescale_motions = rescale_motions[1:]
238 | control_frames = out_frames[1:]
239 | for q in range(len(out_samples)):
240 | frame1 = image
241 | frame2 = Image.fromarray(np.array(out_samples[q])).convert("RGB")
242 |
243 | result = Image.new('RGB', (target_w * 2, target_h))
244 | result.paste(frame1, (0, 0))
245 | result.paste(frame2, (target_w, 0))
246 | final_images.append(np.array(result))
247 |
248 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4")
249 | write_mp4(video_out_path, final_images, fps=12)
250 |
251 | add_audio_to_video(video_out_path, args.driving_audio_path, video_out_path.split(".")[0] + "_audio.mp4")
252 |
253 |
--------------------------------------------------------------------------------
/inference_audio_long_video.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 | import insightface
7 | import cv2
8 | import subprocess
9 | import argparse
10 | import math
11 | import time
12 | from decord import VideoReader
13 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
14 | from facexlib.parsing import init_parsing_model
15 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
16 | from insightface.app import FaceAnalysis
17 | import moviepy.editor as mp
18 |
19 | from diffusers.models import AutoencoderKLCogVideoX
20 | from diffusers.utils import export_to_video, load_image
21 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
22 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
23 |
24 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
25 | from skyreels_a1.skyreels_a1_i2v_long_pipeline import SkyReelsA1ImagePoseToVideoPipeline
26 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
27 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
28 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
29 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
30 | from skyreels_a1.src.multi_fps import multi_fps_tool
31 |
32 | from diffusers.video_processor import VideoProcessor
33 | from diffposetalk.diffposetalk import DiffPoseTalk
34 |
35 |
36 | def crop_and_resize(image, height, width):
37 | image = np.array(image)
38 | image_height, image_width, _ = image.shape
39 | if image_height / image_width < height / width:
40 | croped_width = int(image_height / height * width)
41 | left = (image_width - croped_width) // 2
42 | image = image[:, left: left+croped_width]
43 | image = Image.fromarray(image).resize((width, height))
44 | else:
45 | pad = int((((width / height) * image_height) - image_width) / 2.)
46 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
47 | padded_image[:, pad:pad+image_width] = image
48 | image = Image.fromarray(padded_image).resize((width, height))
49 | return image
50 |
51 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"):
52 | clip = ImageSequenceClip(samples, fps=fps)
53 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
54 | ffmpeg_params=["-crf", "18", "-preset", "slow"])
55 |
56 | def parse_video(driving_frames, fps=25):
57 | video_length = len(driving_frames)
58 |
59 | duration = video_length / fps
60 | target_times = np.arange(0, duration, 1/12)
61 | frame_indices = (target_times * fps).astype(np.int32)
62 |
63 | frame_indices = frame_indices[frame_indices < video_length]
64 | new_driving_frames = []
65 | for idx in frame_indices:
66 | new_driving_frames.append(driving_frames[idx])
67 |
68 | return new_driving_frames
69 |
70 |
71 | def smooth_video_transition(frames1, frames2, smooth_frame_num, frame_inter_model, inter_frames=2):
72 |
73 | frames1_np = np.array([np.array(frame) for frame in frames1])
74 | frames2_np = np.array([np.array(frame) for frame in frames2])
75 | frames1_tensor = torch.from_numpy(frames1_np).permute(3,0,1,2).unsqueeze(0) / 255.0
76 | frames2_tensor = torch.from_numpy(frames2_np).permute(3,0,1,2).unsqueeze(0) / 255.0
77 | video = torch.cat([frames1_tensor[:,:,-smooth_frame_num:], frames2_tensor[:,:,:smooth_frame_num]], dim=2)
78 |
79 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=inter_frames)
80 |
81 | index = [1, 4, 5, 8] if inter_frames == 2 else [2, 5, 7, 10]
82 | video = video[:, :, index]
83 | video = video.squeeze(0)
84 | video = video.permute(1, 2, 3, 0)
85 | video = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
86 | mid_frames = [Image.fromarray(frame) for frame in video]
87 |
88 | out_frames = frames1[:-smooth_frame_num] + mid_frames + frames2[smooth_frame_num:]
89 |
90 | return out_frames
91 |
92 |
93 | def exec_cmd(cmd):
94 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
95 |
96 | def add_audio_to_video(silent_video_path, audio_video_path, output_video_path):
97 | cmd = [
98 | 'ffmpeg',
99 | '-y',
100 | '-i', f'"{silent_video_path}"',
101 | '-i', f'"{audio_video_path}"',
102 | '-map', '0:v',
103 | '-map', '1:a',
104 | '-c:v', 'copy',
105 | '-shortest',
106 | f'"{output_video_path}"'
107 | ]
108 |
109 | try:
110 | exec_cmd(' '.join(cmd))
111 | print(f"Video with audio generated successfully: {output_video_path}")
112 | except subprocess.CalledProcessError as e:
113 | print(f"Error occurred: {e}")
114 |
115 |
116 | if __name__ == "__main__":
117 | parser = argparse.ArgumentParser(description="Process video and image for face animation.")
118 | parser.add_argument('--image_path', type=str, default="assets/ref_images/19.png", help='Path to the source image.')
119 | parser.add_argument('--driving_audio_path', type=str, default="assets/driving_audio/2.wav", help='Path to the driving video.')
120 | parser.add_argument('--output_path', type=str, default="outputs_audio", help='Path to save the output video.')
121 | args = parser.parse_args()
122 |
123 | guidance_scale = 3.0
124 | seed = 43
125 | num_inference_steps = 10
126 | sample_size = [480, 720]
127 | max_frame_num = 10000
128 | frame_num_per_batch = 49
129 | overlap_frame_num = 8
130 | fusion_interval = [3, 8]
131 | use_interpolation = True
132 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates
133 | weight_dtype = torch.bfloat16
134 | save_path = args.output_path
135 | generator = torch.Generator(device="cuda").manual_seed(seed)
136 | model_name = "pretrained_models/SkyReels-A1-5B/"
137 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384"
138 |
139 | lmk_extractor = LMKExtractor()
140 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
141 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False)
142 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda")
143 |
144 | # siglip visual encoder
145 | siglip = SiglipVisionModel.from_pretrained(siglip_name)
146 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name)
147 |
148 | # frame interpolation model
149 | if use_interpolation or target_fps != 12:
150 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda")
151 |
152 | # diffposetalk
153 | diffposetalk = DiffPoseTalk()
154 |
155 | # skyreels a1 model
156 | transformer = CogVideoXTransformer3DModel.from_pretrained(
157 | model_name,
158 | subfolder="transformer"
159 | ).to(weight_dtype)
160 |
161 | vae = AutoencoderKLCogVideoX.from_pretrained(
162 | model_name,
163 | subfolder="vae"
164 | ).to(weight_dtype)
165 |
166 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
167 | model_name,
168 | subfolder="pose_guider",
169 | ).to(weight_dtype)
170 |
171 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
172 | model_name,
173 | transformer=transformer,
174 | vae=vae,
175 | lmk_encoder=lmk_encoder,
176 | image_encoder=siglip,
177 | feature_extractor=siglip_normalize,
178 | torch_dtype=torch.bfloat16
179 | )
180 |
181 | pipe.to("cuda")
182 | pipe.enable_model_cpu_offload()
183 | pipe.vae.enable_tiling()
184 |
185 | image = load_image(image=args.image_path)
186 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1])
187 |
188 | # ref image crop face
189 | ref_image, x1, y1 = processor.face_crop(np.array(image))
190 | face_h, face_w, _ = ref_image.shape
191 | source_image = ref_image
192 |
193 | source_outputs, source_tform, image_original = processor.process_source_image(source_image)
194 | driving_outputs = diffposetalk.infer_from_file(args.driving_audio_path, source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy())
195 |
196 | out_frames = processor.preprocess_lmk3d_from_coef(source_outputs, source_tform, image_original.shape, driving_outputs)
197 | out_frames = parse_video(out_frames)
198 |
199 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(len(out_frames), axis=0)
200 | for ii in range(rescale_motions.shape[0]):
201 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
202 | ref_image = cv2.resize(ref_image, (512, 512))
203 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
204 |
205 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
206 |
207 | first_motion = np.zeros_like(np.array(image))
208 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
209 | first_motion = first_motion[np.newaxis, :]
210 |
211 | input_video = rescale_motions[:max_frame_num]
212 |
213 | video_length = len(input_video)
214 | print(f"orginal video length: {video_length}")
215 |
216 | face_helper.clean_all()
217 | face_helper.read_image(np.array(image)[:, :, ::-1])
218 | face_helper.get_face_landmarks_5(only_center_face=True)
219 | face_helper.align_warp_face()
220 | align_face = face_helper.cropped_faces[0]
221 | image_face = align_face[:, :, ::-1]
222 |
223 | # [F, H, W, C]
224 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, F, H, W]
225 | input_video = input_video / 255
226 | input_video_all = input_video
227 |
228 | first_motion = torch.from_numpy(np.array(first_motion)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, 1, H, W]
229 | first_motion = first_motion / 255
230 |
231 | out_samples = []
232 | padding_frame_num = None
233 | latents_cache = []
234 |
235 | time_start = time.time()
236 |
237 | for i in range(0, video_length, frame_num_per_batch-1-overlap_frame_num):
238 | is_first_batch = (i == 0)
239 | is_last_batch = (i + frame_num_per_batch - 1 >= video_length)
240 |
241 | input_video = input_video_all[:, :, i:i+frame_num_per_batch-1]
242 |
243 | if input_video.shape[2] != frame_num_per_batch-1:
244 | padding_frame_num = frame_num_per_batch-1 - input_video.shape[2]
245 | print(f"padding_frame_num: {padding_frame_num}")
246 | input_video = torch.cat([input_video, torch.repeat_interleave(input_video[:, :, -1].unsqueeze(2), padding_frame_num, dim=2)], dim=2)
247 |
248 | input_video = torch.cat([first_motion, input_video[:, :, 0:1], input_video], dim=2)
249 |
250 | with torch.no_grad():
251 | sample, latents_cache = pipe(
252 | image=image,
253 | image_face=image_face,
254 | control_video=input_video,
255 | prompt="",
256 | negative_prompt="",
257 | height=sample_size[0],
258 | width=sample_size[1],
259 | num_frames=frame_num_per_batch,
260 | generator=generator,
261 | guidance_scale=guidance_scale,
262 | num_inference_steps=num_inference_steps,
263 | is_last_batch=is_last_batch,
264 | overlap_frame_num=overlap_frame_num,
265 | fusion_interval=fusion_interval,
266 | latents_cache=latents_cache
267 | )
268 | if use_interpolation:
269 | if is_first_batch:
270 | out_samples = sample.frames[0][1:]
271 | else:
272 | out_samples = smooth_video_transition(out_samples, sample.frames[0][1+overlap_frame_num:], 2, frame_inter_model)
273 | else:
274 | out_sample = sample.frames[0][1:] if is_first_batch else sample.frames[0][1+overlap_frame_num:]
275 | out_samples.extend(out_sample)
276 | print(f"out_samples len: {len(out_samples)}")
277 |
278 | if is_last_batch:
279 | break
280 |
281 | if padding_frame_num is not None:
282 | out_samples = out_samples[:-padding_frame_num]
283 |
284 | print(f"output video length: {len(out_samples)}")
285 |
286 | time_end = time.time()
287 | print(f"time cost: {time_end - time_start} seconds")
288 |
289 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_audio_path).split(".")[0]+ ".mp4"
290 |
291 | if not os.path.exists(save_path):
292 | os.makedirs(save_path, exist_ok=True)
293 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4")
294 |
295 | if target_fps != 12:
296 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps)
297 |
298 | export_to_video(out_samples, video_path, fps=target_fps)
299 | add_audio_to_video(video_path, args.driving_audio_path, video_path.split(".")[0] + "_audio.mp4")
300 |
301 | if target_fps == 12:
302 | target_h, target_w = sample_size[0], sample_size[1]
303 | final_images = []
304 | final_images2 =[]
305 | rescale_motions = rescale_motions[1:]
306 | control_frames = out_frames[1:]
307 | for q in range(len(out_samples)):
308 | frame1 = image
309 | frame2 = Image.fromarray(np.array(out_samples[q])).convert("RGB")
310 |
311 | result = Image.new('RGB', (target_w * 2, target_h))
312 | result.paste(frame1, (0, 0))
313 | result.paste(frame2, (target_w, 0))
314 | final_images.append(np.array(result))
315 |
316 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4")
317 | write_mp4(video_out_path, final_images, fps=12)
318 |
319 | add_audio_to_video(video_out_path, args.driving_audio_path, video_out_path.split(".")[0] + "_audio.mp4")
320 |
--------------------------------------------------------------------------------
/inference_long_video.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 | import insightface
7 | import cv2
8 | import subprocess
9 | import argparse
10 | import time
11 | from decord import VideoReader
12 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
13 | from facexlib.parsing import init_parsing_model
14 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
15 | from insightface.app import FaceAnalysis
16 |
17 | from diffusers.models import AutoencoderKLCogVideoX
18 | from diffusers.utils import export_to_video, load_image
19 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
20 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
21 |
22 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
23 | from skyreels_a1.skyreels_a1_i2v_long_pipeline import SkyReelsA1ImagePoseToVideoPipeline
24 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
25 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
26 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
27 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
28 | from skyreels_a1.src.multi_fps import multi_fps_tool
29 |
30 | from diffusers.video_processor import VideoProcessor
31 |
32 |
33 | def crop_and_resize(image, height, width):
34 | image = np.array(image)
35 | image_height, image_width, _ = image.shape
36 | if image_height / image_width < height / width:
37 | croped_width = int(image_height / height * width)
38 | left = (image_width - croped_width) // 2
39 | image = image[:, left: left+croped_width]
40 | image = Image.fromarray(image).resize((width, height))
41 | else:
42 | pad = int((((width / height) * image_height) - image_width) / 2.)
43 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
44 | padded_image[:, pad:pad+image_width] = image
45 | image = Image.fromarray(padded_image).resize((width, height))
46 | return image
47 |
48 |
49 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"):
50 | clip = ImageSequenceClip(samples, fps=fps)
51 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
52 | ffmpeg_params=["-crf", "18", "-preset", "slow"])
53 |
54 |
55 | def parse_video(driving_video_path, max_frame_num=10000):
56 | vr = VideoReader(driving_video_path)
57 | fps = vr.get_avg_fps()
58 | video_length = len(vr)
59 |
60 | duration = video_length / fps
61 | target_times = np.arange(0, duration, 1/12)
62 | frame_indices = (target_times * fps).astype(np.int32)
63 |
64 | frame_indices = frame_indices[frame_indices < video_length]
65 | control_frames = vr.get_batch(frame_indices).asnumpy()[:max_frame_num]
66 |
67 | return control_frames
68 |
69 |
70 | def exec_cmd(cmd):
71 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
72 |
73 |
74 | def add_audio_to_video(silent_video_path, audio_video_path, output_video_path):
75 | cmd = [
76 | 'ffmpeg',
77 | '-y',
78 | '-i', f'"{silent_video_path}"',
79 | '-i', f'"{audio_video_path}"',
80 | '-map', '0:v',
81 | '-map', '1:a',
82 | '-c:v', 'copy',
83 | '-shortest',
84 | f'"{output_video_path}"'
85 | ]
86 |
87 | try:
88 | exec_cmd(' '.join(cmd))
89 | print(f"Video with audio generated successfully: {output_video_path}")
90 | except subprocess.CalledProcessError as e:
91 | print(f"Error occurred: {e}")
92 |
93 |
94 | def smooth_video_transition(frames1, frames2, smooth_frame_num, frame_inter_model, inter_frames=2):
95 |
96 | frames1_np = np.array([np.array(frame) for frame in frames1])
97 | frames2_np = np.array([np.array(frame) for frame in frames2])
98 | frames1_tensor = torch.from_numpy(frames1_np).permute(3,0,1,2).unsqueeze(0) / 255.0
99 | frames2_tensor = torch.from_numpy(frames2_np).permute(3,0,1,2).unsqueeze(0) / 255.0
100 | video = torch.cat([frames1_tensor[:,:,-smooth_frame_num:], frames2_tensor[:,:,:smooth_frame_num]], dim=2)
101 |
102 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=inter_frames)
103 |
104 | index = [1, 4, 5, 8] if inter_frames == 2 else [2, 5, 7, 10]
105 | video = video[:, :, index]
106 | video = video.squeeze(0)
107 | video = video.permute(1, 2, 3, 0)
108 | video = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
109 | mid_frames = [Image.fromarray(frame) for frame in video]
110 |
111 | out_frames = frames1[:-smooth_frame_num] + mid_frames + frames2[smooth_frame_num:]
112 |
113 | return out_frames
114 |
115 |
116 | if __name__ == "__main__":
117 | parser = argparse.ArgumentParser(description="Process video and image for face animation.")
118 | parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
119 | parser.add_argument('--driving_video_path', type=str, default="assets/driving_video/6.mp4", help='Path to the driving video.')
120 | parser.add_argument('--output_path', type=str, default="outputs", help='Path to save the output video.')
121 | args = parser.parse_args()
122 |
123 | guidance_scale = 3.0
124 | seed = 43
125 | num_inference_steps = 10
126 | sample_size = [480, 720]
127 | max_frame_num = 10000
128 | frame_num_per_batch = 49
129 | overlap_frame_num = 8
130 | fusion_interval = [3, 8]
131 | use_interpolation = True
132 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates
133 | weight_dtype = torch.bfloat16
134 | save_path = args.output_path
135 | generator = torch.Generator(device="cuda").manual_seed(seed)
136 | model_name = "pretrained_models/SkyReels-A1-5B/"
137 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384"
138 |
139 | lmk_extractor = LMKExtractor()
140 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
141 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False)
142 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda")
143 |
144 | # siglip visual encoder
145 | siglip = SiglipVisionModel.from_pretrained(siglip_name)
146 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name)
147 |
148 | # frame interpolation model
149 | if use_interpolation or target_fps != 12:
150 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda")
151 |
152 | # skyreels a1 model
153 | transformer = CogVideoXTransformer3DModel.from_pretrained(
154 | model_name,
155 | subfolder="transformer"
156 | ).to(weight_dtype)
157 |
158 | vae = AutoencoderKLCogVideoX.from_pretrained(
159 | model_name,
160 | subfolder="vae"
161 | ).to(weight_dtype)
162 |
163 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
164 | model_name,
165 | subfolder="pose_guider",
166 | ).to(weight_dtype)
167 |
168 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
169 | model_name,
170 | transformer=transformer,
171 | vae=vae,
172 | lmk_encoder=lmk_encoder,
173 | image_encoder=siglip,
174 | feature_extractor=siglip_normalize,
175 | torch_dtype=torch.bfloat16
176 | )
177 |
178 | pipe.to("cuda")
179 | pipe.enable_model_cpu_offload()
180 | pipe.vae.enable_tiling()
181 |
182 | control_frames = parse_video(args.driving_video_path, max_frame_num)
183 |
184 | # driving video crop face
185 | driving_video_crop = []
186 | empty_index = []
187 | from tqdm import tqdm
188 | for i, control_frame in enumerate(tqdm(control_frames, desc="Face crop")):
189 | frame, _, _ = processor.face_crop(control_frame)
190 | if frame is None:
191 | print(f'Warning: No face detected in the driving video frame {i}')
192 | empty_index.append(i)
193 | else:
194 | driving_video_crop.append(frame)
195 |
196 | control_frames = np.delete(control_frames, empty_index, axis=0)
197 |
198 | video_length = len(driving_video_crop) # orginal video length
199 |
200 | print(f"orginal video length: {video_length}")
201 |
202 | image = load_image(image=args.image_path)
203 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1])
204 |
205 | # ref image crop face
206 | ref_image, x1, y1 = processor.face_crop(np.array(image))
207 | face_h, face_w, _ = ref_image.shape
208 | source_image = ref_image
209 | driving_video = driving_video_crop
210 | out_frames = processor.preprocess_lmk3d(source_image, driving_video)
211 |
212 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(len(out_frames), axis=0)
213 | for ii in range(rescale_motions.shape[0]):
214 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
215 | ref_image = cv2.resize(ref_image, (512, 512))
216 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
217 |
218 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
219 |
220 | first_motion = np.zeros_like(np.array(image))
221 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
222 | first_motion = first_motion[np.newaxis, :]
223 |
224 | input_video = rescale_motions[:max_frame_num]
225 |
226 | face_helper.clean_all()
227 | face_helper.read_image(np.array(image)[:, :, ::-1])
228 | face_helper.get_face_landmarks_5(only_center_face=True)
229 | face_helper.align_warp_face()
230 | align_face = face_helper.cropped_faces[0]
231 | image_face = align_face[:, :, ::-1]
232 |
233 | # [F, H, W, C]
234 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, F, H, W]
235 | input_video = input_video / 255
236 | input_video_all = input_video
237 |
238 | first_motion = torch.from_numpy(np.array(first_motion)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, 1, H, W]
239 | first_motion = first_motion / 255
240 |
241 | out_samples = []
242 | padding_frame_num = None
243 | latents_cache = []
244 |
245 | time_start = time.time()
246 |
247 | for i in range(0, video_length, frame_num_per_batch-1-overlap_frame_num):
248 | is_first_batch = (i == 0)
249 | is_last_batch = (i + frame_num_per_batch - 1 >= video_length)
250 |
251 | input_video = input_video_all[:, :, i:i+frame_num_per_batch-1]
252 |
253 | if input_video.shape[2] != frame_num_per_batch-1:
254 | padding_frame_num = frame_num_per_batch-1 - input_video.shape[2]
255 | print(f"padding_frame_num: {padding_frame_num}")
256 | input_video = torch.cat([input_video, torch.repeat_interleave(input_video[:, :, -1].unsqueeze(2), padding_frame_num, dim=2)], dim=2)
257 |
258 | input_video = torch.cat([first_motion, input_video], dim=2)
259 |
260 | with torch.no_grad():
261 | sample, latents_cache = pipe(
262 | image=image,
263 | image_face=image_face,
264 | control_video=input_video,
265 | prompt="",
266 | negative_prompt="",
267 | height=sample_size[0],
268 | width=sample_size[1],
269 | num_frames=frame_num_per_batch,
270 | generator=generator,
271 | guidance_scale=guidance_scale,
272 | num_inference_steps=num_inference_steps,
273 | is_last_batch=is_last_batch,
274 | overlap_frame_num=overlap_frame_num,
275 | fusion_interval=fusion_interval,
276 | latents_cache=latents_cache
277 | )
278 | if use_interpolation:
279 | if is_first_batch:
280 | out_samples = sample.frames[0][1:]
281 | else:
282 | out_samples = smooth_video_transition(out_samples, sample.frames[0][1+overlap_frame_num:], 2, frame_inter_model)
283 | else:
284 | out_sample = sample.frames[0][1:] if is_first_batch else sample.frames[0][1+overlap_frame_num:]
285 | out_samples.extend(out_sample)
286 | print(f"out_samples len: {len(out_samples)}")
287 |
288 | if is_last_batch:
289 | break
290 |
291 | if padding_frame_num is not None:
292 | out_samples = out_samples[:-padding_frame_num]
293 |
294 | print(f"output video length: {len(out_samples)}")
295 |
296 | time_end = time.time()
297 | print(f"time cost: {time_end - time_start} seconds")
298 |
299 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_video_path).split(".")[0]+ ".mp4"
300 |
301 | if not os.path.exists(save_path):
302 | os.makedirs(save_path, exist_ok=True)
303 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4")
304 |
305 | if target_fps != 12:
306 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps)
307 |
308 | export_to_video(out_samples, video_path, fps=target_fps)
309 | add_audio_to_video(video_path, args.driving_video_path, video_path.split(".")[0] + "_audio.mp4")
310 |
311 | if target_fps == 12:
312 | target_h, target_w = sample_size[0], sample_size[1]
313 | final_images = []
314 | for q in range(len(out_samples)):
315 | frame1 = image
316 | frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[q])).convert("RGB"), target_h, target_w)
317 | frame3 = Image.fromarray(np.array(out_samples[q])).convert("RGB")
318 |
319 | result = Image.new('RGB', (target_w * 3, target_h))
320 | result.paste(frame1, (0, 0))
321 | result.paste(frame2, (target_w, 0))
322 | result.paste(frame3, (target_w * 2, 0))
323 | final_images.append(np.array(result))
324 |
325 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4")
326 | write_mp4(video_out_path, final_images, fps=12)
327 |
328 | add_audio_to_video(video_out_path, args.driving_video_path, video_out_path.split(".")[0] + f"_audio.mp4")
329 |
330 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | chumpy==0.70
2 | decord==0.6.0
3 | diffusers==0.32.2
4 | einops==0.8.1
5 | facexlib==0.3.0
6 | gradio==5.16.0
7 | insightface==0.7.3
8 | moviepy==1.0.3
9 | numpy==1.26.4
10 | opencv_contrib_python==4.10.0.84
11 | opencv_python==4.10.0.84
12 | opencv_python_headless==4.10.0.84
13 | Pillow==11.1.0
14 | pytorch3d==0.7.8
15 | safetensors==0.5.2
16 | scikit-image==0.24.0
17 | timm==0.6.13
18 | torch==2.2.2+cu118
19 | tqdm==4.66.2
20 | transformers==4.37.2
21 | mediapipe==0.10.21
22 | librosa==0.10.2.post1
23 | onnxruntime-gpu==1.16.3
24 | accelerate==1.6.0
25 |
26 | # if torch or pytorch3d installation fails, try:
27 | # pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
28 | # pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
29 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 | import insightface
7 | import cv2
8 | import subprocess
9 | import argparse
10 | from decord import VideoReader
11 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
12 | from facexlib.parsing import init_parsing_model
13 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
14 | from insightface.app import FaceAnalysis
15 |
16 | from diffusers.models import AutoencoderKLCogVideoX
17 | from diffusers.utils import export_to_video, load_image
18 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
19 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
20 |
21 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
22 | from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline
23 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
24 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
25 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
26 |
27 |
28 | def crop_and_resize(image, height, width):
29 | image = np.array(image)
30 | image_height, image_width, _ = image.shape
31 | if image_height / image_width < height / width:
32 | croped_width = int(image_height / height * width)
33 | left = (image_width - croped_width) // 2
34 | image = image[:, left: left+croped_width]
35 | image = Image.fromarray(image).resize((width, height))
36 | else:
37 | pad = int((((width / height) * image_height) - image_width) / 2.)
38 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
39 | padded_image[:, pad:pad+image_width] = image
40 | image = Image.fromarray(padded_image).resize((width, height))
41 | return image
42 |
43 | def write_mp4(video_path, samples, fps=14, audio_bitrate="192k"):
44 | clip = ImageSequenceClip(samples, fps=fps)
45 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
46 | ffmpeg_params=["-crf", "18", "-preset", "slow"])
47 |
48 | def init_model(
49 | model_name: str = "pretrained_models/SkyReels-A1-5B/",
50 | subfolder: str = "outputs/",
51 | siglip_path: str = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384",
52 | weight_dtype=torch.bfloat16,
53 | ):
54 |
55 | lmk_extractor = LMKExtractor()
56 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,)
57 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
58 |
59 | face_helper = FaceRestoreHelper(
60 | upscale_factor=1,
61 | face_size=512,
62 | crop_ratio=(1, 1),
63 | det_model='retinaface_resnet50',
64 | save_ext='png',
65 | device="cuda",
66 | )
67 |
68 | siglip = SiglipVisionModel.from_pretrained(siglip_path)
69 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_path)
70 |
71 | transformer = CogVideoXTransformer3DModel.from_pretrained(
72 | model_name,
73 | subfolder="transformer",
74 | ).to(weight_dtype)
75 |
76 | vae = AutoencoderKLCogVideoX.from_pretrained(
77 | model_name,
78 | subfolder="vae"
79 | ).to(weight_dtype)
80 |
81 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
82 | model_name,
83 | subfolder="pose_guider"
84 | ).to(weight_dtype)
85 |
86 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
87 | model_name,
88 | transformer = transformer,
89 | vae = vae,
90 | lmk_encoder = lmk_encoder,
91 | image_encoder = siglip,
92 | feature_extractor = siglip_normalize,
93 | torch_dtype=weight_dtype)
94 | pipe.to("cuda")
95 | pipe.enable_model_cpu_offload()
96 | pipe.vae.enable_tiling()
97 |
98 | return pipe, face_helper, processor, lmk_extractor, vis
99 |
100 |
101 |
102 | def generate_video(
103 | pipe,
104 | face_helper,
105 | processor,
106 | lmk_extractor,
107 | vis,
108 | control_video_path: str = None,
109 | image_path: str = None,
110 | save_path: str = None,
111 | guidance_scale=3.0,
112 | seed=43,
113 | num_inference_steps=10,
114 | sample_size=[480, 720],
115 | max_frame_num=49,
116 | weight_dtype=torch.bfloat16,
117 | ):
118 |
119 | vr = VideoReader(control_video_path)
120 | fps = vr.get_avg_fps()
121 | video_length = len(vr)
122 |
123 | duration = video_length / fps
124 | target_times = np.arange(0, duration, 1/12)
125 | frame_indices = (target_times * fps).astype(np.int32)
126 |
127 | frame_indices = frame_indices[frame_indices < video_length]
128 | control_frames = vr.get_batch(frame_indices).asnumpy()[:(max_frame_num-1)]
129 |
130 | out_frames = len(control_frames) - 1
131 | if len(control_frames) < max_frame_num:
132 | video_lenght_add = max_frame_num - len(control_frames)
133 | control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-2], [control_frames[-1]] * video_lenght_add), axis=0)
134 |
135 | # driving video crop face
136 | driving_video_crop = []
137 | for control_frame in control_frames:
138 | frame, _, _ = processor.face_crop(control_frame)
139 | driving_video_crop.append(frame)
140 |
141 | image = load_image(image=image_path)
142 | image = crop_and_resize(image, sample_size[0], sample_size[1])
143 |
144 | with torch.no_grad():
145 | face_helper.clean_all()
146 | face_helper.read_image(np.array(image)[:, :, ::-1])
147 | face_helper.get_face_landmarks_5(only_center_face=True)
148 | face_helper.align_warp_face()
149 | if len(face_helper.cropped_faces) == 0:
150 | return
151 | align_face = face_helper.cropped_faces[0]
152 | image_face = align_face[:, :, ::-1]
153 |
154 | # ref image crop face
155 | ref_image, x1, y1 = processor.face_crop(np.array(image))
156 | face_h, face_w, _, = ref_image.shape
157 | source_image = ref_image
158 | driving_video = driving_video_crop
159 | out_frames = processor.preprocess_lmk3d(source_image, driving_video)
160 |
161 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0)
162 | for ii in range(rescale_motions.shape[0]):
163 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
164 | ref_image = cv2.resize(ref_image, (512, 512))
165 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
166 |
167 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
168 |
169 | first_motion = np.zeros_like(np.array(image))
170 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
171 | first_motion = first_motion[np.newaxis, :]
172 |
173 | motions = np.concatenate([first_motion, rescale_motions])
174 | input_video = motions[:max_frame_num]
175 |
176 | input_video = input_video[:max_frame_num]
177 | motions = np.array(input_video)
178 |
179 | # [F, H, W, C]
180 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0)
181 | input_video = input_video / 255
182 |
183 | out_samples = []
184 |
185 | generator = torch.Generator(device="cuda").manual_seed(seed)
186 | with torch.no_grad():
187 | sample = pipe(
188 | image=image,
189 | image_face=image_face,
190 | control_video = input_video,
191 | height = sample_size[0],
192 | width = sample_size[1],
193 | num_frames = 49,
194 | generator = generator,
195 | guidance_scale = guidance_scale,
196 | num_inference_steps = num_inference_steps,
197 | )
198 | out_samples.extend(sample.frames[0][2:])
199 |
200 | # export_to_video(out_samples, save_path, fps=12)
201 | control_frames = control_frames[1:]
202 | target_h, target_w = sample_size
203 | final_images = []
204 | for i in range(len(out_samples)):
205 | frame1 = image
206 | frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[i])).convert("RGB"), target_h, target_w)
207 | frame3 = Image.fromarray(np.array(out_samples[i])).convert("RGB")
208 | result = Image.new('RGB', (target_w * 3, target_h))
209 | result.paste(frame1, (0, 0))
210 | result.paste(frame2, (target_w, 0))
211 | result.paste(frame3, (target_w * 2, 0))
212 | final_images.append(np.array(result))
213 |
214 | write_mp4(save_path, final_images, fps=12)
215 |
216 |
217 |
218 | if __name__ == "__main__":
219 | control_video_zip = glob.glob("assets/driving_video/*.mp4")
220 | image_path_zip = glob.glob("assets/ref_images/*.png")
221 |
222 | guidance_scale = 3.0
223 | seed = 43
224 | num_inference_steps = 10
225 | sample_size = [480, 720]
226 | max_frame_num = 49
227 | weight_dtype = torch.bfloat16
228 |
229 | save_path = "outputs"
230 |
231 | # init model
232 | pipe, face_helper, processor, lmk_extractor, vis = init_model()
233 |
234 | for i in range(len(control_video_zip)):
235 | for j in range(len(image_path_zip)):
236 | generate_video(
237 | pipe,
238 | face_helper,
239 | processor,
240 | lmk_extractor,
241 | vis,
242 | control_video_path=control_video_zip[i],
243 | image_path=image_path_zip[j],
244 | save_path=save_path,
245 | guidance_scale=guidance_scale,
246 | seed=seed,
247 | num_inference_steps=num_inference_steps,
248 | sample_size=sample_size,
249 | max_frame_num=max_frame_num,
250 | weight_dtype=weight_dtype,
251 | )
252 |
--------------------------------------------------------------------------------
/skyreels_a1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/__init__.py
--------------------------------------------------------------------------------
/skyreels_a1/ddim_solver.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def append_dims(x, target_dims):
5 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
6 | dims_to_append = target_dims - x.ndim
7 | if dims_to_append < 0:
8 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
9 | return x[(...,) + (None,) * dims_to_append]
10 |
11 |
12 | # From LCMScheduler.get_scalings_for_boundary_condition_discrete
13 | def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
14 | scaled_timestep = timestep_scaling * timestep
15 | c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
16 | c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
17 | return c_skip, c_out
18 |
19 |
20 | def extract_into_tensor(a, t, x_shape):
21 | b, *_ = t.shape
22 | out = a.gather(-1, t)
23 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
24 |
25 |
26 |
27 | class DDIMSolver:
28 | def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
29 | # DDIM sampling parameters
30 | step_ratio = timesteps // ddim_timesteps
31 |
32 | self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
33 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
34 | self.ddim_alpha_cumprods_prev = np.asarray(
35 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
36 | )
37 | # convert to torch tensors
38 | self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
39 | self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
40 | self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
41 |
42 | def to(self, device):
43 | self.ddim_timesteps = self.ddim_timesteps.to(device)
44 | self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
45 | self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
46 | return self
47 |
48 | def ddim_step(self, pred_x0, pred_noise, timestep_index):
49 | alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
50 | dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
51 | x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
52 | return x_prev
--------------------------------------------------------------------------------
/skyreels_a1/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/models/__init__.py
--------------------------------------------------------------------------------
/skyreels_a1/pipeline_output.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 |
5 | from diffusers.utils import BaseOutput
6 |
7 |
8 | @dataclass
9 | class CogVideoXPipelineOutput(BaseOutput):
10 | r"""
11 | Output class for CogVideo pipelines.
12 |
13 | Args:
14 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
15 | List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
16 | denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
17 | `(batch_size, num_frames, channels, height, width)`.
18 | """
19 |
20 | frames: torch.Tensor
--------------------------------------------------------------------------------
/skyreels_a1/src/FLAME/FLAME.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import torch
17 | import torch.nn as nn
18 | import numpy as np
19 | np.bool = np.bool_
20 | np.int = np.int_
21 | np.float = np.float_
22 | np.complex = np.complex_
23 | np.object = np.object_
24 | np.unicode = np.unicode_
25 | np.str = np.str_
26 | import pickle
27 | import torch.nn.functional as F
28 |
29 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler
30 |
31 | def to_tensor(array, dtype=torch.float32):
32 | if 'torch.tensor' not in str(type(array)):
33 | return torch.tensor(array, dtype=dtype)
34 | def to_np(array, dtype=np.float32):
35 | if 'scipy.sparse' in str(type(array)):
36 | array = array.todense()
37 | return np.array(array, dtype=dtype)
38 |
39 | class Struct(object):
40 | def __init__(self, **kwargs):
41 | for key, val in kwargs.items():
42 | setattr(self, key, val)
43 |
44 | class FLAME(nn.Module):
45 | """
46 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py
47 | Given flame parameters this class generates a differentiable FLAME function
48 | which outputs the a mesh and 2D/3D facial landmarks
49 | """
50 | def __init__(self, flame_model_path='pretrained_models/FLAME/generic_model.pkl',
51 | flame_lmk_embedding_path='pretrained_models/FLAME/landmark_embedding.npy', n_shape=300, n_exp=50):
52 | super(FLAME, self).__init__()
53 |
54 | with open(flame_model_path, 'rb') as f:
55 | ss = pickle.load(f, encoding='latin1')
56 | flame_model = Struct(**ss)
57 |
58 | self.n_shape = n_shape
59 | self.n_exp = n_exp
60 | self.dtype = torch.float32
61 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long))
62 | # The vertices of the template model
63 | print('Using generic FLAME model')
64 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype))
65 |
66 | # The shape components and expression
67 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
68 | shapedirs = torch.cat([shapedirs[:,:,:n_shape], shapedirs[:,:,300:300+n_exp]], 2)
69 | self.register_buffer('shapedirs', shapedirs)
70 | # The pose components
71 | num_pose_basis = flame_model.posedirs.shape[-1]
72 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
73 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype))
74 | #
75 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype))
76 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1
77 | self.register_buffer('parents', parents)
78 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype))
79 |
80 |
81 | self.register_buffer('l_eyelid', torch.from_numpy(np.load(f'pretrained_models/smirk/l_eyelid.npy')).to(self.dtype)[None])
82 | self.register_buffer('r_eyelid', torch.from_numpy(np.load(f'pretrained_models/smirk/r_eyelid.npy')).to(self.dtype)[None])
83 | # import pdb;pdb.set_trace()
84 |
85 | # Fixing Eyeball and neck rotation
86 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False)
87 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose,
88 | requires_grad=False))
89 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False)
90 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose,
91 | requires_grad=False))
92 |
93 | # Static and Dynamic Landmark embeddings for FLAME
94 | lmk_embeddings = np.load(flame_lmk_embedding_path, allow_pickle=True, encoding='latin1')
95 | lmk_embeddings = lmk_embeddings[()]
96 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long())
97 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype))
98 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long())
99 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype))
100 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long())
101 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype))
102 |
103 | neck_kin_chain = []; NECK_IDX=1
104 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
105 | while curr_idx != -1:
106 | neck_kin_chain.append(curr_idx)
107 | curr_idx = self.parents[curr_idx]
108 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain))
109 |
110 | lmk_embeddings_mp = np.load("pretrained_models/smirk/mediapipe_landmark_embedding.npz")
111 | self.register_buffer('mp_lmk_faces_idx', torch.from_numpy(lmk_embeddings_mp['lmk_face_idx'].astype('int32')).long())
112 | self.register_buffer('mp_lmk_bary_coords', torch.from_numpy(lmk_embeddings_mp['lmk_b_coords']).to(self.dtype))
113 |
114 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx,
115 | dynamic_lmk_b_coords,
116 | neck_kin_chain, dtype=torch.float32):
117 | """
118 | Selects the face contour depending on the reletive position of the head
119 | Input:
120 | vertices: N X num_of_vertices X 3
121 | pose: N X full pose
122 | dynamic_lmk_faces_idx: The list of contour face indexes
123 | dynamic_lmk_b_coords: The list of contour barycentric weights
124 | neck_kin_chain: The tree to consider for the relative rotation
125 | dtype: Data type
126 | return:
127 | The contour face indexes and the corresponding barycentric weights
128 | """
129 |
130 | batch_size = pose.shape[0]
131 |
132 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
133 | neck_kin_chain)
134 | rot_mats = batch_rodrigues(
135 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
136 |
137 | rel_rot_mat = torch.eye(3, device=pose.device,
138 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1)
139 | for idx in range(len(neck_kin_chain)):
140 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
141 |
142 | y_rot_angle = torch.round(
143 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
144 | max=39)).to(dtype=torch.long)
145 |
146 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
147 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
148 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
149 | y_rot_angle = (neg_mask * neg_vals +
150 | (1 - neg_mask) * y_rot_angle)
151 |
152 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
153 | 0, y_rot_angle)
154 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
155 | 0, y_rot_angle)
156 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
157 |
158 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords):
159 | """
160 | Calculates landmarks by barycentric interpolation
161 | Input:
162 | vertices: torch.tensor NxVx3, dtype = torch.float32
163 | The tensor of input vertices
164 | faces: torch.tensor (N*F)x3, dtype = torch.long
165 | The faces of the mesh
166 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long
167 | The tensor with the indices of the faces used to calculate the
168 | landmarks.
169 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32
170 | The tensor of barycentric coordinates that are used to interpolate
171 | the landmarks
172 |
173 | Returns:
174 | landmarks: torch.tensor NxLx3, dtype = torch.float32
175 | The coordinates of the landmarks for each mesh in the batch
176 | """
177 | # Extract the indices of the vertices for each face
178 | # NxLx3
179 | batch_size, num_verts = vertices.shape[:dd2]
180 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
181 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1)
182 |
183 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to(
184 | device=vertices.device) * num_verts
185 |
186 | lmk_vertices = vertices.view(-1, 3)[lmk_faces]
187 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
188 | return landmarks
189 |
190 | def seletec_3d68(self, vertices):
191 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
192 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1),
193 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
194 | return landmarks3d
195 |
196 | def get_landmarks(self, vertices):
197 | """
198 | Input:
199 | shape_params: N X number of shape parameters
200 | expression_params: N X number of expression parameters
201 | pose_params: N X number of pose parameters (6)
202 | return:d
203 | vertices: N X V X 3
204 | landmarks: N X number of landmarks X 3
205 | """
206 | batch_size = vertices.shape[0]
207 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
208 |
209 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
210 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
211 |
212 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
213 | full_pose, self.dynamic_lmk_faces_idx,
214 | self.dynamic_lmk_bary_coords,
215 | self.neck_kin_chain, dtype=self.dtype)
216 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
217 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
218 |
219 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
220 | lmk_faces_idx,
221 | lmk_bary_coords)
222 | bz = vertices.shape[0]
223 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
224 | self.full_lmk_faces_idx.repeat(bz, 1),
225 | self.full_lmk_bary_coords.repeat(bz, 1, 1))
226 | return vertices, landmarks2d, landmarks3d
227 |
228 |
229 | def forward(self, param_dictionary, zero_expression=False, zero_shape=False, zero_pose=False):
230 | shape_params = param_dictionary['shape_params']
231 | expression_params = param_dictionary['expression_params']
232 | pose_params = param_dictionary.get('pose_params', None)
233 | jaw_params = param_dictionary.get('jaw_params', None)
234 | eye_pose_params = param_dictionary.get('eye_pose_params', None)
235 | neck_pose_params = param_dictionary.get('neck_pose_params', None)
236 | eyelid_params = param_dictionary.get('eyelid_params', None)
237 |
238 | batch_size = shape_params.shape[0]
239 |
240 | # Adjust expression params size if needed
241 | if expression_params.shape[1] < self.n_exp:
242 | expression_params = torch.cat([expression_params, torch.zeros(expression_params.shape[0], self.n_exp - expression_params.shape[1]).to(shape_params.device)], dim=1)
243 |
244 | if shape_params.shape[1] < self.n_shape:
245 | shape_params = torch.cat([shape_params, torch.zeros(shape_params.shape[0], self.n_shape - shape_params.shape[1]).to(shape_params.device)], dim=1)
246 |
247 | # Zero out the expression and pose parameters if needed
248 | if zero_expression:
249 | expression_params = torch.zeros_like(expression_params).to(shape_params.device)
250 | jaw_params = torch.zeros_like(jaw_params).to(shape_params.device)
251 |
252 | if zero_shape:
253 | shape_params = torch.zeros_like(shape_params).to(shape_params.device)
254 |
255 |
256 | if zero_pose:
257 | pose_params = torch.zeros_like(pose_params).to(shape_params.device)
258 | pose_params[...,0] = 0.2
259 | pose_params[...,1] = -0.7
260 |
261 | if pose_params is None:
262 | pose_params = self.pose_params.expand(batch_size, -1)
263 |
264 | if eye_pose_params is None:
265 | eye_pose_params = self.eye_pose.expand(batch_size, -1)
266 |
267 | if neck_pose_params is None:
268 | neck_pose_params = self.neck_pose.expand(batch_size, -1)
269 |
270 |
271 | betas = torch.cat([shape_params, expression_params], dim=1)
272 | full_pose = torch.cat([pose_params, neck_pose_params, jaw_params, eye_pose_params], dim=1)
273 | # import pdb;pdb.set_trace()
274 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
275 |
276 | vertices, _ = lbs(betas, full_pose, template_vertices,
277 | self.shapedirs, self.posedirs,
278 | self.J_regressor, self.parents,
279 | self.lbs_weights, dtype=self.dtype)
280 | # import pdb;pdb.set_trace()
281 | if eyelid_params is not None:
282 | vertices = vertices + self.r_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 1:2, None]
283 | vertices = vertices + self.l_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 0:1, None]
284 |
285 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
286 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
287 |
288 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
289 | full_pose, self.dynamic_lmk_faces_idx,
290 | self.dynamic_lmk_bary_coords,
291 | self.neck_kin_chain, dtype=self.dtype)
292 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
293 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
294 |
295 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
296 | lmk_faces_idx,
297 | lmk_bary_coords)
298 | bz = vertices.shape[0]
299 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
300 | self.full_lmk_faces_idx.repeat(bz, 1),
301 | self.full_lmk_bary_coords.repeat(bz, 1, 1))
302 |
303 | landmarksmp = vertices2landmarks(vertices, self.faces_tensor,
304 | self.mp_lmk_faces_idx.repeat(vertices.shape[0], 1),
305 | self.mp_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
306 |
307 | return {
308 | 'vertices': vertices,
309 | 'landmarks_fan': landmarks2d,
310 | 'landmarks_fan_3d': landmarks3d,
311 | 'landmarks_mp': landmarksmp
312 | }
313 |
314 |
315 |
--------------------------------------------------------------------------------
/skyreels_a1/src/FLAME/lbs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | from __future__ import absolute_import
18 | from __future__ import print_function
19 | from __future__ import division
20 |
21 | import numpy as np
22 |
23 | import torch
24 | import torch.nn.functional as F
25 |
26 | def rot_mat_to_euler(rot_mats):
27 | # Calculates rotation matrix to euler angles
28 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
29 |
30 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
31 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
32 | return torch.atan2(-rot_mats[:, 2, 0], sy)
33 |
34 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
35 | dynamic_lmk_b_coords,
36 | neck_kin_chain, dtype=torch.float32):
37 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks
38 |
39 |
40 | To do so, we first compute the rotation of the neck around the y-axis
41 | and then use a pre-computed look-up table to find the faces and the
42 | barycentric coordinates that will be used.
43 |
44 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
45 | for providing the original TensorFlow implementation and for the LUT.
46 |
47 | Parameters
48 | ----------
49 | vertices: torch.tensor BxVx3, dtype = torch.float32
50 | The tensor of input vertices
51 | pose: torch.tensor Bx(Jx3), dtype = torch.float32
52 | The current pose of the body model
53 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
54 | The look-up table from neck rotation to faces
55 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
56 | The look-up table from neck rotation to barycentric coordinates
57 | neck_kin_chain: list
58 | A python list that contains the indices of the joints that form the
59 | kinematic chain of the neck.
60 | dtype: torch.dtype, optional
61 |
62 | Returns
63 | -------
64 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
65 | A tensor of size BxL that contains the indices of the faces that
66 | will be used to compute the current dynamic landmarks.
67 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
68 | A tensor of size BxL that contains the indices of the faces that
69 | will be used to compute the current dynamic landmarks.
70 | '''
71 |
72 | batch_size = vertices.shape[0]
73 |
74 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
75 | neck_kin_chain)
76 | rot_mats = batch_rodrigues(
77 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
78 |
79 | rel_rot_mat = torch.eye(3, device=vertices.device,
80 | dtype=dtype).unsqueeze_(dim=0)
81 | for idx in range(len(neck_kin_chain)):
82 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
83 |
84 | y_rot_angle = torch.round(
85 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
86 | max=39)).to(dtype=torch.long)
87 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
88 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
89 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
90 | y_rot_angle = (neg_mask * neg_vals +
91 | (1 - neg_mask) * y_rot_angle)
92 |
93 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
94 | 0, y_rot_angle)
95 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
96 | 0, y_rot_angle)
97 |
98 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
99 |
100 |
101 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
102 | ''' Calculates landmarks by barycentric interpolation
103 |
104 | Parameters
105 | ----------
106 | vertices: torch.tensor BxVx3, dtype = torch.float32
107 | The tensor of input vertices
108 | faces: torch.tensor Fx3, dtype = torch.long
109 | The faces of the mesh
110 | lmk_faces_idx: torch.tensor L, dtype = torch.long
111 | The tensor with the indices of the faces used to calculate the
112 | landmarks.
113 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
114 | The tensor of barycentric coordinates that are used to interpolate
115 | the landmarks
116 |
117 | Returns
118 | -------
119 | landmarks: torch.tensor BxLx3, dtype = torch.float32
120 | The coordinates of the landmarks for each mesh in the batch
121 | '''
122 | # Extract the indices of the vertices for each face
123 | # BxLx3
124 | batch_size, num_verts = vertices.shape[:2]
125 | device = vertices.device
126 |
127 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
128 | batch_size, -1, 3)
129 |
130 | lmk_faces += torch.arange(
131 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
132 |
133 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
134 | batch_size, -1, 3, 3)
135 |
136 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
137 | return landmarks
138 |
139 |
140 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
141 | lbs_weights, pose2rot=True, dtype=torch.float32):
142 | ''' Performs Linear Blend Skinning with the given shape and pose parameters
143 |
144 | Parameters
145 | ----------
146 | betas : torch.tensor BxNB
147 | The tensor of shape parameters
148 | pose : torch.tensor Bx(J + 1) * 3
149 | The pose parameters in axis-angle format
150 | v_template torch.tensor BxVx3
151 | The template mesh that will be deformed
152 | shapedirs : torch.tensor 1xNB
153 | The tensor of PCA shape displacements
154 | posedirs : torch.tensor Px(V * 3)
155 | The pose PCA coefficients
156 | J_regressor : torch.tensor JxV
157 | The regressor array that is used to calculate the joints from
158 | the position of the vertices
159 | parents: torch.tensor J
160 | The array that describes the kinematic tree for the model
161 | lbs_weights: torch.tensor N x V x (J + 1)
162 | The linear blend skinning weights that represent how much the
163 | rotation matrix of each part affects each vertex
164 | pose2rot: bool, optional
165 | Flag on whether to convert the input pose tensor to rotation
166 | matrices. The default value is True. If False, then the pose tensor
167 | should already contain rotation matrices and have a size of
168 | Bx(J + 1)x9
169 | dtype: torch.dtype, optional
170 |
171 | Returns
172 | -------
173 | verts: torch.tensor BxVx3
174 | The vertices of the mesh after applying the shape and pose
175 | displacements.
176 | joints: torch.tensor BxJx3
177 | The joints of the model
178 | '''
179 |
180 | batch_size = max(betas.shape[0], pose.shape[0])
181 | device = betas.device
182 |
183 | # Add shape contribution
184 | v_shaped = v_template + blend_shapes(betas, shapedirs)
185 |
186 | # Get the joints
187 | # NxJx3 array
188 | J = vertices2joints(J_regressor, v_shaped)
189 |
190 | # 3. Add pose blend shapes
191 | # N x J x 3 x 3
192 | ident = torch.eye(3, dtype=dtype, device=device)
193 | if pose2rot:
194 | rot_mats = batch_rodrigues(
195 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
196 |
197 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
198 | # (N x P) x (P, V * 3) -> N x V x 3
199 | pose_offsets = torch.matmul(pose_feature, posedirs) \
200 | .view(batch_size, -1, 3)
201 | else:
202 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
203 | rot_mats = pose.view(batch_size, -1, 3, 3)
204 |
205 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
206 | posedirs).view(batch_size, -1, 3)
207 |
208 | v_posed = pose_offsets + v_shaped
209 | # 4. Get the global joint location
210 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
211 |
212 | # 5. Do skinning:
213 | # W is N x V x (J + 1)
214 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
215 | # (N x V x (J + 1)) x (N x (J + 1) x 16)
216 | num_joints = J_regressor.shape[0]
217 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
218 | .view(batch_size, -1, 4, 4)
219 |
220 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
221 | dtype=dtype, device=device)
222 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
223 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
224 |
225 | verts = v_homo[:, :, :3, 0]
226 |
227 | return verts, J_transformed
228 |
229 |
230 | def vertices2joints(J_regressor, vertices):
231 | ''' Calculates the 3D joint locations from the vertices
232 |
233 | Parameters
234 | ----------
235 | J_regressor : torch.tensor JxV
236 | The regressor array that is used to calculate the joints from the
237 | position of the vertices
238 | vertices : torch.tensor BxVx3
239 | The tensor of mesh vertices
240 |
241 | Returns
242 | -------
243 | torch.tensor BxJx3
244 | The location of the joints
245 | '''
246 |
247 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
248 |
249 |
250 | def blend_shapes(betas, shape_disps):
251 | ''' Calculates the per vertex displacement due to the blend shapes
252 |
253 |
254 | Parameters
255 | ----------
256 | betas : torch.tensor Bx(num_betas)
257 | Blend shape coefficients
258 | shape_disps: torch.tensor Vx3x(num_betas)
259 | Blend shapes
260 |
261 | Returns
262 | -------
263 | torch.tensor BxVx3
264 | The per-vertex displacement due to shape deformation
265 | '''
266 |
267 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
268 | # i.e. Multiply each shape displacement by its corresponding beta and
269 | # then sum them.
270 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
271 | return blend_shape
272 |
273 |
274 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
275 | ''' Calculates the rotation matrices for a batch of rotation vectors
276 | Parameters
277 | ----------
278 | rot_vecs: torch.tensor Nx3
279 | array of N axis-angle vectors
280 | Returns
281 | -------
282 | R: torch.tensor Nx3x3
283 | The rotation matrices for the given axis-angle parameters
284 | '''
285 |
286 | batch_size = rot_vecs.shape[0]
287 | device = rot_vecs.device
288 |
289 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
290 | rot_dir = rot_vecs / angle
291 |
292 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
293 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
294 |
295 | # Bx1 arrays
296 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
297 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
298 |
299 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
300 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
301 | .view((batch_size, 3, 3))
302 |
303 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
304 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
305 | return rot_mat
306 |
307 |
308 | def transform_mat(R, t):
309 | ''' Creates a batch of transformation matrices
310 | Args:
311 | - R: Bx3x3 array of a batch of rotation matrices
312 | - t: Bx3x1 array of a batch of translation vectors
313 | Returns:
314 | - T: Bx4x4 Transformation matrix
315 | '''
316 | # No padding left or right, only add an extra row
317 | return torch.cat([F.pad(R, [0, 0, 0, 1]),
318 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
319 |
320 |
321 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
322 | """
323 | Applies a batch of rigid transformations to the joints
324 |
325 | Parameters
326 | ----------
327 | rot_mats : torch.tensor BxNx3x3
328 | Tensor of rotation matrices
329 | joints : torch.tensor BxNx3
330 | Locations of joints
331 | parents : torch.tensor BxN
332 | The kinematic tree of each object
333 | dtype : torch.dtype, optional:
334 | The data type of the created tensors, the default is torch.float32
335 |
336 | Returns
337 | -------
338 | posed_joints : torch.tensor BxNx3
339 | The locations of the joints after applying the pose rotations
340 | rel_transforms : torch.tensor BxNx4x4
341 | The relative (with respect to the root joint) rigid transformations
342 | for all the joints
343 | """
344 |
345 | joints = torch.unsqueeze(joints, dim=-1)
346 |
347 | rel_joints = joints.clone()
348 | rel_joints[:, 1:] -= joints[:, parents[1:]]
349 |
350 | # transforms_mat = transform_mat(
351 | # rot_mats.view(-1, 3, 3),
352 | # rel_joints.view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4)
353 | transforms_mat = transform_mat(
354 | rot_mats.view(-1, 3, 3),
355 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
356 |
357 | transform_chain = [transforms_mat[:, 0]]
358 | for i in range(1, parents.shape[0]):
359 | # Subtract the joint location at the rest pose
360 | # No need for rotation, since it's identity when at rest
361 | curr_res = torch.matmul(transform_chain[parents[i]],
362 | transforms_mat[:, i])
363 | transform_chain.append(curr_res)
364 |
365 | transforms = torch.stack(transform_chain, dim=1)
366 |
367 | # The last column of the transformations contains the posed joints
368 | posed_joints = transforms[:, :, :3, 3]
369 |
370 | # The last column of the transformations contains the posed joints
371 | posed_joints = transforms[:, :, :3, 3]
372 |
373 | joints_homogen = F.pad(joints, [0, 0, 0, 1])
374 |
375 | rel_transforms = transforms - F.pad(
376 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
377 |
378 | return posed_joints, rel_transforms
--------------------------------------------------------------------------------
/skyreels_a1/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/__init__.py
--------------------------------------------------------------------------------
/skyreels_a1/src/frame_interpolation.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/dajes/frame-interpolation-pytorch
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import bisect
7 | import shutil
8 | import pdb
9 | from tqdm import tqdm
10 |
11 | def init_frame_interpolation_model(checkpoint_name, device="cuda"):
12 | print(f"Initializing frame interpolation model from {checkpoint_name}")
13 |
14 | model = torch.jit.load(checkpoint_name, map_location='cpu')
15 | model.eval()
16 | model = model.half()
17 | model = model.to(device=device)
18 | return model
19 |
20 |
21 | def batch_images_interpolation_tool(input_tensor, model, inter_frames=1):
22 |
23 | video_tensor = []
24 | frame_num = input_tensor.shape[2] # bs, channel, frame, height, width
25 |
26 | for idx in tqdm(range(frame_num-1)):
27 | image1 = input_tensor[:,:,idx]
28 | image2 = input_tensor[:,:,idx+1]
29 |
30 | results = [image1, image2]
31 |
32 | inter_frames = int(inter_frames)
33 | idxes = [0, inter_frames + 1]
34 | remains = list(range(1, inter_frames + 1))
35 |
36 | splits = torch.linspace(0, 1, inter_frames + 2)
37 |
38 | for _ in range(len(remains)):
39 | starts = splits[idxes[:-1]]
40 | ends = splits[idxes[1:]]
41 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
42 | matrix = torch.argmin(distances).item()
43 | start_i, step = np.unravel_index(matrix, distances.shape)
44 | end_i = start_i + 1
45 |
46 | x0 = results[start_i]
47 | x1 = results[end_i]
48 |
49 | x0 = x0.half()
50 | x1 = x1.half()
51 | x0 = x0.cuda()
52 | x1 = x1.cuda()
53 |
54 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
55 |
56 | with torch.no_grad():
57 | prediction = model(x0, x1, dt)
58 | insert_position = bisect.bisect_left(idxes, remains[step])
59 | idxes.insert(insert_position, remains[step])
60 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
61 | del remains[step]
62 |
63 | for sub_idx in range(len(results)-1):
64 | video_tensor.append(results[sub_idx].unsqueeze(2))
65 |
66 | video_tensor.append(input_tensor[:,:,-1].unsqueeze(2))
67 | video_tensor = torch.cat(video_tensor, dim=2)
68 | return video_tensor
--------------------------------------------------------------------------------
/skyreels_a1/src/media_pipe/draw_util_2d.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mediapipe as mp
3 | import numpy as np
4 | from mediapipe.framework.formats import landmark_pb2
5 |
6 | class FaceMeshVisualizer2d:
7 | # def __init__(self,
8 | # forehead_edge=False,
9 | # upface_only=False,
10 | # draw_eye=True,
11 | # draw_head=False,
12 | # draw_iris=True,
13 | # draw_eyebrow=True,
14 | # draw_mouse=True,
15 | # draw_nose=True,
16 | # draw_pupil=True
17 | # ):
18 | def __init__(self,
19 | forehead_edge=True,
20 | upface_only=True,
21 | draw_eye=True,
22 | draw_head=True,
23 | draw_iris=True,
24 | draw_eyebrow=True,
25 | draw_mouse=True,
26 | draw_nose=True,
27 | draw_pupil=True
28 | ):
29 | self.mp_drawing = mp.solutions.drawing_utils
30 | mp_face_mesh = mp.solutions.face_mesh
31 | self.mp_face_mesh = mp_face_mesh
32 | self.forehead_edge = forehead_edge
33 |
34 | DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
35 | f_thick = 1
36 | f_rad = 1
37 | right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
38 | right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
39 | right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
40 | left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
41 | left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
42 | left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
43 | head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
44 | nose_draw = DrawingSpec(color=(200, 200, 200), thickness=f_thick, circle_radius=f_rad)
45 |
46 | mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
47 | mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
48 |
49 | mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
50 | mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
51 |
52 | mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
53 | mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
54 |
55 | mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
56 | mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad)
57 |
58 | FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)]
59 | FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)]
60 |
61 | FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)]
62 | FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)]
63 |
64 | FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)]
65 | FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)]
66 |
67 | FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)]
68 | FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)]
69 |
70 | FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)]
71 |
72 | # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
73 | face_connection_spec = {}
74 |
75 | #from IPython import embed
76 | #embed()
77 | if self.forehead_edge:
78 | for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
79 | face_connection_spec[edge] = head_draw
80 | else:
81 | if draw_head:
82 | FACEMESH_CUSTOM_FACE_OVAL_sorted = sorted(FACEMESH_CUSTOM_FACE_OVAL)
83 | if upface_only:
84 | for edge in [FACEMESH_CUSTOM_FACE_OVAL_sorted[edge_idx] for edge_idx in [1,2,9,12,13,16,22,25]]:
85 | face_connection_spec[edge] = head_draw
86 | else:
87 | for edge in FACEMESH_CUSTOM_FACE_OVAL_sorted:
88 | face_connection_spec[edge] = head_draw
89 |
90 | if draw_eye:
91 | for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
92 | face_connection_spec[edge] = left_eye_draw
93 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
94 | face_connection_spec[edge] = right_eye_draw
95 |
96 | if draw_eyebrow:
97 | for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
98 | face_connection_spec[edge] = left_eyebrow_draw
99 |
100 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
101 | face_connection_spec[edge] = right_eyebrow_draw
102 |
103 | if draw_iris:
104 | for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
105 | face_connection_spec[edge] = left_iris_draw
106 | for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
107 | face_connection_spec[edge] = right_iris_draw
108 |
109 | #for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
110 | # face_connection_spec[edge] = right_eyebrow_draw
111 | # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
112 | # face_connection_spec[edge] = right_iris_draw
113 |
114 | # for edge in mp_face_mesh.FACEMESH_LIPS:
115 | # face_connection_spec[edge] = mouth_draw
116 |
117 | if draw_mouse:
118 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
119 | face_connection_spec[edge] = mouth_draw_obl
120 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
121 | face_connection_spec[edge] = mouth_draw_obr
122 | for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
123 | face_connection_spec[edge] = mouth_draw_ibl
124 | for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
125 | face_connection_spec[edge] = mouth_draw_ibr
126 | for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
127 | face_connection_spec[edge] = mouth_draw_otl
128 | for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
129 | face_connection_spec[edge] = mouth_draw_otr
130 | for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
131 | face_connection_spec[edge] = mouth_draw_itl
132 | for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
133 | face_connection_spec[edge] = mouth_draw_itr
134 |
135 | self.face_connection_spec = face_connection_spec
136 |
137 | self.pupil_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
138 | self.nose_landmark_spec = {4: nose_draw}
139 |
140 | self.draw_pupil = draw_pupil
141 | self.draw_nose = draw_nose
142 |
143 | def draw_points(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
144 | """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
145 | landmarks. Until our PR is merged into mediapipe, we need this separate method."""
146 | if len(image.shape) != 3:
147 | raise ValueError("Input image must be H,W,C.")
148 | image_rows, image_cols, image_channels = image.shape
149 | if image_channels != 3: # BGR channels
150 | raise ValueError('Input image must contain three channel bgr data.')
151 | for idx, landmark in enumerate(landmark_list.landmark):
152 | if idx not in drawing_spec:
153 | continue
154 |
155 | if (
156 | (landmark.HasField('visibility') and landmark.visibility < 0.9) or
157 | (landmark.HasField('presence') and landmark.presence < 0.5)
158 | ):
159 | continue
160 | if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
161 | continue
162 |
163 | image_x = int(image_cols * landmark.x)
164 | image_y = int(image_rows * landmark.y)
165 |
166 | draw_color = drawing_spec[idx].color
167 | image[image_y - halfwidth : image_y + halfwidth, image_x - halfwidth : image_x + halfwidth, :] = draw_color
168 |
169 |
170 | def draw_landmarks(self, image_size, keypoints, normed=False):
171 | ini_size = [512, 512]
172 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
173 | if keypoints is not None:
174 | new_landmarks = landmark_pb2.NormalizedLandmarkList()
175 | for i in range(keypoints.shape[0]):
176 | landmark = new_landmarks.landmark.add()
177 | if normed:
178 | landmark.x = keypoints[i, 0]
179 | landmark.y = keypoints[i, 1]
180 | else:
181 | landmark.x = keypoints[i, 0] / image_size[0]
182 | landmark.y = keypoints[i, 1] / image_size[1]
183 | landmark.z = 1.0
184 |
185 | self.mp_drawing.draw_landmarks(
186 | image=image,
187 | landmark_list=new_landmarks,
188 | connections=self.face_connection_spec.keys(),
189 | landmark_drawing_spec=None,
190 | connection_drawing_spec=self.face_connection_spec
191 | )
192 |
193 | # if self.draw_pupil:
194 | # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2)
195 |
196 | # if self.draw_nose:
197 | # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
198 |
199 | if self.draw_pupil:
200 | self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1)
201 |
202 | if self.draw_nose:
203 | self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
204 |
205 | image = cv2.resize(image, (image_size[0], image_size[1]))
206 |
207 | return image
208 |
209 | def draw_landmarks_v2(self, image_size, keypoints, normed=False):
210 | # ini_size = [512, 512]
211 | ini_size = [image_size[0], image_size[1]]
212 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
213 | if keypoints is not None:
214 | new_landmarks = landmark_pb2.NormalizedLandmarkList()
215 | for i in range(keypoints.shape[0]):
216 | landmark = new_landmarks.landmark.add()
217 | if normed:
218 | landmark.x = keypoints[i, 0]
219 | landmark.y = keypoints[i, 1]
220 | else:
221 | landmark.x = keypoints[i, 0] / image_size[0]
222 | landmark.y = keypoints[i, 1] / image_size[1]
223 | landmark.z = 1.0
224 |
225 | self.mp_drawing.draw_landmarks(
226 | image=image,
227 | landmark_list=new_landmarks,
228 | connections=self.face_connection_spec.keys(),
229 | landmark_drawing_spec=None,
230 | connection_drawing_spec=self.face_connection_spec
231 | )
232 |
233 | # if self.draw_pupil:
234 | # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2)
235 |
236 | # if self.draw_nose:
237 | # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
238 |
239 | if self.draw_pupil:
240 | self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1)
241 |
242 | if self.draw_nose:
243 | self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
244 |
245 | image = cv2.resize(image, (image_size[0], image_size[1]))
246 |
247 | return image
248 |
249 | def draw_landmarks_v3(self, image_size, resize_size, keypoints, normed=False):
250 | # ini_size = [512, 512]
251 | ini_size = [resize_size[0], resize_size[1]]
252 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
253 | if keypoints is not None:
254 | new_landmarks = landmark_pb2.NormalizedLandmarkList()
255 | for i in range(keypoints.shape[0]):
256 | landmark = new_landmarks.landmark.add()
257 | if normed:
258 | # landmark.x = keypoints[i, 0] * resize_size[0] / image_size[0]
259 | # landmark.y = keypoints[i, 1] * resize_size[1] / image_size[1]
260 | landmark.x = keypoints[i, 0]
261 | landmark.y = keypoints[i, 1]
262 | else:
263 | landmark.x = keypoints[i, 0] / image_size[0]
264 | landmark.y = keypoints[i, 1] / image_size[1]
265 | landmark.z = 1.0
266 |
267 | self.mp_drawing.draw_landmarks(
268 | image=image,
269 | landmark_list=new_landmarks,
270 | connections=self.face_connection_spec.keys(),
271 | landmark_drawing_spec=None,
272 | connection_drawing_spec=self.face_connection_spec
273 | )
274 |
275 | # if self.draw_pupil:
276 | # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2)
277 |
278 | # if self.draw_nose:
279 | # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
280 |
281 | if self.draw_pupil:
282 | self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1)
283 |
284 | if self.draw_nose:
285 | self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
286 |
287 | image = cv2.resize(image, (resize_size[0], resize_size[1]))
288 |
289 | return image
290 |
--------------------------------------------------------------------------------
/skyreels_a1/src/media_pipe/mp_models/blaze_face_short_range.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/media_pipe/mp_models/blaze_face_short_range.tflite
--------------------------------------------------------------------------------
/skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task
--------------------------------------------------------------------------------
/skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task
--------------------------------------------------------------------------------
/skyreels_a1/src/media_pipe/mp_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import time
5 | from tqdm import tqdm
6 | import multiprocessing
7 | import glob
8 |
9 | import mediapipe as mp
10 | from mediapipe import solutions
11 | from mediapipe.framework.formats import landmark_pb2
12 | from mediapipe.tasks import python
13 | from mediapipe.tasks.python import vision
14 | from . import face_landmark
15 |
16 | CUR_DIR = os.path.dirname(__file__)
17 |
18 |
19 | class LMKExtractor():
20 | def __init__(self, FPS=25):
21 | # Create an FaceLandmarker object.
22 | self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE
23 | base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task'))
24 | base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU
25 | options = vision.FaceLandmarkerOptions(base_options=base_options,
26 | running_mode=self.mode,
27 | # min_face_detection_confidence=0.3,
28 | # min_face_presence_confidence=0.3,
29 | # min_tracking_confidence=0.3,
30 | output_face_blendshapes=True,
31 | output_facial_transformation_matrixes=True,
32 | num_faces=1)
33 | self.detector = face_landmark.FaceLandmarker.create_from_options(options)
34 | self.last_ts = 0
35 | self.frame_ms = int(1000 / FPS)
36 |
37 | det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite'))
38 | det_options = vision.FaceDetectorOptions(base_options=det_base_options)
39 | self.det_detector = vision.FaceDetector.create_from_options(det_options)
40 |
41 |
42 | def __call__(self, img):
43 | frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
44 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
45 | t0 = time.time()
46 | if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO:
47 | det_result = self.det_detector.detect(image)
48 | if len(det_result.detections) != 1:
49 | return None
50 | self.last_ts += self.frame_ms
51 | try:
52 | detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts)
53 | except:
54 | return None
55 | elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE:
56 | # det_result = self.det_detector.detect(image)
57 |
58 | # if len(det_result.detections) != 1:
59 | # return None
60 | try:
61 | detection_result, mesh3d = self.detector.detect(image)
62 | except:
63 | return None
64 |
65 |
66 | bs_list = detection_result.face_blendshapes
67 | if len(bs_list) == 1:
68 | bs = bs_list[0]
69 | bs_values = []
70 | for index in range(len(bs)):
71 | bs_values.append(bs[index].score)
72 | bs_values = bs_values[1:] # remove neutral
73 | trans_mat = detection_result.facial_transformation_matrixes[0]
74 | face_landmarks_list = detection_result.face_landmarks
75 | face_landmarks = face_landmarks_list[0]
76 | lmks = []
77 | for index in range(len(face_landmarks)):
78 | x = face_landmarks[index].x
79 | y = face_landmarks[index].y
80 | z = face_landmarks[index].z
81 | lmks.append([x, y, z])
82 | lmks = np.array(lmks)
83 |
84 | lmks3d = np.array(mesh3d.vertex_buffer)
85 | lmks3d = lmks3d.reshape(-1, 5)[:, :3]
86 | mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1
87 |
88 | return {
89 | "lmks": lmks,
90 | 'lmks3d': lmks3d,
91 | "trans_mat": trans_mat,
92 | 'faces': mp_tris,
93 | "bs": bs_values
94 | }
95 | else:
96 | # print('multiple faces in the image: {}'.format(img_path))
97 | return None
98 |
--------------------------------------------------------------------------------
/skyreels_a1/src/media_pipe/readme:
--------------------------------------------------------------------------------
1 | The landmark file defines the barycentric embedding of 105 points of the Mediapipe mesh in the surface of FLAME.
2 | In consists of three arrays: lmk_face_idx, lmk_b_coords, and landmark_indices.
3 | - lmk_face_idx contains for every landmark the index of the FLAME triangle which each landmark is embedded into
4 | - lmk_b_coords are the barycentric weights for each vertex of the triangles
5 | - landmark_indices are the indices of the vertices of the Mediapipe mesh
6 |
--------------------------------------------------------------------------------
/skyreels_a1/src/multi_fps.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import subprocess
6 | from .frame_interpolation import batch_images_interpolation_tool
7 |
8 | def multi_fps_tool(frames, frame_inter_model, target_fps, original_fps=12):
9 | frames_np = np.array([np.array(frame) for frame in frames])
10 |
11 | interpolation_factor = target_fps / original_fps
12 | inter_frames = math.ceil(interpolation_factor) - 1
13 | frames_tensor = torch.from_numpy(frames_np).permute(3, 0, 1, 2).unsqueeze(0) / 255.0
14 |
15 | video = batch_images_interpolation_tool(frames_tensor, frame_inter_model, inter_frames=inter_frames)
16 | video = video.squeeze(0)
17 | video = video.permute(1, 2, 3, 0)
18 | video = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
19 | out_frames = [Image.fromarray(frame) for frame in video]
20 |
21 | if not interpolation_factor.is_integer():
22 | print(f"Warning: target fps {target_fps} is not mulitple of 12, which may cause unstable video rate.")
23 | out_frames = adjust_video_fps(out_frames, target_fps, int(target_fps//12+1)*12)
24 |
25 | return out_frames
26 |
27 | def adjust_video_fps(frames, target_fps, fps):
28 | video_length = len(frames)
29 |
30 | duration = video_length / fps
31 | target_times = np.arange(0, duration, 1/target_fps)
32 | frame_indices = (target_times * fps).astype(np.int32)
33 |
34 | frame_indices = frame_indices[frame_indices < video_length]
35 | new_frames = []
36 | for idx in frame_indices:
37 | new_frames.append(frames[idx])
38 |
39 | return new_frames
--------------------------------------------------------------------------------
/skyreels_a1/src/smirk_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | import timm
5 |
6 |
7 | def create_backbone(backbone_name, pretrained=True):
8 | backbone = timm.create_model(backbone_name,
9 | pretrained=pretrained,
10 | features_only=True)
11 | feature_dim = backbone.feature_info[-1]['num_chs']
12 | return backbone, feature_dim
13 |
14 | class PoseEncoder(nn.Module):
15 | def __init__(self) -> None:
16 | super().__init__()
17 |
18 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_small_minimal_100')
19 |
20 | self.pose_cam_layers = nn.Sequential(
21 | nn.Linear(feature_dim, 6)
22 | )
23 |
24 | self.init_weights()
25 |
26 | def init_weights(self):
27 | self.pose_cam_layers[-1].weight.data *= 0.001
28 | self.pose_cam_layers[-1].bias.data *= 0.001
29 |
30 | self.pose_cam_layers[-1].weight.data[3] = 0
31 | self.pose_cam_layers[-1].bias.data[3] = 7
32 |
33 |
34 | def forward(self, img):
35 | features = self.encoder(img)[-1]
36 |
37 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
38 |
39 | outputs = {}
40 |
41 | pose_cam = self.pose_cam_layers(features).reshape(img.size(0), -1)
42 | outputs['pose_params'] = pose_cam[...,:3]
43 | # import pdb;pdb.set_trace()
44 | outputs['cam'] = pose_cam[...,3:]
45 |
46 | return outputs
47 |
48 |
49 | class ShapeEncoder(nn.Module):
50 | def __init__(self, n_shape=300) -> None:
51 | super().__init__()
52 |
53 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
54 |
55 | self.shape_layers = nn.Sequential(
56 | nn.Linear(feature_dim, n_shape)
57 | )
58 |
59 | self.init_weights()
60 |
61 |
62 | def init_weights(self):
63 | self.shape_layers[-1].weight.data *= 0
64 | self.shape_layers[-1].bias.data *= 0
65 |
66 |
67 | def forward(self, img):
68 | features = self.encoder(img)[-1]
69 |
70 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
71 |
72 | parameters = self.shape_layers(features).reshape(img.size(0), -1)
73 |
74 | return {'shape_params': parameters}
75 |
76 |
77 | class ExpressionEncoder(nn.Module):
78 | def __init__(self, n_exp=50) -> None:
79 | super().__init__()
80 |
81 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
82 |
83 | self.expression_layers = nn.Sequential(
84 | nn.Linear(feature_dim, n_exp+2+3) # num expressions + jaw + eyelid
85 | )
86 |
87 | self.n_exp = n_exp
88 | self.init_weights()
89 |
90 |
91 | def init_weights(self):
92 | self.expression_layers[-1].weight.data *= 0.1
93 | self.expression_layers[-1].bias.data *= 0.1
94 |
95 |
96 | def forward(self, img):
97 | features = self.encoder(img)[-1]
98 |
99 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
100 |
101 |
102 | parameters = self.expression_layers(features).reshape(img.size(0), -1)
103 |
104 | outputs = {}
105 |
106 | outputs['expression_params'] = parameters[...,:self.n_exp]
107 | outputs['eyelid_params'] = torch.clamp(parameters[...,self.n_exp:self.n_exp+2], 0, 1)
108 | outputs['jaw_params'] = torch.cat([F.relu(parameters[...,self.n_exp+2].unsqueeze(-1)),
109 | torch.clamp(parameters[...,self.n_exp+3:self.n_exp+5], -.2, .2)], dim=-1)
110 |
111 | return outputs
112 |
113 |
114 | class SmirkEncoder(nn.Module):
115 | def __init__(self, n_exp=50, n_shape=300) -> None:
116 | super().__init__()
117 |
118 | self.pose_encoder = PoseEncoder()
119 |
120 | self.shape_encoder = ShapeEncoder(n_shape=n_shape)
121 |
122 | self.expression_encoder = ExpressionEncoder(n_exp=n_exp)
123 |
124 | def forward(self, img):
125 | pose_outputs = self.pose_encoder(img)
126 | shape_outputs = self.shape_encoder(img)
127 | expression_outputs = self.expression_encoder(img)
128 |
129 | outputs = {}
130 | outputs.update(pose_outputs)
131 | outputs.update(shape_outputs)
132 | outputs.update(expression_outputs)
133 |
134 | return outputs
135 |
--------------------------------------------------------------------------------
/skyreels_a1/src/utils/mediapipe_utils.py:
--------------------------------------------------------------------------------
1 | import mediapipe as mp
2 | from mediapipe.tasks import python
3 | from mediapipe.tasks.python import vision
4 | import cv2
5 | import numpy as np
6 | import os
7 |
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | import os
13 | import cv2
14 |
15 | # borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py
16 | def face_vertices(vertices, faces):
17 | """
18 | :param vertices: [batch size, number of vertices, 3]
19 | :param faces: [batch size, number of faces, 3]
20 | :return: [batch size, number of faces, 3, 3]
21 | """
22 | assert (vertices.ndimension() == 3)
23 | assert (faces.ndimension() == 3)
24 | assert (vertices.shape[0] == faces.shape[0])
25 | assert (vertices.shape[2] == 3)
26 | assert (faces.shape[2] == 3)
27 |
28 | bs, nv = vertices.shape[:2]
29 | bs, nf = faces.shape[:2]
30 | device = vertices.device
31 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
32 | vertices = vertices.reshape((bs * nv, 3))
33 | # pytorch only supports long and byte tensors for indexing
34 | return vertices[faces.long()]
35 |
36 | def vertex_normals(vertices, faces):
37 | """
38 | :param vertices: [batch size, number of vertices, 3]
39 | :param faces: [batch size, number of faces, 3]
40 | :return: [batch size, number of vertices, 3]
41 | """
42 | assert (vertices.ndimension() == 3)
43 | assert (faces.ndimension() == 3)
44 | assert (vertices.shape[0] == faces.shape[0])
45 | assert (vertices.shape[2] == 3)
46 | assert (faces.shape[2] == 3)
47 | bs, nv = vertices.shape[:2]
48 | bs, nf = faces.shape[:2]
49 | device = vertices.device
50 | normals = torch.zeros(bs * nv, 3).to(device)
51 |
52 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] # expanded faces
53 | vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()]
54 |
55 | faces = faces.reshape(-1, 3)
56 | vertices_faces = vertices_faces.reshape(-1, 3, 3)
57 |
58 | normals.index_add_(0, faces[:, 1].long(),
59 | torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1], vertices_faces[:, 0] - vertices_faces[:, 1]))
60 | normals.index_add_(0, faces[:, 2].long(),
61 | torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2], vertices_faces[:, 1] - vertices_faces[:, 2]))
62 | normals.index_add_(0, faces[:, 0].long(),
63 | torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0], vertices_faces[:, 2] - vertices_faces[:, 0]))
64 |
65 | normals = F.normalize(normals, eps=1e-6, dim=1)
66 | normals = normals.reshape((bs, nv, 3))
67 | # pytorch only supports long and byte tensors for indexing
68 | return normals
69 |
70 | def batch_orth_proj(X, camera):
71 | ''' orthgraphic projection
72 | X: 3d vertices, [bz, n_point, 3]
73 | camera: scale and translation, [bz, 3], [scale, tx, ty]
74 | '''
75 | #print('--------')
76 | #print(camera[0, 1:].abs())
77 | #print(X[0].abs().mean(0))
78 |
79 | camera = camera.clone().view(-1, 1, 3)
80 | X_trans = X[:, :, :2] + camera[:, :, 1:]
81 | #print(X_trans[0].abs().mean(0))
82 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2)
83 | Xn = (camera[:, :, 0:1] * X_trans)
84 | return Xn
85 |
86 | class MP_2_FLAME():
87 | """
88 | Convert Mediapipe 52 blendshape scores to FLAME's coefficients
89 | """
90 | def __init__(self, mappings_path):
91 | self.bs2exp = np.load(os.path.join(mappings_path, 'bs2exp.npy'))
92 | self.bs2pose = np.load(os.path.join(mappings_path, 'bs2pose.npy'))
93 | self.bs2eye = np.load(os.path.join(mappings_path, 'bs2eye.npy'))
94 |
95 | def convert(self, blendshape_scores : np.array):
96 | # blendshape_scores: [N, 52]
97 |
98 | # Calculate expression, pose, and eye_pose using the mappings
99 | exp = blendshape_scores @ self.bs2exp
100 | pose = blendshape_scores @ self.bs2pose
101 | pose[0, :3] = 0 # we do not support head rotation yet
102 | eye_pose = blendshape_scores @ self.bs2eye
103 |
104 | return exp, pose, eye_pose
105 |
106 | class MediaPipeUtils:
107 | def __init__(self, model_asset_path='pretrained_models/mediapipe/face_landmarker.task', mappings_path='pretrained_models/mediapipe/'):
108 | base_options = python.BaseOptions(model_asset_path=model_asset_path)
109 | options = vision.FaceLandmarkerOptions(base_options=base_options,
110 | output_face_blendshapes=True,
111 | output_facial_transformation_matrixes=True,
112 | num_faces=1,
113 | min_face_detection_confidence=0.1,
114 | min_face_presence_confidence=0.1)
115 | self.detector = vision.FaceLandmarker.create_from_options(options)
116 | self.mp2flame = MP_2_FLAME(mappings_path=mappings_path)
117 |
118 | def run_mediapipe(self, image):
119 | image_numpy = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
120 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_numpy)
121 | detection_result = self.detector.detect(image)
122 |
123 | if len(detection_result.face_landmarks) == 0:
124 | print('No face detected')
125 | return None
126 |
127 | blend_scores = detection_result.face_blendshapes[0]
128 | blend_scores = np.array(list(map(lambda l: l.score, blend_scores)), dtype=np.float32).reshape(1, 52)
129 | exp, pose, eye_pose = self.mp2flame.convert(blendshape_scores=blend_scores)
130 |
131 | face_landmarks = detection_result.face_landmarks[0]
132 | face_landmarks_numpy = np.zeros((478, 3))
133 |
134 | for i, landmark in enumerate(face_landmarks):
135 | face_landmarks_numpy[i] = [landmark.x * image.width, landmark.y * image.height, landmark.z]
136 |
137 | return face_landmarks_numpy, exp, pose, eye_pose
138 |
--------------------------------------------------------------------------------