├── .gitignore ├── LICENSE.txt ├── README.md ├── app.py ├── assets ├── demo.gif ├── driving_audio │ ├── 1.wav │ ├── 2.wav │ ├── 3.wav │ ├── 4.wav │ ├── 5.wav │ └── 6.wav ├── driving_video │ ├── .DS_Store │ ├── 1.mp4 │ ├── 2.mp4 │ ├── 3.mp4 │ ├── 4.mp4 │ ├── 5.mp4 │ ├── 6.mp4 │ ├── 7.mp4 │ └── 8.mp4 ├── gradio.png ├── logo.png └── ref_images │ ├── 1.png │ ├── 10.png │ ├── 11.png │ ├── 12.png │ ├── 13.png │ ├── 14.png │ ├── 15.png │ ├── 16.png │ ├── 17.png │ ├── 18.png │ ├── 19.png │ ├── 2.png │ ├── 20.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ └── 8.png ├── diffposetalk ├── common.py ├── diff_talking_head.py ├── diffposetalk.py ├── hubert.py ├── utils │ ├── __init__.py │ ├── common.py │ ├── media.py │ ├── renderer.py │ └── rotation_conversions.py └── wav2vec2.py ├── eval ├── arc_score.py ├── curricularface │ ├── __init__.py │ ├── common.py │ ├── model_irse.py │ └── model_resnet.py ├── expression_score.py └── pose_score.py ├── inference.py ├── inference_audio.py ├── inference_audio_long_video.py ├── inference_long_video.py ├── requirements.txt ├── scripts ├── __init__.py └── demo.py └── skyreels_a1 ├── __init__.py ├── ddim_solver.py ├── models ├── __init__.py └── transformer3d.py ├── pipeline_output.py ├── pre_process_lmk3d.py ├── skyreels_a1_i2v_long_pipeline.py ├── skyreels_a1_i2v_pipeline.py └── src ├── FLAME ├── FLAME.py └── lbs.py ├── __init__.py ├── frame_interpolation.py ├── lmk3d_test.py ├── media_pipe ├── draw_util.py ├── draw_util_2d.py ├── face_landmark.py ├── mp_models │ ├── blaze_face_short_range.tflite │ ├── face_landmarker_v2_with_blendshapes.task │ └── pose_landmarker_heavy.task ├── mp_utils.py └── readme ├── multi_fps.py ├── renderer.py ├── smirk_encoder.py └── utils └── mediapipe_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained_models/ 2 | __pycache__/ 3 | .gradio/ 4 | outputs/ 5 | outputs_audio/ 6 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | --- 2 | language: 3 | - en 4 | - zh 5 | license: other 6 | tasks: 7 | - text-generation 8 | 9 | --- 10 | 11 | 12 | 13 | 14 | # 声明与协议/Terms and Conditions 15 | 16 | ## 声明 17 | 18 | 我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。 19 | 20 | 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。 21 | 22 | We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment. 23 | 24 | We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility. 25 | 26 | ## 协议 27 | 28 | 社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。 29 | 30 | 31 | The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). 32 | 33 | 34 | 35 | [《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf 36 | 37 | 38 | [skywork-opensource@kunlun-inc.com]: mailto:skywork-opensource@kunlun-inc.com 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Skyreels Logo 3 |

4 | 5 | 6 |

SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers

7 | 8 |
9 | Di Qiu  10 | Zhengcong Fei  11 | Rui Wang  12 | Jialin Bai  13 | Changqian Yu  14 |
15 | 16 |
17 | Mingyuan Fan  18 | Guibin Chen  19 | Xiang Wen  20 |
21 | 22 |
23 | Skywork AI, Kunlun Inc. 24 |
25 | 26 |
27 | 28 |
29 | 30 | 31 | 32 | 33 | 34 |
35 |
36 |
37 | 38 | 39 |

40 | showcase 41 |
42 | 🔥 For more results, visit our homepage 🔥 43 |

44 | 45 |

46 | 👋 Join our Discord 47 |

48 | 49 | 50 | This repo, named **SkyReels-A1**, contains the official PyTorch implementation of our paper [SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers](https://arxiv.org/abs/2502.10841). 51 | 52 | 53 | ## 🔥🔥🔥 News!! 54 | * Apr 3, 2025: 🔥 We release [SkyReels-A2](https://github.com/SkyworkAI/SkyReels-A2). This is an open-sourced controllable video generation framework capable of assembling arbitrary visual elements. 55 | * Mar 4, 2025: 🔥 We release audio-driven portrait image animation pipeline. Try out on [Huggingface Spaces Demo](https://huggingface.co/spaces/Skywork/skyreels-a1-talking-head) ! 56 | * Feb 18, 2025: 👋 We release the inference code and model weights of SkyReels-A1. [Download](https://huggingface.co/Skywork/SkyReels-A1) 57 | * Feb 18, 2025: 🎉 We have made our technical report available as open source. [Read](https://skyworkai.github.io/skyreels-a1.github.io/report.pdf) 58 | * Feb 18, 2025: 🔥 Our online demo of LipSync is available on SkyReels now! Try out on [LipSync](https://www.skyreels.ai/home/tools/lip-sync?refer=navbar) . 59 | * Feb 18, 2025: 🔥 We have open-sourced I2V video generation model [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model. 60 | 61 | ## 📑 TODO List 62 | - [x] Checkpoints 63 | - [x] Inference Code 64 | - [x] Web Demo (Gradio) 65 | - [x] Audio-driven Portrait Image Animation Pipeline 66 | - [x] Inference Code for Long Videos 67 | - [ ] User-Level GPU Inference on RTX4090 68 | - [ ] ComfyUI 69 | 70 | 71 | ## Getting Started 🏁 72 | 73 | ### 1. Clone the code and prepare the environment 🛠️ 74 | First git clone the repository with code: 75 | ```bash 76 | git clone https://github.com/SkyworkAI/SkyReels-A1.git 77 | cd SkyReels-A1 78 | 79 | # create env using conda 80 | conda create -n skyreels-a1 python=3.10 81 | conda activate skyreels-a1 82 | ``` 83 | Then, install the remaining dependencies: 84 | ```bash 85 | pip install -r requirements.txt 86 | ``` 87 | 88 | 89 | ### 2. Download pretrained weights 📥 90 | You can download the pretrained weights is from HuggingFace: 91 | ```bash 92 | # !pip install -U "huggingface_hub[cli]" 93 | huggingface-cli download Skywork/SkyReels-A1 --local-dir local_path --exclude "*.git*" "README.md" "docs" 94 | ``` 95 | 96 | The FLAME, mediapipe, and smirk models are located in the SkyReels-A1/extra_models folder. 97 | 98 | The directory structure of our SkyReels-A1 code is formulated as: 99 | ```text 100 | pretrained_models 101 | ├── FLAME 102 | ├── SkyReels-A1-5B 103 | │ ├── pose_guider 104 | │ ├── scheduler 105 | │ ├── tokenizer 106 | │ ├── siglip-so400m-patch14-384 107 | │ ├── transformer 108 | │ ├── vae 109 | │ └── text_encoder 110 | ├── mediapipe 111 | └── smirk 112 | 113 | ``` 114 | 115 | #### Download DiffposeTalk assets and pretrained weights (For Audio-driven) 116 | 117 | - We use [diffposetalk](https://github.com/DiffPoseTalk/DiffPoseTalk/tree/main) to generate flame coefficients from audio, thereby constructing motion signals. 118 | 119 | - Download the diffposetalk code and follow its README to download the weights and related data. 120 | 121 | - Then place them in the specified directory. 122 | 123 | ```bash 124 | cp -r ${diffposetalk_root}/style pretrained_models/diffposetalk 125 | cp ${diffposetalk_root}/experiments/DPT/head-SA-hubert-WM/checkpoints/iter_0110000.pt pretrained_models/diffposetalk 126 | cp ${diffposetalk_root}/datasets/HDTF_TFHP/lmdb/stats_train.npz pretrained_models/diffposetalk 127 | ``` 128 | 129 | - Or you can download style files from [link](https://drive.google.com/file/d/1XT426b-jt7RUkRTYsjGvG-wS4Jed2U1T/view?usp=sharing) and stats_train.npz from [link](https://drive.google.com/file/d/1_I5XRzkMP7xULCSGVuaN8q1Upplth9xR/view?usp=sharing). 130 | 131 | ```text 132 | pretrained_models 133 | ├── FLAME 134 | ├── SkyReels-A1-5B 135 | ├── mediapipe 136 | ├── diffposetalk 137 | │ ├── style 138 | │ ├── iter_0110000.pt 139 | │ ├── stats_train.npz 140 | └── smirk 141 | 142 | ``` 143 | 144 | #### Download Frame interpolation Model pretrained weights (For Long Video Inference and Dynamic Resolution) 145 | 146 | - We use [FILM](https://github.com/dajes/frame-interpolation-pytorch) to generate transition frames, making the video transitions smoother (Set `use_interpolation` to True). 147 | 148 | - Download [film_net_fp16.pt](https://github.com/dajes/frame-interpolation-pytorch/releases), and place it in the specified directory. 149 | 150 | ```text 151 | pretrained_models 152 | ├── FLAME 153 | ├── SkyReels-A1-5B 154 | ├── mediapipe 155 | ├── diffposetalk 156 | ├── film_net 157 | │ ├── film_net_fp16.pt 158 | └── smirk 159 | ``` 160 | 161 | 162 | ### 3. Inference 🚀 163 | You can simply run the inference scripts as: 164 | ```bash 165 | python inference.py 166 | 167 | # inference audio to video 168 | python inference_audio.py 169 | ``` 170 | 171 | If the script runs successfully, you will get an output mp4 file. This file includes the following results: driving video, input image or video, and generated result. 172 | 173 | #### Long Video Inference 174 | 175 | Now, you can run the long video inference scripts to obtain portrait animation of any length: 176 | ```bash 177 | python inference_long_video.py 178 | 179 | # inference audio to video 180 | python inference_audio_long_video.py 181 | ``` 182 | 183 | #### Dynamic Resolution 184 | 185 | All inference scripts now support dynamic resolution, simply set `target_fps` to any desired fps, recommended fps include: 12fps (Native), 24fps, 48fps, 60fps, other settings such as 25fps and 30fps may cause unstable frame rates. 186 | 187 | 188 | ## Gradio Interface 🤗 189 | 190 | We provide a [Gradio](https://huggingface.co/docs/hub/spaces-sdks-gradio) interface for a better experience, just run by: 191 | 192 | ```bash 193 | python app.py 194 | ``` 195 | 196 | The graphical interactive interface is shown as below: 197 | 198 | ![gradio](assets/gradio.png) 199 | 200 | 201 | ## Metric Evaluation 👓 202 | 203 | We also provide all scripts for automatically calculating the metrics, including SimFace, FID, and L1 distance between expression and motion, reported in the paper. 204 | 205 | All codes can be found in the ```eval``` folder. After setting the video result path, run the following commands in sequence: 206 | 207 | ```bash 208 | python arc_score.py 209 | python expression_score.py 210 | python pose_score.py 211 | ``` 212 | 213 | 214 | ## Acknowledgements 💐 215 | We would like to thank the contributors of [CogvideoX](https://github.com/THUDM/CogVideo), [finetrainers](https://github.com/a-r-r-o-w/finetrainers) and [DiffPoseTalk](https://github.com/DiffPoseTalk/DiffPoseTalk)repositories, for their open research and contributions. 216 | 217 | ## Citation 💖 218 | If you find SkyReels-A1 useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX: 219 | ```bibtex 220 | @article{qiu2025skyreels, 221 | title={Skyreels-a1: Expressive portrait animation in video diffusion transformers}, 222 | author={Qiu, Di and Fei, Zhengcong and Wang, Rui and Bai, Jialin and Yu, Changqian and Fan, Mingyuan and Chen, Guibin and Wen, Xiang}, 223 | journal={arXiv preprint arXiv:2502.10841}, 224 | year={2025} 225 | } 226 | ``` 227 | 228 | ## Star History 229 | 230 | [![Star History Chart](https://api.star-history.com/svg?repos=SkyworkAI/SkyReels-A1&type=Date)](https://www.star-history.com/#SkyworkAI/SkyReels-A1&Date) 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from scripts.demo import init_model, generate_video, crop_and_resize 3 | import os 4 | import os.path as osp 5 | import stat 6 | from datetime import datetime 7 | import torch 8 | import numpy as np 9 | from diffusers.utils import export_to_video, load_image 10 | 11 | os.environ['GRADIO_TEMP_DIR'] = 'tmp' 12 | 13 | example_portrait_dir = "assets/ref_images" 14 | example_video_dir = "assets/driving_video" 15 | 16 | 17 | pipe, face_helper, processor, lmk_extractor, vis = init_model() 18 | # Gradio interface using Interface 19 | with gr.Blocks() as demo: 20 | gr.Markdown(""" 21 |
22 | SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers 23 |
24 |
25 | 🤗 SkyReels-A1-5B Model Hub | 26 | 🌐 Github | 27 | 📜 arxiv 28 |
29 | """) 30 | 31 | with gr.Row(): # 创建一个水平排列的行 32 | with gr.Accordion(open=True, label="Portrait Image"): 33 | image_input = gr.Image(type="filepath") 34 | gr.Examples( 35 | examples=[ 36 | [osp.join(example_portrait_dir, "1.png")], 37 | [osp.join(example_portrait_dir, "2.png")], 38 | [osp.join(example_portrait_dir, "3.png")], 39 | [osp.join(example_portrait_dir, "4.png")], 40 | [osp.join(example_portrait_dir, "5.png")], 41 | [osp.join(example_portrait_dir, "6.png")], 42 | [osp.join(example_portrait_dir, "7.png")], 43 | [osp.join(example_portrait_dir, "8.png")], 44 | ], 45 | inputs=[image_input], 46 | cache_examples=False, 47 | ) 48 | with gr.Accordion(open=True, label="Driving Video"): 49 | control_video_input = gr.Video() 50 | gr.Examples( 51 | examples=[ 52 | [osp.join(example_video_dir, "1.mp4")], 53 | [osp.join(example_video_dir, "2.mp4")], 54 | [osp.join(example_video_dir, "3.mp4")], 55 | [osp.join(example_video_dir, "4.mp4")], 56 | [osp.join(example_video_dir, "5.mp4")], 57 | [osp.join(example_video_dir, "6.mp4")], 58 | [osp.join(example_video_dir, "7.mp4")], 59 | [osp.join(example_video_dir, "8.mp4")], 60 | ], 61 | inputs=[control_video_input], 62 | cache_examples=False, 63 | ) 64 | 65 | def face_check(image_path): 66 | image = load_image(image=image_path) 67 | image = crop_and_resize(image, 480, 720) 68 | 69 | with torch.no_grad(): 70 | face_helper.clean_all() 71 | face_helper.read_image(np.array(image)[:, :, ::-1]) 72 | face_helper.get_face_landmarks_5(only_center_face=True) 73 | face_helper.align_warp_face() 74 | if len(face_helper.cropped_faces) == 0: 75 | return False 76 | face = face_helper.det_faces 77 | face_w = int(face[2] - face[0]) 78 | if face_w < 50: 79 | return False 80 | return True 81 | 82 | 83 | def gradio_generate_video(control_video_path, image_path, progress=gr.Progress(track_tqdm=True)): 84 | try: 85 | save_dir = "./outputs/" 86 | if not os.path.exists(save_dir): 87 | os.makedirs(save_dir, exist_ok=True) 88 | current_time = datetime.now().strftime("%Y%m%d_%H%M%S") 89 | save_path = os.path.join(save_dir, f"generated_video_{current_time}.mp4") 90 | print(control_video_path, image_path) 91 | 92 | face = face_check(image_path) 93 | if face == False: 94 | return "Face too small or no face.", None, None 95 | 96 | generate_video( 97 | pipe, 98 | face_helper, 99 | processor, 100 | lmk_extractor, 101 | vis, 102 | control_video_path=control_video_path, 103 | image_path=image_path, 104 | save_path=save_path, 105 | guidance_scale=3, 106 | seed=43, 107 | num_inference_steps=20, 108 | sample_size=[480, 720], 109 | max_frame_num=49, 110 | ) 111 | 112 | print("finished.") 113 | print(save_path) 114 | if not os.path.exists(save_path): 115 | print("Error: Video file not found") 116 | return "Error: Video file not found", None 117 | 118 | video_update = gr.update(visible=True, value=save_path) 119 | return "Video generated successfully.", save_path, video_update 120 | except Exception as e: 121 | return f"Error occurred: {str(e)}", None, None 122 | 123 | 124 | generate_button = gr.Button("Generate Video") 125 | output_text = gr.Textbox(label="Output") 126 | output_video = gr.Video(label="Output Video") 127 | with gr.Row(): 128 | download_video_button = gr.File(label="📥 Download Video", visible=False) 129 | 130 | generate_button.click( 131 | gradio_generate_video, 132 | inputs=[ 133 | control_video_input, 134 | image_input 135 | ], 136 | outputs=[output_text, output_video, download_video_button], # 更新输出以包含视频 137 | show_progress=True, 138 | ) 139 | 140 | 141 | if __name__ == "__main__": 142 | # demo.queue(concurrency_count=8) 143 | demo.launch(share=True, enable_queue=True) 144 | -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/demo.gif -------------------------------------------------------------------------------- /assets/driving_audio/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/1.wav -------------------------------------------------------------------------------- /assets/driving_audio/2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/2.wav -------------------------------------------------------------------------------- /assets/driving_audio/3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/3.wav -------------------------------------------------------------------------------- /assets/driving_audio/4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/4.wav -------------------------------------------------------------------------------- /assets/driving_audio/5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/5.wav -------------------------------------------------------------------------------- /assets/driving_audio/6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_audio/6.wav -------------------------------------------------------------------------------- /assets/driving_video/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/.DS_Store -------------------------------------------------------------------------------- /assets/driving_video/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/1.mp4 -------------------------------------------------------------------------------- /assets/driving_video/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/2.mp4 -------------------------------------------------------------------------------- /assets/driving_video/3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/3.mp4 -------------------------------------------------------------------------------- /assets/driving_video/4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/4.mp4 -------------------------------------------------------------------------------- /assets/driving_video/5.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/5.mp4 -------------------------------------------------------------------------------- /assets/driving_video/6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/6.mp4 -------------------------------------------------------------------------------- /assets/driving_video/7.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/7.mp4 -------------------------------------------------------------------------------- /assets/driving_video/8.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/driving_video/8.mp4 -------------------------------------------------------------------------------- /assets/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/gradio.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/logo.png -------------------------------------------------------------------------------- /assets/ref_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/1.png -------------------------------------------------------------------------------- /assets/ref_images/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/10.png -------------------------------------------------------------------------------- /assets/ref_images/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/11.png -------------------------------------------------------------------------------- /assets/ref_images/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/12.png -------------------------------------------------------------------------------- /assets/ref_images/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/13.png -------------------------------------------------------------------------------- /assets/ref_images/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/14.png -------------------------------------------------------------------------------- /assets/ref_images/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/15.png -------------------------------------------------------------------------------- /assets/ref_images/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/16.png -------------------------------------------------------------------------------- /assets/ref_images/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/17.png -------------------------------------------------------------------------------- /assets/ref_images/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/18.png -------------------------------------------------------------------------------- /assets/ref_images/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/19.png -------------------------------------------------------------------------------- /assets/ref_images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/2.png -------------------------------------------------------------------------------- /assets/ref_images/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/20.png -------------------------------------------------------------------------------- /assets/ref_images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/3.png -------------------------------------------------------------------------------- /assets/ref_images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/4.png -------------------------------------------------------------------------------- /assets/ref_images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/5.png -------------------------------------------------------------------------------- /assets/ref_images/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/6.png -------------------------------------------------------------------------------- /assets/ref_images/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/7.png -------------------------------------------------------------------------------- /assets/ref_images/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/assets/ref_images/8.png -------------------------------------------------------------------------------- /diffposetalk/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | def __init__(self, d_model, dropout=0.1, max_len=600): 10 | super().__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | # vanilla sinusoidal encoding 13 | pe = torch.zeros(max_len, d_model) 14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0) 19 | self.register_buffer('pe', pe) 20 | 21 | def forward(self, x): 22 | x = x + self.pe[:, x.shape[1], :] 23 | return self.dropout(x) 24 | 25 | 26 | def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'): 27 | mask = torch.ones(T, S) 28 | for i in range(T): 29 | mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0 30 | return (mask == 1).to(device=device) 31 | 32 | 33 | def pad_audio(audio, audio_unit=320, pad_threshold=80): 34 | batch_size, audio_len = audio.shape 35 | n_units = audio_len // audio_unit 36 | side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2) 37 | if side_len >= 0: 38 | reflect_len = side_len // 2 39 | replicate_len = side_len % 2 40 | if reflect_len > 0: 41 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') 42 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') 43 | if replicate_len > 0: 44 | audio = F.pad(audio, (1, 1), mode='replicate') 45 | 46 | return audio 47 | -------------------------------------------------------------------------------- /diffposetalk/diffposetalk.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tempfile 3 | import warnings 4 | from pathlib import Path 5 | 6 | import cv2 7 | import librosa 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | from pydantic import BaseModel 13 | 14 | from .diff_talking_head import DiffTalkingHead 15 | from .utils import NullableArgs, coef_dict_to_vertices, get_coef_dict 16 | from .utils.media import combine_video_and_audio, convert_video, reencode_audio 17 | 18 | warnings.filterwarnings('ignore', message='PySoundFile failed. Trying audioread instead.') 19 | 20 | class DiffPoseTalkConfig(BaseModel): 21 | no_context_audio_feat: bool = False 22 | model_path: str = "pretrained_models/diffposetalk/iter_0110000.pt" # DPT/head-SA-hubert-WM 23 | coef_stats: str = "pretrained_models/diffposetalk/stats_train.npz" 24 | style_path: str = "pretrained_models/diffposetalk/style/L4H4-T0.1-BS32/iter_0034000/normal.npy" 25 | dynamic_threshold_ratio: float = 0.99 26 | dynamic_threshold_min: float = 1.0 27 | dynamic_threshold_max: float = 4.0 28 | scale_audio: float = 1.15 29 | scale_style: float = 3.0 30 | 31 | class DiffPoseTalk: 32 | def __init__(self, config: DiffPoseTalkConfig = DiffPoseTalkConfig(), device="cuda"): 33 | self.cfg = config 34 | self.device = device 35 | 36 | self.no_context_audio_feat = self.cfg.no_context_audio_feat 37 | model_data = torch.load(self.cfg.model_path, map_location=self.device) 38 | 39 | self.model_args = NullableArgs(model_data['args']) 40 | self.model = DiffTalkingHead(self.model_args, self.device) 41 | model_data['model'].pop('denoising_net.TE.pe') 42 | self.model.load_state_dict(model_data['model'], strict=False) 43 | self.model.to(self.device) 44 | self.model.eval() 45 | 46 | self.use_indicator = self.model_args.use_indicator 47 | self.rot_repr = self.model_args.rot_repr 48 | self.predict_head_pose = not self.model_args.no_head_pose 49 | if self.model.use_style: 50 | style_dir = Path(self.model_args.style_enc_ckpt) 51 | style_dir = Path(*style_dir.with_suffix('').parts[-3::2]) 52 | self.style_dir = style_dir 53 | 54 | # sequence 55 | self.n_motions = self.model_args.n_motions 56 | self.n_prev_motions = self.model_args.n_prev_motions 57 | self.fps = self.model_args.fps 58 | self.audio_unit = 16000. / self.fps # num of samples per frame 59 | self.n_audio_samples = round(self.audio_unit * self.n_motions) 60 | self.pad_mode = self.model_args.pad_mode 61 | 62 | self.coef_stats = dict(np.load(self.cfg.coef_stats)) 63 | self.coef_stats = {k: torch.from_numpy(v).to(self.device) for k, v in self.coef_stats.items()} 64 | 65 | if self.cfg.dynamic_threshold_ratio > 0: 66 | self.dynamic_threshold = (self.cfg.dynamic_threshold_ratio, self.cfg.dynamic_threshold_min, 67 | self.cfg.dynamic_threshold_max) 68 | else: 69 | self.dynamic_threshold = None 70 | 71 | 72 | def infer_from_file(self, audio_path, shape_coef): 73 | n_repetitions = 1 74 | cfg_mode = None 75 | cfg_cond = self.model.guiding_conditions 76 | cfg_scale = [] 77 | for cond in cfg_cond: 78 | if cond == 'audio': 79 | cfg_scale.append(self.cfg.scale_audio) 80 | elif cond == 'style': 81 | cfg_scale.append(self.cfg.scale_style) 82 | 83 | coef_dict = self.infer_coeffs(audio_path, shape_coef, self.cfg.style_path, n_repetitions, 84 | cfg_mode, cfg_cond, cfg_scale, include_shape=True) 85 | return coef_dict 86 | 87 | @torch.no_grad() 88 | def infer_coeffs(self, audio, shape_coef, style_feat=None, n_repetitions=1, 89 | cfg_mode=None, cfg_cond=None, cfg_scale=1.15, include_shape=False): 90 | # Returns dict[str, (n_repetitions, L, *)] 91 | # Step 1: Preprocessing 92 | # Preprocess audio 93 | if isinstance(audio, (str, Path)): 94 | audio, _ = librosa.load(audio, sr=16000, mono=True) 95 | if isinstance(audio, np.ndarray): 96 | audio = torch.from_numpy(audio).to(self.device) 97 | assert audio.ndim == 1, 'Audio must be 1D tensor.' 98 | audio_mean, audio_std = torch.mean(audio), torch.std(audio) 99 | audio = (audio - audio_mean) / (audio_std + 1e-5) 100 | 101 | # Preprocess shape coefficient 102 | if isinstance(shape_coef, (str, Path)): 103 | shape_coef = np.load(shape_coef) 104 | if not isinstance(shape_coef, np.ndarray): 105 | shape_coef = shape_coef['shape'] 106 | if isinstance(shape_coef, np.ndarray): 107 | shape_coef = torch.from_numpy(shape_coef).float().to(self.device) 108 | assert shape_coef.ndim <= 2, 'Shape coefficient must be 1D or 2D tensor.' 109 | if shape_coef.ndim > 1: 110 | # use the first frame as the shape coefficient 111 | shape_coef = shape_coef[0] 112 | original_shape_coef = shape_coef.clone() 113 | if self.coef_stats is not None: 114 | shape_coef = (shape_coef - self.coef_stats['shape_mean']) / self.coef_stats['shape_std'] 115 | shape_coef = shape_coef.unsqueeze(0).expand(n_repetitions, -1) 116 | 117 | # Preprocess style feature if given 118 | if style_feat is not None: 119 | assert self.model.use_style 120 | if isinstance(style_feat, (str, Path)): 121 | style_feat = Path(style_feat) 122 | if not style_feat.exists() and not style_feat.is_absolute(): 123 | style_feat = style_feat.parent / self.style_dir / style_feat.name 124 | style_feat = np.load(style_feat) 125 | if not isinstance(style_feat, np.ndarray): 126 | style_feat = style_feat['style'] 127 | if isinstance(style_feat, np.ndarray): 128 | style_feat = torch.from_numpy(style_feat).float().to(self.device) 129 | assert style_feat.ndim == 1, 'Style feature must be 1D tensor.' 130 | style_feat = style_feat.unsqueeze(0).expand(n_repetitions, -1) 131 | 132 | # Step 2: Predict motion coef 133 | # divide into synthesize units and do synthesize 134 | clip_len = int(len(audio) / 16000 * self.fps) 135 | stride = self.n_motions 136 | if clip_len <= self.n_motions: 137 | n_subdivision = 1 138 | else: 139 | n_subdivision = math.ceil(clip_len / stride) 140 | 141 | # Prepare audio input 142 | n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio) 143 | n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit) 144 | if n_padding_audio_samples > 0: 145 | if self.pad_mode == 'zero': 146 | padding_value = 0 147 | elif self.pad_mode == 'replicate': 148 | padding_value = audio[-1] 149 | else: 150 | raise ValueError(f'Unknown pad mode: {self.pad_mode}') 151 | audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value) 152 | 153 | if not self.no_context_audio_feat: 154 | audio_feat = self.model.extract_audio_feature(audio.unsqueeze(0), self.n_motions * n_subdivision) 155 | 156 | # Generate `self.n_motions` new frames at one time, and use the last `self.n_prev_motions` frames 157 | # from the previous generation as the initial motion condition 158 | coef_list = [] 159 | for i in range(0, n_subdivision): 160 | start_idx = i * stride 161 | end_idx = start_idx + self.n_motions 162 | indicator = torch.ones((n_repetitions, self.n_motions)).to(self.device) if self.use_indicator else None 163 | if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0: 164 | indicator[:, -n_padding_frames:] = 0 165 | if not self.no_context_audio_feat: 166 | audio_in = audio_feat[:, start_idx:end_idx].expand(n_repetitions, -1, -1) 167 | else: 168 | audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0) 169 | 170 | # generate motion coefficients 171 | if i == 0: 172 | # -> (N, L, d_motion=n_code_per_frame * code_dim) 173 | motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat, 174 | indicator=indicator, cfg_mode=cfg_mode, 175 | cfg_cond=cfg_cond, cfg_scale=cfg_scale, 176 | dynamic_threshold=self.dynamic_threshold) 177 | else: 178 | motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat, 179 | prev_motion_feat, prev_audio_feat, noise, 180 | indicator=indicator, cfg_mode=cfg_mode, 181 | cfg_cond=cfg_cond, cfg_scale=cfg_scale, 182 | dynamic_threshold=self.dynamic_threshold) 183 | prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone() 184 | prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:] 185 | 186 | motion_coef = motion_feat 187 | if i == n_subdivision - 1 and n_padding_frames > 0: 188 | motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames 189 | coef_list.append(motion_coef) 190 | 191 | motion_coef = torch.cat(coef_list, dim=1) 192 | 193 | # Step 3: restore to coef dict 194 | coef_dict = get_coef_dict(motion_coef, None, self.coef_stats, self.predict_head_pose, self.rot_repr) 195 | if include_shape: 196 | coef_dict['shape'] = original_shape_coef[None, None].expand(n_repetitions, motion_coef.shape[1], -1) 197 | return self.coef_to_a1_format(coef_dict) 198 | 199 | def coef_to_a1_format(self, coef_dict): 200 | n_frames = coef_dict['exp'].shape[1] 201 | new_coef_dict = [] 202 | for i in range(n_frames): 203 | 204 | new_coef_dict.append({ 205 | "expression_params": coef_dict["exp"][0, i:i+1], 206 | "jaw_params": coef_dict["pose"][0, i:i+1, 3:], 207 | "eye_pose_params": torch.zeros(1, 6).type_as(coef_dict["pose"]), 208 | "pose_params": coef_dict["pose"][0, i:i+1, :3], 209 | "eyelid_params": None 210 | }) 211 | return new_coef_dict 212 | 213 | 214 | 215 | 216 | 217 | @staticmethod 218 | def _pad_coef(coef, n_frames, elem_ndim=1): 219 | if coef.ndim == elem_ndim: 220 | coef = coef[None] 221 | elem_shape = coef.shape[1:] 222 | if coef.shape[0] >= n_frames: 223 | new_coef = coef[:n_frames] 224 | else: 225 | # repeat the last coef frame 226 | new_coef = torch.cat([coef, coef[[-1]].expand(n_frames - coef.shape[0], *elem_shape)], dim=0) 227 | return new_coef # (n_frames, *elem_shape) 228 | 229 | -------------------------------------------------------------------------------- /diffposetalk/hubert.py: -------------------------------------------------------------------------------- 1 | from transformers import HubertModel 2 | from transformers.modeling_outputs import BaseModelOutput 3 | 4 | from .wav2vec2 import linear_interpolation 5 | 6 | _CONFIG_FOR_DOC = 'HubertConfig' 7 | 8 | 9 | class HubertModel(HubertModel): 10 | def __init__(self, config): 11 | super().__init__(config) 12 | 13 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None, 14 | output_hidden_states=None, return_dict=None, frame_num=None): 15 | self.config.output_attentions = True 16 | 17 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 18 | output_hidden_states = ( 19 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) 20 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 21 | 22 | extract_features = self.feature_extractor(input_values) # (N, C, L) 23 | # Resample the audio feature @ 50 fps to `output_fps`. 24 | if frame_num is not None: 25 | extract_features_len = round(frame_num * 50 / output_fps) 26 | extract_features = extract_features[:, :, :extract_features_len] 27 | extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num) 28 | extract_features = extract_features.transpose(1, 2) # (N, L, C) 29 | 30 | if attention_mask is not None: 31 | # compute reduced attention_mask corresponding to feature vectors 32 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) 33 | 34 | hidden_states = self.feature_projection(extract_features) 35 | hidden_states = self._mask_hidden_states(hidden_states) 36 | 37 | encoder_outputs = self.encoder( 38 | hidden_states, 39 | attention_mask=attention_mask, 40 | output_attentions=output_attentions, 41 | output_hidden_states=output_hidden_states, 42 | return_dict=return_dict, 43 | ) 44 | 45 | hidden_states = encoder_outputs[0] 46 | 47 | if not return_dict: 48 | return (hidden_states,) + encoder_outputs[1:] 49 | 50 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, 51 | attentions=encoder_outputs.attentions, ) 52 | -------------------------------------------------------------------------------- /diffposetalk/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | -------------------------------------------------------------------------------- /diffposetalk/utils/media.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | from pathlib import Path 4 | 5 | 6 | def combine_video_and_audio(video_file, audio_file, output, quality=17, copy_audio=True): 7 | audio_codec = '-c:a copy' if copy_audio else '' 8 | cmd = f'ffmpeg -i {video_file} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \ 9 | f'{audio_codec} -fflags +shortest -y -hide_banner -loglevel error {output}' 10 | assert subprocess.run(shlex.split(cmd)).returncode == 0 11 | 12 | 13 | def combine_frames_and_audio(frame_files, audio_file, fps, output, quality=17): 14 | cmd = f'ffmpeg -framerate {fps} -i {frame_files} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \ 15 | f'-c:a copy -fflags +shortest -y -hide_banner -loglevel error {output}' 16 | assert subprocess.run(shlex.split(cmd)).returncode == 0 17 | 18 | 19 | def convert_video(video_file, output, quality=17): 20 | cmd = f'ffmpeg -i {video_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \ 21 | f'-fflags +shortest -y -hide_banner -loglevel error {output}' 22 | assert subprocess.run(shlex.split(cmd)).returncode == 0 23 | 24 | 25 | def reencode_audio(audio_file, output): 26 | cmd = f'ffmpeg -i {audio_file} -y -hide_banner -loglevel error {output}' 27 | assert subprocess.run(shlex.split(cmd)).returncode == 0 28 | 29 | 30 | def extract_frames(filename, output_dir, quality=1): 31 | output_dir = Path(output_dir) 32 | output_dir.mkdir(parents=True, exist_ok=True) 33 | cmd = f'ffmpeg -i {filename} -qmin 1 -qscale:v {quality} -y -start_number 0 -hide_banner -loglevel error ' \ 34 | f'{output_dir / "%06d.jpg"}' 35 | assert subprocess.run(shlex.split(cmd)).returncode == 0 36 | -------------------------------------------------------------------------------- /diffposetalk/utils/renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import cv2 5 | import kiui.mesh 6 | import numpy as np 7 | 8 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # osmesa or egl 9 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 10 | import pyrender 11 | import trimesh 12 | # from psbody.mesh import Mesh 13 | 14 | 15 | class MeshRenderer: 16 | def __init__(self, size, fov=16 / 180 * np.pi, camera_pose=None, light_pose=None, black_bg=False): 17 | # Camera 18 | self.frustum = {'near': 0.01, 'far': 3.0} 19 | self.camera = pyrender.PerspectiveCamera(yfov=fov, znear=self.frustum['near'], 20 | zfar=self.frustum['far'], aspectRatio=1.0) 21 | 22 | # Material 23 | self.primitive_material = pyrender.material.MetallicRoughnessMaterial( 24 | alphaMode='BLEND', 25 | baseColorFactor=[0.3, 0.3, 0.3, 1.0], 26 | metallicFactor=0.8, 27 | roughnessFactor=0.8 28 | ) 29 | 30 | # Lighting 31 | light_color = np.array([1., 1., 1.]) 32 | self.light = pyrender.DirectionalLight(color=light_color, intensity=2) 33 | self.light_angle = np.pi / 6.0 34 | 35 | # Scene 36 | self.scene = None 37 | self._init_scene(black_bg) 38 | 39 | # add camera and lighting 40 | self._init_camera(camera_pose) 41 | self._init_lighting(light_pose) 42 | 43 | # Renderer 44 | self.renderer = pyrender.OffscreenRenderer(*size, point_size=1.0) 45 | 46 | def _init_scene(self, black_bg=False): 47 | if black_bg: 48 | self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0]) 49 | else: 50 | self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255]) 51 | 52 | def _init_camera(self, camera_pose=None): 53 | if camera_pose is None: 54 | camera_pose = np.eye(4) 55 | camera_pose[:3, 3] = np.array([0, 0, 1]) 56 | self.camera_pose = camera_pose.copy() 57 | self.camera_node = self.scene.add(self.camera, pose=camera_pose) 58 | 59 | def _init_lighting(self, light_pose=None): 60 | if light_pose is None: 61 | light_pose = np.eye(4) 62 | light_pose[:3, 3] = np.array([0, 0, 1]) 63 | self.light_pose = light_pose.copy() 64 | 65 | light_poses = self._get_light_poses(self.light_angle, light_pose) 66 | self.light_nodes = [self.scene.add(self.light, pose=light_pose) for light_pose in light_poses] 67 | 68 | def set_camera_pose(self, camera_pose): 69 | self.camera_pose = camera_pose.copy() 70 | self.scene.set_pose(self.camera_node, pose=camera_pose) 71 | 72 | def set_lighting_pose(self, light_pose): 73 | self.light_pose = light_pose.copy() 74 | 75 | light_poses = self._get_light_poses(self.light_angle, light_pose) 76 | for light_node, light_pose in zip(self.light_nodes, light_poses): 77 | self.scene.set_pose(light_node, pose=light_pose) 78 | 79 | def render_mesh(self, v, f, t_center, rot=np.zeros(3), tex_img=None, tex_uv=None, 80 | camera_pose=None, light_pose=None): 81 | # Prepare mesh 82 | v[:] = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center 83 | if tex_img is not None: 84 | tex = pyrender.Texture(source=tex_img, source_channels='RGB') 85 | tex_material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex) 86 | from kiui.mesh import Mesh 87 | import torch 88 | mesh = Mesh( 89 | v=torch.from_numpy(v), 90 | f=torch.from_numpy(f), 91 | vt=tex_uv['vt'], 92 | ft=tex_uv['ft'] 93 | ) 94 | with tempfile.NamedTemporaryFile(suffix='.obj') as f: 95 | mesh.write_obj(f.name) 96 | tri_mesh = trimesh.load(f.name, process=False) 97 | return tri_mesh 98 | # tri_mesh = self._pyrender_mesh_workaround(mesh) 99 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=tex_material) 100 | else: 101 | tri_mesh = trimesh.Trimesh(vertices=v, faces=f) 102 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=self.primitive_material, smooth=True) 103 | mesh_node = self.scene.add(render_mesh, pose=np.eye(4)) 104 | 105 | # Change camera and lighting pose if necessary 106 | if camera_pose is not None: 107 | self.set_camera_pose(camera_pose) 108 | if light_pose is not None: 109 | self.set_lighting_pose(light_pose) 110 | 111 | # Render 112 | flags = pyrender.RenderFlags.SKIP_CULL_FACES 113 | color, depth = self.renderer.render(self.scene, flags=flags) 114 | 115 | # Remove mesh 116 | self.scene.remove_node(mesh_node) 117 | 118 | return color, depth 119 | 120 | @staticmethod 121 | def _get_light_poses(light_angle, light_pose): 122 | light_poses = [] 123 | init_pos = light_pose[:3, 3].copy() 124 | 125 | light_poses.append(light_pose.copy()) 126 | 127 | light_pose[:3, 3] = cv2.Rodrigues(np.array([light_angle, 0, 0]))[0].dot(init_pos) 128 | light_poses.append(light_pose.copy()) 129 | 130 | light_pose[:3, 3] = cv2.Rodrigues(np.array([-light_angle, 0, 0]))[0].dot(init_pos) 131 | light_poses.append(light_pose.copy()) 132 | 133 | light_pose[:3, 3] = cv2.Rodrigues(np.array([0, -light_angle, 0]))[0].dot(init_pos) 134 | light_poses.append(light_pose.copy()) 135 | 136 | light_pose[:3, 3] = cv2.Rodrigues(np.array([0, light_angle, 0]))[0].dot(init_pos) 137 | light_poses.append(light_pose.copy()) 138 | 139 | return light_poses 140 | 141 | @staticmethod 142 | def _pyrender_mesh_workaround(mesh): 143 | # Workaround as pyrender requires number of vertices and uv coordinates to be the same 144 | with tempfile.NamedTemporaryFile(suffix='.obj') as f: 145 | mesh.write_obj(f.name) 146 | tri_mesh = trimesh.load(f.name, process=False) 147 | return tri_mesh 148 | -------------------------------------------------------------------------------- /diffposetalk/wav2vec2.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import transformers 8 | from transformers import Wav2Vec2Model 9 | from transformers.modeling_outputs import BaseModelOutput 10 | 11 | _CONFIG_FOR_DOC = 'Wav2Vec2Config' 12 | 13 | 14 | # the implementation of Wav2Vec2Model is borrowed from 15 | # https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model 16 | # initialize our encoder with the pre-trained wav2vec 2.0 weights. 17 | def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int, 18 | attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray: 19 | bsz, all_sz = shape 20 | mask = np.full((bsz, all_sz), False) 21 | 22 | all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand()) 23 | all_num_mask = max(min_masks, all_num_mask) 24 | mask_idcs = [] 25 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None 26 | for i in range(bsz): 27 | if padding_mask is not None: 28 | sz = all_sz - padding_mask[i].long().sum().item() 29 | num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand()) 30 | num_mask = max(min_masks, num_mask) 31 | else: 32 | sz = all_sz 33 | num_mask = all_num_mask 34 | 35 | lengths = np.full(num_mask, mask_length) 36 | 37 | if sum(lengths) == 0: 38 | lengths[0] = min(mask_length, sz - 1) 39 | 40 | min_len = min(lengths) 41 | if sz - min_len <= num_mask: 42 | min_len = sz - num_mask - 1 43 | 44 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 45 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) 46 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 47 | 48 | min_len = min([len(m) for m in mask_idcs]) 49 | for i, mask_idc in enumerate(mask_idcs): 50 | if len(mask_idc) > min_len: 51 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 52 | mask[i, mask_idc] = True 53 | return mask 54 | 55 | 56 | # linear interpolation layer 57 | def linear_interpolation(features, input_fps, output_fps, output_len=None): 58 | # features: (N, C, L) 59 | seq_len = features.shape[2] / float(input_fps) 60 | if output_len is None: 61 | output_len = int(seq_len * output_fps) 62 | output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear') 63 | return output_features 64 | 65 | 66 | class Wav2Vec2Model(Wav2Vec2Model): 67 | def __init__(self, config): 68 | super().__init__(config) 69 | self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0') 70 | 71 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None, 72 | output_hidden_states=None, return_dict=None, frame_num=None): 73 | self.config.output_attentions = True 74 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 75 | output_hidden_states = ( 76 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) 77 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 78 | 79 | hidden_states = self.feature_extractor(input_values) # (N, C, L) 80 | # Resample the audio feature @ 50 fps to `output_fps`. 81 | if frame_num is not None: 82 | hidden_states_len = round(frame_num * 50 / output_fps) 83 | hidden_states = hidden_states[:, :, :hidden_states_len] 84 | hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num) 85 | hidden_states = hidden_states.transpose(1, 2) # (N, L, C) 86 | 87 | if attention_mask is not None: 88 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) 89 | attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype, 90 | device=hidden_states.device) 91 | attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1 92 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() 93 | 94 | if self.is_old_version: 95 | hidden_states = self.feature_projection(hidden_states) 96 | else: 97 | hidden_states = self.feature_projection(hidden_states)[0] 98 | 99 | if self.config.apply_spec_augment and self.training: 100 | batch_size, sequence_length, hidden_size = hidden_states.size() 101 | if self.config.mask_time_prob > 0: 102 | mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob, 103 | self.config.mask_time_length, attention_mask=attention_mask, 104 | min_masks=2, ) 105 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) 106 | if self.config.mask_feature_prob > 0: 107 | mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob, 108 | self.config.mask_feature_length, ) 109 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) 110 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 111 | encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask, 112 | output_attentions=output_attentions, output_hidden_states=output_hidden_states, 113 | return_dict=return_dict, ) 114 | hidden_states = encoder_outputs[0] 115 | if not return_dict: 116 | return (hidden_states,) + encoder_outputs[1:] 117 | 118 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, 119 | attentions=encoder_outputs.attentions, ) 120 | -------------------------------------------------------------------------------- /eval/arc_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from insightface.app import FaceAnalysis 4 | from insightface.utils import face_align 5 | from PIL import Image 6 | from torchvision import models, transforms 7 | from curricularface import get_model 8 | import cv2 9 | import numpy as np 10 | import numpy 11 | 12 | 13 | def matrix_sqrt(matrix): 14 | eigenvalues, eigenvectors = torch.linalg.eigh(matrix) 15 | sqrt_eigenvalues = torch.sqrt(torch.clamp(eigenvalues, min=0)) 16 | sqrt_matrix = (eigenvectors * sqrt_eigenvalues).mm(eigenvectors.T) 17 | return sqrt_matrix 18 | 19 | def sample_video_frames(video_path, num_frames=16): 20 | cap = cv2.VideoCapture(video_path) 21 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 22 | frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) 23 | 24 | frames = [] 25 | for idx in frame_indices: 26 | cap.set(cv2.CAP_PROP_POS_FRAMES, idx) 27 | ret, frame = cap.read() 28 | if ret: 29 | # print(frame.shape) 30 | #if frame.shape[1] > 1024: 31 | # frame = frame[:, 1440:, :] 32 | # print(frame.shape) 33 | frames.append(frame) 34 | cap.release() 35 | return frames 36 | 37 | 38 | def get_face_keypoints(face_model, image_bgr): 39 | face_info = face_model.get(image_bgr) 40 | if len(face_info) > 0: 41 | return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] 42 | return None 43 | 44 | def load_image(image): 45 | img = image.convert('RGB') 46 | img = transforms.Resize((299, 299))(img) # Resize to Inception input size 47 | img = transforms.ToTensor()(img) 48 | return img.unsqueeze(0) # Add batch dimension 49 | 50 | def calculate_fid(real_activations, fake_activations, device="cuda"): 51 | real_activations_tensor = torch.tensor(real_activations).to(device) 52 | fake_activations_tensor = torch.tensor(fake_activations).to(device) 53 | 54 | mu1 = real_activations_tensor.mean(dim=0) 55 | sigma1 = torch.cov(real_activations_tensor.T) 56 | mu2 = fake_activations_tensor.mean(dim=0) 57 | sigma2 = torch.cov(fake_activations_tensor.T) 58 | 59 | ssdiff = torch.sum((mu1 - mu2) ** 2) 60 | covmean = matrix_sqrt(sigma1.mm(sigma2)) 61 | if torch.is_complex(covmean): 62 | covmean = covmean.real 63 | fid = ssdiff + torch.trace(sigma1 + sigma2 - 2 * covmean) 64 | return fid.item() 65 | 66 | def batch_cosine_similarity(embedding_image, embedding_frames, device="cuda"): 67 | embedding_image = torch.tensor(embedding_image).to(device) 68 | embedding_frames = torch.tensor(embedding_frames).to(device) 69 | return torch.nn.functional.cosine_similarity(embedding_image, embedding_frames, dim=-1).cpu().numpy() 70 | 71 | 72 | def get_activations(images, model, batch_size=16): 73 | model.eval() 74 | activations = [] 75 | with torch.no_grad(): 76 | for i in range(0, len(images), batch_size): 77 | batch = images[i:i + batch_size] 78 | pred = model(batch) 79 | activations.append(pred) 80 | activations = torch.cat(activations, dim=0).cpu().numpy() 81 | if activations.shape[0] == 1: 82 | activations = np.repeat(activations, 2, axis=0) 83 | return activations 84 | 85 | def pad_np_bgr_image(np_image, scale=1.25): 86 | assert scale >= 1.0, "scale should be >= 1.0" 87 | pad_scale = scale - 1.0 88 | h, w = np_image.shape[:2] 89 | top = bottom = int(h * pad_scale) 90 | left = right = int(w * pad_scale) 91 | return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top) 92 | 93 | 94 | def process_image(face_model, image_path): 95 | if isinstance(image_path, str): 96 | np_faceid_image = np.array(Image.open(image_path).convert("RGB")) 97 | elif isinstance(image_path, numpy.ndarray): 98 | np_faceid_image = image_path 99 | else: 100 | raise TypeError("image_path should be a string or PIL.Image.Image object") 101 | 102 | image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR) 103 | 104 | face_info = get_face_keypoints(face_model, image_bgr) 105 | if face_info is None: 106 | padded_image, sub_coord = pad_np_bgr_image(image_bgr) 107 | face_info = get_face_keypoints(face_model, padded_image) 108 | if face_info is None: 109 | print("Warning: No face detected in the image. Continuing processing...") 110 | return None, None 111 | face_kps = face_info['kps'] 112 | face_kps -= np.array(sub_coord) 113 | else: 114 | face_kps = face_info['kps'] 115 | arcface_embedding = face_info['embedding'] 116 | # print(face_kps) 117 | norm_face = face_align.norm_crop(image_bgr, landmark=face_kps, image_size=224) 118 | align_face = cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB) 119 | 120 | return align_face, arcface_embedding 121 | 122 | @torch.no_grad() 123 | def inference(face_model, img, device): 124 | img = cv2.resize(img, (112, 112)) 125 | img = np.transpose(img, (2, 0, 1)) 126 | img = torch.from_numpy(img).unsqueeze(0).float().to(device) 127 | img.div_(255).sub_(0.5).div_(0.5) 128 | embedding = face_model(img).detach().cpu().numpy()[0] 129 | return embedding / np.linalg.norm(embedding) 130 | 131 | 132 | def process_video(video_path, face_arc_model, face_cur_model, fid_model, arcface_image_embedding, cur_image_embedding, real_activations, device): 133 | video_frames = sample_video_frames(video_path, num_frames=16) 134 | #print(video_frames) 135 | # Initialize lists to store the scores 136 | cur_scores = [] 137 | arc_scores = [] 138 | fid_face = [] 139 | 140 | for frame in video_frames: 141 | # Convert to RGB once at the beginning 142 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 143 | 144 | # Process the frame for ArcFace embeddings 145 | align_face_frame, arcface_frame_embedding = process_image(face_arc_model, frame_rgb) 146 | 147 | # Skip if alignment fails 148 | if align_face_frame is None: 149 | continue 150 | 151 | # Perform inference for current face model 152 | cur_embedding_frame = inference(face_cur_model, align_face_frame, device) 153 | 154 | # Compute cosine similarity for cur_score and arc_score in a compact manner 155 | cur_score = max(0.0, batch_cosine_similarity(cur_image_embedding, cur_embedding_frame, device=device).item()) 156 | arc_score = max(0.0, batch_cosine_similarity(arcface_image_embedding, arcface_frame_embedding, device=device).item()) 157 | 158 | # Process FID score 159 | align_face_frame_pil = Image.fromarray(align_face_frame) 160 | fake_image = load_image(align_face_frame_pil).to(device) 161 | fake_activations = get_activations(fake_image, fid_model) 162 | fid_score = calculate_fid(real_activations, fake_activations, device) 163 | 164 | # Collect scores 165 | fid_face.append(fid_score) 166 | cur_scores.append(cur_score) 167 | arc_scores.append(arc_score) 168 | 169 | # Aggregate results with default values for empty lists 170 | avg_cur_score = np.mean(cur_scores) if cur_scores else 0.0 171 | avg_arc_score = np.mean(arc_scores) if arc_scores else 0.0 172 | avg_fid_score = np.mean(fid_face) if fid_face else 0.0 173 | 174 | return avg_cur_score, avg_arc_score, avg_fid_score 175 | 176 | 177 | 178 | def main(): 179 | device = "cuda" 180 | # data_path = "data/SkyActor" 181 | # data_path = "data/LivePotraits" 182 | # data_path = "data/Actor-One" 183 | data_path = "data/FollowYourEmoji" 184 | img_path = "/maindata/data/shared/public/rui.wang/act_review/ref_images" 185 | pre_tag = False 186 | mp4_list = os.listdir(data_path) 187 | print(mp4_list) 188 | 189 | img_list = [] 190 | video_list = [] 191 | for mp4 in mp4_list: 192 | if "mp4" not in mp4: 193 | continue 194 | if pre_tag: 195 | png_path = mp4.split('.')[0].split('-')[0] + ".png" 196 | else: 197 | if "-" in mp4: 198 | png_path = mp4.split('.')[0].split('-')[1] + ".png" 199 | else: 200 | png_path = mp4.split('.')[0].split('_')[1] + ".png" 201 | img_list.append(os.path.join(img_path, png_path)) 202 | video_list.append(os.path.join(data_path, mp4)) 203 | print(img_list) 204 | print(video_list[0]) 205 | 206 | model_path = "eval" 207 | face_arc_path = os.path.join(model_path, "face_encoder") 208 | face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin") 209 | 210 | # Initialize FaceEncoder model for face detection and embedding extraction 211 | face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider']) 212 | face_arc_model.prepare(ctx_id=0, det_size=(320, 320)) 213 | 214 | # Load face recognition model 215 | face_cur_model = get_model('IR_101')([112, 112]) 216 | face_cur_model.load_state_dict(torch.load(face_cur_path, map_location="cpu")) 217 | face_cur_model = face_cur_model.to(device) 218 | face_cur_model.eval() 219 | 220 | # Load InceptionV3 model for FID calculation 221 | fid_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT) 222 | fid_model.fc = torch.nn.Identity() # Remove final classification layer 223 | fid_model.eval() 224 | fid_model = fid_model.to(device) 225 | 226 | # Process the single video and image pair 227 | # Extract embeddings and features from the image 228 | cur_list, arc_list, fid_list = [], [], [] 229 | for i in range(len(img_list)): 230 | align_face_image, arcface_image_embedding = process_image(face_arc_model, img_list[i]) 231 | 232 | cur_image_embedding = inference(face_cur_model, align_face_image, device) 233 | align_face_image_pil = Image.fromarray(align_face_image) 234 | real_image = load_image(align_face_image_pil).to(device) 235 | real_activations = get_activations(real_image, fid_model) 236 | 237 | # Process the video and calculate scores 238 | cur_score, arc_score, fid_score = process_video( 239 | video_list[i], face_arc_model, face_cur_model, fid_model, 240 | arcface_image_embedding, cur_image_embedding, real_activations, device 241 | ) 242 | print(cur_score, arc_score, fid_score) 243 | cur_list.append(cur_score) 244 | arc_list.append(arc_score) 245 | fid_list.append(fid_score) 246 | # break 247 | print("cur", sum(cur_list)/ len(cur_list)) 248 | print("arc", sum(arc_list)/ len(arc_list)) 249 | print("fid", sum(fid_list)/ len(fid_list)) 250 | 251 | 252 | 253 | main() 254 | -------------------------------------------------------------------------------- /eval/curricularface/__init__.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone 3 | from .model_irse import IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200 4 | from .model_resnet import ResNet_50, ResNet_101, ResNet_152 5 | 6 | 7 | _model_dict = { 8 | 'ResNet_50': ResNet_50, 9 | 'ResNet_101': ResNet_101, 10 | 'ResNet_152': ResNet_152, 11 | 'IR_18': IR_18, 12 | 'IR_34': IR_34, 13 | 'IR_50': IR_50, 14 | 'IR_101': IR_101, 15 | 'IR_152': IR_152, 16 | 'IR_200': IR_200, 17 | 'IR_SE_50': IR_SE_50, 18 | 'IR_SE_101': IR_SE_101, 19 | 'IR_SE_152': IR_SE_152, 20 | 'IR_SE_200': IR_SE_200 21 | } 22 | 23 | 24 | def get_model(key): 25 | """ Get different backbone network by key, 26 | support ResNet50, ResNet_101, ResNet_152 27 | IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, 28 | IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200. 29 | """ 30 | if key in _model_dict.keys(): 31 | return _model_dict[key] 32 | else: 33 | raise KeyError('not support model {}'.format(key)) 34 | -------------------------------------------------------------------------------- /eval/curricularface/common.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py 3 | import torch.nn as nn 4 | from torch.nn import Conv2d, Module, ReLU, Sigmoid 5 | 6 | 7 | def initialize_weights(modules): 8 | """ Weight initilize, conv2d and linear is initialized with kaiming_normal 9 | """ 10 | for m in modules: 11 | if isinstance(m, nn.Conv2d): 12 | nn.init.kaiming_normal_( 13 | m.weight, mode='fan_out', nonlinearity='relu') 14 | if m.bias is not None: 15 | m.bias.data.zero_() 16 | elif isinstance(m, nn.BatchNorm2d): 17 | m.weight.data.fill_(1) 18 | m.bias.data.zero_() 19 | elif isinstance(m, nn.Linear): 20 | nn.init.kaiming_normal_( 21 | m.weight, mode='fan_out', nonlinearity='relu') 22 | if m.bias is not None: 23 | m.bias.data.zero_() 24 | 25 | 26 | class Flatten(Module): 27 | """ Flat tensor 28 | """ 29 | 30 | def forward(self, input): 31 | return input.view(input.size(0), -1) 32 | 33 | 34 | class SEModule(Module): 35 | """ SE block 36 | """ 37 | 38 | def __init__(self, channels, reduction): 39 | super(SEModule, self).__init__() 40 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 41 | self.fc1 = Conv2d( 42 | channels, 43 | channels // reduction, 44 | kernel_size=1, 45 | padding=0, 46 | bias=False) 47 | 48 | nn.init.xavier_uniform_(self.fc1.weight.data) 49 | 50 | self.relu = ReLU(inplace=True) 51 | self.fc2 = Conv2d( 52 | channels // reduction, 53 | channels, 54 | kernel_size=1, 55 | padding=0, 56 | bias=False) 57 | 58 | self.sigmoid = Sigmoid() 59 | 60 | def forward(self, x): 61 | module_input = x 62 | x = self.avg_pool(x) 63 | x = self.fc1(x) 64 | x = self.relu(x) 65 | x = self.fc2(x) 66 | x = self.sigmoid(x) 67 | 68 | return module_input * x 69 | -------------------------------------------------------------------------------- /eval/curricularface/model_irse.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py 3 | from collections import namedtuple 4 | 5 | from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, Sequential 6 | 7 | from .common import Flatten, SEModule, initialize_weights 8 | 9 | 10 | class BasicBlockIR(Module): 11 | """ BasicBlock for IRNet 12 | """ 13 | 14 | def __init__(self, in_channel, depth, stride): 15 | super(BasicBlockIR, self).__init__() 16 | if in_channel == depth: 17 | self.shortcut_layer = MaxPool2d(1, stride) 18 | else: 19 | self.shortcut_layer = Sequential( 20 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 21 | BatchNorm2d(depth)) 22 | self.res_layer = Sequential( 23 | BatchNorm2d(in_channel), 24 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 25 | BatchNorm2d(depth), PReLU(depth), 26 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 27 | BatchNorm2d(depth)) 28 | 29 | def forward(self, x): 30 | shortcut = self.shortcut_layer(x) 31 | res = self.res_layer(x) 32 | 33 | return res + shortcut 34 | 35 | 36 | class BottleneckIR(Module): 37 | """ BasicBlock with bottleneck for IRNet 38 | """ 39 | 40 | def __init__(self, in_channel, depth, stride): 41 | super(BottleneckIR, self).__init__() 42 | reduction_channel = depth // 4 43 | if in_channel == depth: 44 | self.shortcut_layer = MaxPool2d(1, stride) 45 | else: 46 | self.shortcut_layer = Sequential( 47 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 48 | BatchNorm2d(depth)) 49 | self.res_layer = Sequential( 50 | BatchNorm2d(in_channel), 51 | Conv2d( 52 | in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), 53 | BatchNorm2d(reduction_channel), PReLU(reduction_channel), 54 | Conv2d( 55 | reduction_channel, 56 | reduction_channel, (3, 3), (1, 1), 57 | 1, 58 | bias=False), BatchNorm2d(reduction_channel), 59 | PReLU(reduction_channel), 60 | Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), 61 | BatchNorm2d(depth)) 62 | 63 | def forward(self, x): 64 | shortcut = self.shortcut_layer(x) 65 | res = self.res_layer(x) 66 | 67 | return res + shortcut 68 | 69 | 70 | class BasicBlockIRSE(BasicBlockIR): 71 | 72 | def __init__(self, in_channel, depth, stride): 73 | super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) 74 | self.res_layer.add_module('se_block', SEModule(depth, 16)) 75 | 76 | 77 | class BottleneckIRSE(BottleneckIR): 78 | 79 | def __init__(self, in_channel, depth, stride): 80 | super(BottleneckIRSE, self).__init__(in_channel, depth, stride) 81 | self.res_layer.add_module('se_block', SEModule(depth, 16)) 82 | 83 | 84 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 85 | '''A named tuple describing a ResNet block.''' 86 | 87 | 88 | def get_block(in_channel, depth, num_units, stride=2): 89 | return [Bottleneck(in_channel, depth, stride)] + \ 90 | [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 91 | 92 | 93 | def get_blocks(num_layers): 94 | if num_layers == 18: 95 | blocks = [ 96 | get_block(in_channel=64, depth=64, num_units=2), 97 | get_block(in_channel=64, depth=128, num_units=2), 98 | get_block(in_channel=128, depth=256, num_units=2), 99 | get_block(in_channel=256, depth=512, num_units=2) 100 | ] 101 | elif num_layers == 34: 102 | blocks = [ 103 | get_block(in_channel=64, depth=64, num_units=3), 104 | get_block(in_channel=64, depth=128, num_units=4), 105 | get_block(in_channel=128, depth=256, num_units=6), 106 | get_block(in_channel=256, depth=512, num_units=3) 107 | ] 108 | elif num_layers == 50: 109 | blocks = [ 110 | get_block(in_channel=64, depth=64, num_units=3), 111 | get_block(in_channel=64, depth=128, num_units=4), 112 | get_block(in_channel=128, depth=256, num_units=14), 113 | get_block(in_channel=256, depth=512, num_units=3) 114 | ] 115 | elif num_layers == 100: 116 | blocks = [ 117 | get_block(in_channel=64, depth=64, num_units=3), 118 | get_block(in_channel=64, depth=128, num_units=13), 119 | get_block(in_channel=128, depth=256, num_units=30), 120 | get_block(in_channel=256, depth=512, num_units=3) 121 | ] 122 | elif num_layers == 152: 123 | blocks = [ 124 | get_block(in_channel=64, depth=256, num_units=3), 125 | get_block(in_channel=256, depth=512, num_units=8), 126 | get_block(in_channel=512, depth=1024, num_units=36), 127 | get_block(in_channel=1024, depth=2048, num_units=3) 128 | ] 129 | elif num_layers == 200: 130 | blocks = [ 131 | get_block(in_channel=64, depth=256, num_units=3), 132 | get_block(in_channel=256, depth=512, num_units=24), 133 | get_block(in_channel=512, depth=1024, num_units=36), 134 | get_block(in_channel=1024, depth=2048, num_units=3) 135 | ] 136 | 137 | return blocks 138 | 139 | 140 | class Backbone(Module): 141 | 142 | def __init__(self, input_size, num_layers, mode='ir'): 143 | """ Args: 144 | input_size: input_size of backbone 145 | num_layers: num_layers of backbone 146 | mode: support ir or irse 147 | """ 148 | super(Backbone, self).__init__() 149 | assert input_size[0] in [112, 224], \ 150 | 'input_size should be [112, 112] or [224, 224]' 151 | assert num_layers in [18, 34, 50, 100, 152, 200], \ 152 | 'num_layers should be 18, 34, 50, 100 or 152' 153 | assert mode in ['ir', 'ir_se'], \ 154 | 'mode should be ir or ir_se' 155 | self.input_layer = Sequential( 156 | Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), 157 | PReLU(64)) 158 | blocks = get_blocks(num_layers) 159 | if num_layers <= 100: 160 | if mode == 'ir': 161 | unit_module = BasicBlockIR 162 | elif mode == 'ir_se': 163 | unit_module = BasicBlockIRSE 164 | output_channel = 512 165 | else: 166 | if mode == 'ir': 167 | unit_module = BottleneckIR 168 | elif mode == 'ir_se': 169 | unit_module = BottleneckIRSE 170 | output_channel = 2048 171 | 172 | if input_size[0] == 112: 173 | self.output_layer = Sequential( 174 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(), 175 | Linear(output_channel * 7 * 7, 512), 176 | BatchNorm1d(512, affine=False)) 177 | else: 178 | self.output_layer = Sequential( 179 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(), 180 | Linear(output_channel * 14 * 14, 512), 181 | BatchNorm1d(512, affine=False)) 182 | 183 | modules = [] 184 | mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101 185 | for block in blocks: 186 | if len(mid_layer_indices) == 0: 187 | mid_layer_indices.append(len(block) - 1) 188 | else: 189 | mid_layer_indices.append(len(block) + mid_layer_indices[-1]) 190 | for bottleneck in block: 191 | modules.append( 192 | unit_module(bottleneck.in_channel, bottleneck.depth, 193 | bottleneck.stride)) 194 | self.body = Sequential(*modules) 195 | self.mid_layer_indices = mid_layer_indices[-4:] 196 | 197 | # self.dtype = next(self.parameters()).dtype 198 | initialize_weights(self.modules()) 199 | 200 | def device(self): 201 | return next(self.parameters()).device 202 | 203 | def dtype(self): 204 | return next(self.parameters()).dtype 205 | 206 | def forward(self, x, return_mid_feats=False): 207 | x = self.input_layer(x) 208 | if not return_mid_feats: 209 | x = self.body(x) 210 | x = self.output_layer(x) 211 | return x 212 | else: 213 | out_feats = [] 214 | for idx, module in enumerate(self.body): 215 | x = module(x) 216 | if idx in self.mid_layer_indices: 217 | out_feats.append(x) 218 | x = self.output_layer(x) 219 | return x, out_feats 220 | 221 | 222 | def IR_18(input_size): 223 | """ Constructs a ir-18 model. 224 | """ 225 | model = Backbone(input_size, 18, 'ir') 226 | 227 | return model 228 | 229 | 230 | def IR_34(input_size): 231 | """ Constructs a ir-34 model. 232 | """ 233 | model = Backbone(input_size, 34, 'ir') 234 | 235 | return model 236 | 237 | 238 | def IR_50(input_size): 239 | """ Constructs a ir-50 model. 240 | """ 241 | model = Backbone(input_size, 50, 'ir') 242 | 243 | return model 244 | 245 | 246 | def IR_101(input_size): 247 | """ Constructs a ir-101 model. 248 | """ 249 | model = Backbone(input_size, 100, 'ir') 250 | 251 | return model 252 | 253 | 254 | def IR_152(input_size): 255 | """ Constructs a ir-152 model. 256 | """ 257 | model = Backbone(input_size, 152, 'ir') 258 | 259 | return model 260 | 261 | 262 | def IR_200(input_size): 263 | """ Constructs a ir-200 model. 264 | """ 265 | model = Backbone(input_size, 200, 'ir') 266 | 267 | return model 268 | 269 | 270 | def IR_SE_50(input_size): 271 | """ Constructs a ir_se-50 model. 272 | """ 273 | model = Backbone(input_size, 50, 'ir_se') 274 | 275 | return model 276 | 277 | 278 | def IR_SE_101(input_size): 279 | """ Constructs a ir_se-101 model. 280 | """ 281 | model = Backbone(input_size, 100, 'ir_se') 282 | 283 | return model 284 | 285 | 286 | def IR_SE_152(input_size): 287 | """ Constructs a ir_se-152 model. 288 | """ 289 | model = Backbone(input_size, 152, 'ir_se') 290 | 291 | return model 292 | 293 | 294 | def IR_SE_200(input_size): 295 | """ Constructs a ir_se-200 model. 296 | """ 297 | model = Backbone(input_size, 200, 'ir_se') 298 | 299 | return model 300 | -------------------------------------------------------------------------------- /eval/curricularface/model_resnet.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py 3 | import torch.nn as nn 4 | from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential 5 | 6 | from .common import initialize_weights 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """ 3x3 convolution with padding 11 | """ 12 | return Conv2d( 13 | in_planes, 14 | out_planes, 15 | kernel_size=3, 16 | stride=stride, 17 | padding=1, 18 | bias=False) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """ 1x1 convolution 23 | """ 24 | return Conv2d( 25 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class Bottleneck(Module): 29 | expansion = 4 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(Bottleneck, self).__init__() 33 | self.conv1 = conv1x1(inplanes, planes) 34 | self.bn1 = BatchNorm2d(planes) 35 | self.conv2 = conv3x3(planes, planes, stride) 36 | self.bn2 = BatchNorm2d(planes) 37 | self.conv3 = conv1x1(planes, planes * self.expansion) 38 | self.bn3 = BatchNorm2d(planes * self.expansion) 39 | self.relu = ReLU(inplace=True) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv3(out) 55 | out = self.bn3(out) 56 | 57 | if self.downsample is not None: 58 | identity = self.downsample(x) 59 | 60 | out += identity 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class ResNet(Module): 67 | """ ResNet backbone 68 | """ 69 | 70 | def __init__(self, input_size, block, layers, zero_init_residual=True): 71 | """ Args: 72 | input_size: input_size of backbone 73 | block: block function 74 | layers: layers in each block 75 | """ 76 | super(ResNet, self).__init__() 77 | assert input_size[0] in [112, 224], \ 78 | 'input_size should be [112, 112] or [224, 224]' 79 | self.inplanes = 64 80 | self.conv1 = Conv2d( 81 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 82 | self.bn1 = BatchNorm2d(64) 83 | self.relu = ReLU(inplace=True) 84 | self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) 85 | self.layer1 = self._make_layer(block, 64, layers[0]) 86 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 87 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 88 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 89 | 90 | self.bn_o1 = BatchNorm2d(2048) 91 | self.dropout = Dropout() 92 | if input_size[0] == 112: 93 | self.fc = Linear(2048 * 4 * 4, 512) 94 | else: 95 | self.fc = Linear(2048 * 7 * 7, 512) 96 | self.bn_o2 = BatchNorm1d(512) 97 | 98 | initialize_weights(self.modules) 99 | if zero_init_residual: 100 | for m in self.modules(): 101 | if isinstance(m, Bottleneck): 102 | nn.init.constant_(m.bn3.weight, 0) 103 | 104 | def _make_layer(self, block, planes, blocks, stride=1): 105 | downsample = None 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = Sequential( 108 | conv1x1(self.inplanes, planes * block.expansion, stride), 109 | BatchNorm2d(planes * block.expansion), 110 | ) 111 | 112 | layers = [] 113 | layers.append(block(self.inplanes, planes, stride, downsample)) 114 | self.inplanes = planes * block.expansion 115 | for _ in range(1, blocks): 116 | layers.append(block(self.inplanes, planes)) 117 | 118 | return Sequential(*layers) 119 | 120 | def forward(self, x): 121 | x = self.conv1(x) 122 | x = self.bn1(x) 123 | x = self.relu(x) 124 | x = self.maxpool(x) 125 | 126 | x = self.layer1(x) 127 | x = self.layer2(x) 128 | x = self.layer3(x) 129 | x = self.layer4(x) 130 | 131 | x = self.bn_o1(x) 132 | x = self.dropout(x) 133 | x = x.view(x.size(0), -1) 134 | x = self.fc(x) 135 | x = self.bn_o2(x) 136 | 137 | return x 138 | 139 | 140 | def ResNet_50(input_size, **kwargs): 141 | """ Constructs a ResNet-50 model. 142 | """ 143 | model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs) 144 | 145 | return model 146 | 147 | 148 | def ResNet_101(input_size, **kwargs): 149 | """ Constructs a ResNet-101 model. 150 | """ 151 | model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs) 152 | 153 | return model 154 | 155 | 156 | def ResNet_152(input_size, **kwargs): 157 | """ Constructs a ResNet-152 model. 158 | """ 159 | model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs) 160 | 161 | return model 162 | -------------------------------------------------------------------------------- /eval/expression_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os 3 | import torch 4 | from insightface.app import FaceAnalysis 5 | from insightface.utils import face_align 6 | from PIL import Image 7 | from torchvision import models, transforms 8 | from curricularface import get_model 9 | import cv2 10 | import numpy as np 11 | import numpy 12 | 13 | def pad_np_bgr_image(np_image, scale=1.25): 14 | assert scale >= 1.0, "scale should be >= 1.0" 15 | pad_scale = scale - 1.0 16 | h, w = np_image.shape[:2] 17 | top = bottom = int(h * pad_scale) 18 | left = right = int(w * pad_scale) 19 | return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top) 20 | 21 | 22 | def sample_video_frames(video_path,): 23 | cap = cv2.VideoCapture(video_path) 24 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 25 | frame_indices = np.linspace(0, total_frames - 1, total_frames, dtype=int) 26 | 27 | frames = [] 28 | for idx in frame_indices: 29 | cap.set(cv2.CAP_PROP_POS_FRAMES, idx) 30 | ret, frame = cap.read() 31 | if ret: 32 | # if frame.shape[1] > 1024: 33 | # frame = frame[:, 1440:, :] 34 | # print(frame.shape) 35 | frame = cv2.resize(frame, (720, 480)) 36 | # print(frame.shape) 37 | frames.append(frame) 38 | cap.release() 39 | return frames 40 | 41 | 42 | def get_face_keypoints(face_model, image_bgr): 43 | face_info = face_model.get(image_bgr) 44 | if len(face_info) > 0: 45 | return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] 46 | return None 47 | 48 | def process_image(face_model, image_path): 49 | if isinstance(image_path, str): 50 | np_faceid_image = np.array(Image.open(image_path).convert("RGB")) 51 | elif isinstance(image_path, numpy.ndarray): 52 | np_faceid_image = image_path 53 | else: 54 | raise TypeError("image_path should be a string or PIL.Image.Image object") 55 | 56 | image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR) 57 | 58 | face_info = get_face_keypoints(face_model, image_bgr) 59 | if face_info is None: 60 | padded_image, sub_coord = pad_np_bgr_image(image_bgr) 61 | face_info = get_face_keypoints(face_model, padded_image) 62 | if face_info is None: 63 | print("Warning: No face detected in the image. Continuing processing...") 64 | return None 65 | face_kps = face_info['kps'] 66 | face_kps -= np.array(sub_coord) 67 | else: 68 | face_kps = face_info['kps'] 69 | return face_kps 70 | 71 | def process_video(video_path, face_arc_model): 72 | video_frames = sample_video_frames(video_path,) 73 | print(len(video_frames)) 74 | kps_list = [] 75 | for frame in video_frames: 76 | # Convert to RGB once at the beginning 77 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 78 | kps = process_image(face_arc_model, frame_rgb) 79 | if kps is None: 80 | return None 81 | # print(kps) 82 | kps_list.append(kps) 83 | return kps_list 84 | 85 | 86 | def calculate_l1_distance(list1, list2): 87 | """ 88 | 计算两个列表的 L1 距离 89 | :param list1: 第一个列表,形状为 (5, 2) 90 | :param list2: 第二个列表,形状为 (5, 2) 91 | :return: L1 距离 92 | """ 93 | # 将列表转换为 NumPy 数组 94 | list1 = np.array(list1) 95 | list2 = np.array(list2) 96 | 97 | # 计算每对点的 L1 距离 98 | l1_distances = np.abs(list1 - list2).sum(axis=1) 99 | 100 | # 返回所有点的 L1 距离之和 101 | return l1_distances.sum() 102 | 103 | 104 | def calculate_kps(list1, list2): 105 | distance_list = [] 106 | for kps1 in list1: 107 | min_dis = (480 + 720) * 5 + 1 108 | for kps2 in list2: 109 | min_dis = min(min_dis, calculate_l1_distance(kps1, kps2)) 110 | distance_list.append(min_dis/(480+720)/10) 111 | return sum(distance_list)/len(distance_list) 112 | 113 | 114 | def main(): 115 | device = "cuda" 116 | # data_path = "data/SkyActor" 117 | # data_path = "data/LivePotraits" 118 | # data_path = "data/Actor-One" 119 | data_path = "data/FollowYourEmoji" 120 | img_path = "/maindata/data/shared/public/rui.wang/act_review/driving_video" 121 | pre_tag = False 122 | mp4_list = os.listdir(data_path) 123 | print(mp4_list) 124 | 125 | img_list = [] 126 | video_list = [] 127 | for mp4 in mp4_list: 128 | if "mp4" not in mp4: 129 | continue 130 | if pre_tag: 131 | png_path = mp4.split('.')[0].split('--')[1] + ".mp4" 132 | else: 133 | if "-" in mp4: 134 | png_path = mp4.split('.')[0].split('-')[0] + ".mp4" 135 | else: 136 | png_path = mp4.split('.')[0].split('_')[0] + ".mp4" 137 | img_list.append(os.path.join(img_path, png_path)) 138 | video_list.append(os.path.join(data_path, mp4)) 139 | print(img_list) 140 | print(video_list[0]) 141 | 142 | model_path = "eval" 143 | face_arc_path = os.path.join(model_path, "face_encoder") 144 | face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin") 145 | 146 | # Initialize FaceEncoder model for face detection and embedding extraction 147 | face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider']) 148 | face_arc_model.prepare(ctx_id=0, det_size=(320, 320)) 149 | 150 | expression_list = [] 151 | for i in range(len(img_list)): 152 | print("number: ", str(i), " total: ", len(img_list), data_path) 153 | kps_1 = process_video(video_list[i], face_arc_model) 154 | kps_2 = process_video(img_list[i], face_arc_model) 155 | if kps_1 is None or kps_2 is None: 156 | continue 157 | 158 | dis = calculate_kps(kps_1, kps_2) 159 | print(dis) 160 | expression_list.append(dis) 161 | # break 162 | 163 | print("kps", sum(expression_list)/ len(expression_list)) 164 | 165 | 166 | 167 | main() 168 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | import insightface 7 | import cv2 8 | import subprocess 9 | import argparse 10 | from decord import VideoReader 11 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip 12 | from facexlib.parsing import init_parsing_model 13 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 14 | from insightface.app import FaceAnalysis 15 | 16 | from diffusers.models import AutoencoderKLCogVideoX 17 | from diffusers.utils import export_to_video, load_image 18 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel 19 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 20 | 21 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel 22 | from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline 23 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor 24 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor 25 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d 26 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool 27 | from skyreels_a1.src.multi_fps import multi_fps_tool 28 | 29 | def crop_and_resize(image, height, width): 30 | image = np.array(image) 31 | image_height, image_width, _ = image.shape 32 | if image_height / image_width < height / width: 33 | croped_width = int(image_height / height * width) 34 | left = (image_width - croped_width) // 2 35 | image = image[:, left: left+croped_width] 36 | image = Image.fromarray(image).resize((width, height)) 37 | else: 38 | pad = int((((width / height) * image_height) - image_width) / 2.) 39 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8) 40 | padded_image[:, pad:pad+image_width] = image 41 | image = Image.fromarray(padded_image).resize((width, height)) 42 | return image 43 | 44 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"): 45 | clip = ImageSequenceClip(samples, fps=fps) 46 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, 47 | ffmpeg_params=["-crf", "18", "-preset", "slow"]) 48 | 49 | def parse_video(driving_video_path, max_frame_num): 50 | vr = VideoReader(driving_video_path) 51 | fps = vr.get_avg_fps() 52 | video_length = len(vr) 53 | 54 | duration = video_length / fps 55 | target_times = np.arange(0, duration, 1/12) 56 | frame_indices = (target_times * fps).astype(np.int32) 57 | 58 | frame_indices = frame_indices[frame_indices < video_length] 59 | control_frames = vr.get_batch(frame_indices).asnumpy()[:(max_frame_num-1)] 60 | 61 | out_frames = len(control_frames) - 1 62 | if len(control_frames) < max_frame_num - 1: 63 | video_lenght_add = max_frame_num - len(control_frames) - 1 64 | control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1], [control_frames[-1]] * video_lenght_add), axis=0) 65 | else: 66 | control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1]), axis=0) 67 | 68 | return control_frames 69 | 70 | def exec_cmd(cmd): 71 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 72 | 73 | def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str): 74 | cmd = [ 75 | 'ffmpeg', 76 | '-y', 77 | '-i', f'"{silent_video_path}"', 78 | '-i', f'"{audio_video_path}"', 79 | '-map', '0:v', 80 | '-map', '1:a', 81 | '-c:v', 'copy', 82 | '-shortest', 83 | f'"{output_video_path}"' 84 | ] 85 | 86 | try: 87 | exec_cmd(' '.join(cmd)) 88 | print(f"Video with audio generated successfully: {output_video_path}") 89 | except subprocess.CalledProcessError as e: 90 | print(f"Error occurred: {e}") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser(description="Process video and image for face animation.") 95 | parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.') 96 | parser.add_argument('--driving_video_path', type=str, default="assets/driving_video/1.mp4", help='Path to the driving video.') 97 | parser.add_argument('--output_path', type=str, default="outputs", help='Path to save the output video.') 98 | args = parser.parse_args() 99 | 100 | guidance_scale = 3.0 101 | seed = 43 102 | num_inference_steps = 10 103 | sample_size = [480, 720] 104 | max_frame_num = 49 105 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates 106 | weight_dtype = torch.bfloat16 107 | save_path = args.output_path 108 | generator = torch.Generator(device="cuda").manual_seed(seed) 109 | model_name = "pretrained_models/SkyReels-A1-5B/" 110 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" 111 | 112 | lmk_extractor = LMKExtractor() 113 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') 114 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,) 115 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",) 116 | 117 | # siglip visual encoder 118 | siglip = SiglipVisionModel.from_pretrained(siglip_name) 119 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) 120 | 121 | # frame interpolation model 122 | if target_fps != 12: 123 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda") 124 | 125 | # skyreels a1 model 126 | transformer = CogVideoXTransformer3DModel.from_pretrained( 127 | model_name, 128 | subfolder="transformer" 129 | ).to(weight_dtype) 130 | 131 | vae = AutoencoderKLCogVideoX.from_pretrained( 132 | model_name, 133 | subfolder="vae" 134 | ).to(weight_dtype) 135 | 136 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( 137 | model_name, 138 | subfolder="pose_guider", 139 | ).to(weight_dtype) 140 | 141 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( 142 | model_name, 143 | transformer = transformer, 144 | vae = vae, 145 | lmk_encoder = lmk_encoder, 146 | image_encoder = siglip, 147 | feature_extractor = siglip_normalize, 148 | torch_dtype=torch.bfloat16 149 | ) 150 | 151 | pipe.to("cuda") 152 | pipe.enable_model_cpu_offload() 153 | pipe.vae.enable_tiling() 154 | 155 | control_frames = parse_video(args.driving_video_path, max_frame_num) 156 | 157 | # driving video crop face 158 | driving_video_crop = [] 159 | for control_frame in control_frames: 160 | frame, _, _ = processor.face_crop(control_frame) 161 | driving_video_crop.append(frame) 162 | 163 | image = load_image(image=args.image_path) 164 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) 165 | 166 | # ref image crop face 167 | ref_image, x1, y1 = processor.face_crop(np.array(image)) 168 | face_h, face_w, _, = ref_image.shape 169 | source_image = ref_image 170 | driving_video = driving_video_crop 171 | out_frames = processor.preprocess_lmk3d(source_image, driving_video) 172 | 173 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0) 174 | for ii in range(rescale_motions.shape[0]): 175 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] 176 | ref_image = cv2.resize(ref_image, (512, 512)) 177 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1]) 178 | 179 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True) 180 | 181 | first_motion = np.zeros_like(np.array(image)) 182 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img 183 | first_motion = first_motion[np.newaxis, :] 184 | 185 | motions = np.concatenate([first_motion, rescale_motions]) 186 | input_video = motions[:max_frame_num] 187 | 188 | face_helper.clean_all() 189 | face_helper.read_image(np.array(image)[:, :, ::-1]) 190 | face_helper.get_face_landmarks_5(only_center_face=True) 191 | face_helper.align_warp_face() 192 | align_face = face_helper.cropped_faces[0] 193 | image_face = align_face[:, :, ::-1] 194 | 195 | input_video = input_video[:max_frame_num] 196 | motions = np.array(input_video) 197 | 198 | # [F, H, W, C] 199 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) 200 | input_video = input_video / 255 201 | 202 | out_samples = [] 203 | 204 | with torch.no_grad(): 205 | sample = pipe( 206 | image=image, 207 | image_face=image_face, 208 | control_video = input_video, 209 | prompt = "", 210 | negative_prompt = "", 211 | height = sample_size[0], 212 | width = sample_size[1], 213 | num_frames = 49, 214 | generator = generator, 215 | guidance_scale = guidance_scale, 216 | num_inference_steps = num_inference_steps, 217 | ) 218 | out_samples.extend(sample.frames[0]) 219 | out_samples = out_samples[2:] 220 | 221 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_video_path).split(".")[0]+ ".mp4" 222 | 223 | if not os.path.exists(save_path): 224 | os.makedirs(save_path, exist_ok=True) 225 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4") 226 | 227 | if target_fps != 12: 228 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps) 229 | 230 | export_to_video(out_samples, video_path, fps=target_fps) 231 | add_audio_to_video(video_path, args.driving_video_path, video_path.split(".")[0] + "_audio.mp4") 232 | 233 | if target_fps == 12: 234 | target_h, target_w = sample_size[0], sample_size[1] 235 | final_images = [] 236 | final_images2 =[] 237 | rescale_motions = rescale_motions[1:] 238 | control_frames = control_frames[1:] 239 | for q in range(len(out_samples)): 240 | frame1 = image 241 | frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[q])).convert("RGB"), target_h, target_w) 242 | frame3 = Image.fromarray(np.array(out_samples[q])).convert("RGB") 243 | 244 | result = Image.new('RGB', (target_w * 3, target_h)) 245 | result.paste(frame1, (0, 0)) 246 | result.paste(frame2, (target_w, 0)) 247 | result.paste(frame3, (target_w * 2, 0)) 248 | final_images.append(np.array(result)) 249 | 250 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4") 251 | write_mp4(video_out_path, final_images, fps=12) 252 | 253 | add_audio_to_video(video_out_path, args.driving_video_path, video_out_path.split(".")[0] + f"_audio.mp4") 254 | -------------------------------------------------------------------------------- /inference_audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | import insightface 7 | import cv2 8 | import subprocess 9 | import argparse 10 | from decord import VideoReader 11 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip 12 | from facexlib.parsing import init_parsing_model 13 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 14 | from insightface.app import FaceAnalysis 15 | 16 | from diffusers.models import AutoencoderKLCogVideoX 17 | from diffusers.utils import export_to_video, load_image 18 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel 19 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 20 | 21 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel 22 | from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline 23 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor 24 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor 25 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d 26 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool 27 | from skyreels_a1.src.multi_fps import multi_fps_tool 28 | 29 | import moviepy.editor as mp 30 | from diffposetalk.diffposetalk import DiffPoseTalk 31 | 32 | 33 | def crop_and_resize(image, height, width): 34 | image = np.array(image) 35 | image_height, image_width, _ = image.shape 36 | if image_height / image_width < height / width: 37 | croped_width = int(image_height / height * width) 38 | left = (image_width - croped_width) // 2 39 | image = image[:, left: left+croped_width] 40 | image = Image.fromarray(image).resize((width, height)) 41 | else: 42 | pad = int((((width / height) * image_height) - image_width) / 2.) 43 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8) 44 | padded_image[:, pad:pad+image_width] = image 45 | image = Image.fromarray(padded_image).resize((width, height)) 46 | return image 47 | 48 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"): 49 | clip = ImageSequenceClip(samples, fps=fps) 50 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, 51 | ffmpeg_params=["-crf", "18", "-preset", "slow"]) 52 | 53 | 54 | def parse_video(driving_frames, max_frame_num, fps=25): 55 | 56 | video_length = len(driving_frames) 57 | 58 | duration = video_length / fps 59 | target_times = np.arange(0, duration, 1/12) 60 | frame_indices = (target_times * fps).astype(np.int32) 61 | 62 | frame_indices = frame_indices[frame_indices < video_length] 63 | new_driving_frames = [] 64 | for idx in frame_indices: 65 | new_driving_frames.append(driving_frames[idx]) 66 | if len(new_driving_frames) >= max_frame_num - 1: 67 | break 68 | 69 | video_lenght_add = max_frame_num - len(new_driving_frames) - 1 70 | new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add 71 | return new_driving_frames 72 | 73 | def exec_cmd(cmd): 74 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 75 | 76 | def add_audio_to_video(silent_video_path, audio_video_path, output_video_path): 77 | cmd = [ 78 | 'ffmpeg', 79 | '-y', 80 | '-i', f'"{silent_video_path}"', 81 | '-i', f'"{audio_video_path}"', 82 | '-map', '0:v', 83 | '-map', '1:a', 84 | '-c:v', 'copy', 85 | '-shortest', 86 | f'"{output_video_path}"' 87 | ] 88 | 89 | try: 90 | exec_cmd(' '.join(cmd)) 91 | print(f"Video with audio generated successfully: {output_video_path}") 92 | except subprocess.CalledProcessError as e: 93 | print(f"Error occurred: {e}") 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser(description="Process video and image for face animation.") 97 | parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.') 98 | parser.add_argument('--driving_audio_path', type=str, default="assets/driving_audio/1.wav", help='Path to the driving video.') 99 | parser.add_argument('--output_path', type=str, default="outputs_audio", help='Path to save the output video.') 100 | args = parser.parse_args() 101 | 102 | guidance_scale = 3.0 103 | seed = 43 104 | num_inference_steps = 10 105 | sample_size = [480, 720] 106 | max_frame_num = 49 107 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates 108 | weight_dtype = torch.bfloat16 109 | save_path = args.output_path 110 | generator = torch.Generator(device="cuda").manual_seed(seed) 111 | model_name = "pretrained_models/SkyReels-A1-5B/" 112 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" 113 | 114 | lmk_extractor = LMKExtractor() 115 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') 116 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,) 117 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",) 118 | 119 | # siglip visual encoder 120 | siglip = SiglipVisionModel.from_pretrained(siglip_name) 121 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) 122 | 123 | # frame interpolation model 124 | if target_fps != 12: 125 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda") 126 | 127 | # diffposetalk 128 | diffposetalk = DiffPoseTalk() 129 | 130 | # skyreels a1 model 131 | transformer = CogVideoXTransformer3DModel.from_pretrained( 132 | model_name, 133 | subfolder="transformer" 134 | ).to(weight_dtype) 135 | 136 | vae = AutoencoderKLCogVideoX.from_pretrained( 137 | model_name, 138 | subfolder="vae" 139 | ).to(weight_dtype) 140 | 141 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( 142 | model_name, 143 | subfolder="pose_guider", 144 | ).to(weight_dtype) 145 | 146 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( 147 | model_name, 148 | transformer = transformer, 149 | vae = vae, 150 | lmk_encoder = lmk_encoder, 151 | image_encoder = siglip, 152 | feature_extractor = siglip_normalize, 153 | torch_dtype=torch.bfloat16 154 | ) 155 | 156 | pipe.to("cuda") 157 | pipe.enable_model_cpu_offload() 158 | pipe.vae.enable_tiling() 159 | 160 | image = load_image(image=args.image_path) 161 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) 162 | 163 | # ref image crop face 164 | ref_image, x1, y1 = processor.face_crop(np.array(image)) 165 | face_h, face_w, _, = ref_image.shape 166 | source_image = ref_image 167 | 168 | source_outputs, source_tform, image_original = processor.process_source_image(source_image) 169 | driving_outputs = diffposetalk.infer_from_file(args.driving_audio_path, source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy()) 170 | out_frames = processor.preprocess_lmk3d_from_coef(source_outputs, source_tform, image_original.shape, driving_outputs) 171 | out_frames = parse_video(out_frames, max_frame_num) 172 | 173 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0) 174 | for ii in range(rescale_motions.shape[0]): 175 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] 176 | ref_image = cv2.resize(ref_image, (512, 512)) 177 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1]) 178 | 179 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True) 180 | 181 | first_motion = np.zeros_like(np.array(image)) 182 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img 183 | first_motion = first_motion[np.newaxis, :] 184 | 185 | motions = np.concatenate([first_motion, rescale_motions]) 186 | input_video = motions[:max_frame_num] 187 | 188 | face_helper.clean_all() 189 | face_helper.read_image(np.array(image)[:, :, ::-1]) 190 | face_helper.get_face_landmarks_5(only_center_face=True) 191 | face_helper.align_warp_face() 192 | align_face = face_helper.cropped_faces[0] 193 | image_face = align_face[:, :, ::-1] 194 | 195 | input_video = input_video[:max_frame_num] 196 | motions = np.array(input_video) 197 | 198 | # [F, H, W, C] 199 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) 200 | input_video = input_video / 255 201 | 202 | out_samples = [] 203 | 204 | with torch.no_grad(): 205 | sample = pipe( 206 | image=image, 207 | image_face=image_face, 208 | control_video = input_video, 209 | prompt = "", 210 | negative_prompt = "", 211 | height = sample_size[0], 212 | width = sample_size[1], 213 | num_frames = 49, 214 | generator = generator, 215 | guidance_scale = guidance_scale, 216 | num_inference_steps = num_inference_steps, 217 | ) 218 | out_samples.extend(sample.frames[0]) 219 | out_samples = out_samples[2:] 220 | 221 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_audio_path).split(".")[0]+ ".mp4" 222 | 223 | if not os.path.exists(save_path): 224 | os.makedirs(save_path, exist_ok=True) 225 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4") 226 | 227 | if target_fps != 12: 228 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps) 229 | 230 | export_to_video(out_samples, video_path, fps=target_fps) 231 | add_audio_to_video(video_path, args.driving_audio_path, video_path.split(".")[0] + "_audio.mp4") 232 | 233 | if target_fps == 12: 234 | target_h, target_w = sample_size[0], sample_size[1] 235 | final_images = [] 236 | final_images2 =[] 237 | rescale_motions = rescale_motions[1:] 238 | control_frames = out_frames[1:] 239 | for q in range(len(out_samples)): 240 | frame1 = image 241 | frame2 = Image.fromarray(np.array(out_samples[q])).convert("RGB") 242 | 243 | result = Image.new('RGB', (target_w * 2, target_h)) 244 | result.paste(frame1, (0, 0)) 245 | result.paste(frame2, (target_w, 0)) 246 | final_images.append(np.array(result)) 247 | 248 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4") 249 | write_mp4(video_out_path, final_images, fps=12) 250 | 251 | add_audio_to_video(video_out_path, args.driving_audio_path, video_out_path.split(".")[0] + "_audio.mp4") 252 | 253 | -------------------------------------------------------------------------------- /inference_audio_long_video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | import insightface 7 | import cv2 8 | import subprocess 9 | import argparse 10 | import math 11 | import time 12 | from decord import VideoReader 13 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip 14 | from facexlib.parsing import init_parsing_model 15 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 16 | from insightface.app import FaceAnalysis 17 | import moviepy.editor as mp 18 | 19 | from diffusers.models import AutoencoderKLCogVideoX 20 | from diffusers.utils import export_to_video, load_image 21 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel 22 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 23 | 24 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel 25 | from skyreels_a1.skyreels_a1_i2v_long_pipeline import SkyReelsA1ImagePoseToVideoPipeline 26 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor 27 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor 28 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d 29 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool 30 | from skyreels_a1.src.multi_fps import multi_fps_tool 31 | 32 | from diffusers.video_processor import VideoProcessor 33 | from diffposetalk.diffposetalk import DiffPoseTalk 34 | 35 | 36 | def crop_and_resize(image, height, width): 37 | image = np.array(image) 38 | image_height, image_width, _ = image.shape 39 | if image_height / image_width < height / width: 40 | croped_width = int(image_height / height * width) 41 | left = (image_width - croped_width) // 2 42 | image = image[:, left: left+croped_width] 43 | image = Image.fromarray(image).resize((width, height)) 44 | else: 45 | pad = int((((width / height) * image_height) - image_width) / 2.) 46 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8) 47 | padded_image[:, pad:pad+image_width] = image 48 | image = Image.fromarray(padded_image).resize((width, height)) 49 | return image 50 | 51 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"): 52 | clip = ImageSequenceClip(samples, fps=fps) 53 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, 54 | ffmpeg_params=["-crf", "18", "-preset", "slow"]) 55 | 56 | def parse_video(driving_frames, fps=25): 57 | video_length = len(driving_frames) 58 | 59 | duration = video_length / fps 60 | target_times = np.arange(0, duration, 1/12) 61 | frame_indices = (target_times * fps).astype(np.int32) 62 | 63 | frame_indices = frame_indices[frame_indices < video_length] 64 | new_driving_frames = [] 65 | for idx in frame_indices: 66 | new_driving_frames.append(driving_frames[idx]) 67 | 68 | return new_driving_frames 69 | 70 | 71 | def smooth_video_transition(frames1, frames2, smooth_frame_num, frame_inter_model, inter_frames=2): 72 | 73 | frames1_np = np.array([np.array(frame) for frame in frames1]) 74 | frames2_np = np.array([np.array(frame) for frame in frames2]) 75 | frames1_tensor = torch.from_numpy(frames1_np).permute(3,0,1,2).unsqueeze(0) / 255.0 76 | frames2_tensor = torch.from_numpy(frames2_np).permute(3,0,1,2).unsqueeze(0) / 255.0 77 | video = torch.cat([frames1_tensor[:,:,-smooth_frame_num:], frames2_tensor[:,:,:smooth_frame_num]], dim=2) 78 | 79 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=inter_frames) 80 | 81 | index = [1, 4, 5, 8] if inter_frames == 2 else [2, 5, 7, 10] 82 | video = video[:, :, index] 83 | video = video.squeeze(0) 84 | video = video.permute(1, 2, 3, 0) 85 | video = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() 86 | mid_frames = [Image.fromarray(frame) for frame in video] 87 | 88 | out_frames = frames1[:-smooth_frame_num] + mid_frames + frames2[smooth_frame_num:] 89 | 90 | return out_frames 91 | 92 | 93 | def exec_cmd(cmd): 94 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 95 | 96 | def add_audio_to_video(silent_video_path, audio_video_path, output_video_path): 97 | cmd = [ 98 | 'ffmpeg', 99 | '-y', 100 | '-i', f'"{silent_video_path}"', 101 | '-i', f'"{audio_video_path}"', 102 | '-map', '0:v', 103 | '-map', '1:a', 104 | '-c:v', 'copy', 105 | '-shortest', 106 | f'"{output_video_path}"' 107 | ] 108 | 109 | try: 110 | exec_cmd(' '.join(cmd)) 111 | print(f"Video with audio generated successfully: {output_video_path}") 112 | except subprocess.CalledProcessError as e: 113 | print(f"Error occurred: {e}") 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description="Process video and image for face animation.") 118 | parser.add_argument('--image_path', type=str, default="assets/ref_images/19.png", help='Path to the source image.') 119 | parser.add_argument('--driving_audio_path', type=str, default="assets/driving_audio/2.wav", help='Path to the driving video.') 120 | parser.add_argument('--output_path', type=str, default="outputs_audio", help='Path to save the output video.') 121 | args = parser.parse_args() 122 | 123 | guidance_scale = 3.0 124 | seed = 43 125 | num_inference_steps = 10 126 | sample_size = [480, 720] 127 | max_frame_num = 10000 128 | frame_num_per_batch = 49 129 | overlap_frame_num = 8 130 | fusion_interval = [3, 8] 131 | use_interpolation = True 132 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates 133 | weight_dtype = torch.bfloat16 134 | save_path = args.output_path 135 | generator = torch.Generator(device="cuda").manual_seed(seed) 136 | model_name = "pretrained_models/SkyReels-A1-5B/" 137 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" 138 | 139 | lmk_extractor = LMKExtractor() 140 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') 141 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False) 142 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda") 143 | 144 | # siglip visual encoder 145 | siglip = SiglipVisionModel.from_pretrained(siglip_name) 146 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) 147 | 148 | # frame interpolation model 149 | if use_interpolation or target_fps != 12: 150 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda") 151 | 152 | # diffposetalk 153 | diffposetalk = DiffPoseTalk() 154 | 155 | # skyreels a1 model 156 | transformer = CogVideoXTransformer3DModel.from_pretrained( 157 | model_name, 158 | subfolder="transformer" 159 | ).to(weight_dtype) 160 | 161 | vae = AutoencoderKLCogVideoX.from_pretrained( 162 | model_name, 163 | subfolder="vae" 164 | ).to(weight_dtype) 165 | 166 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( 167 | model_name, 168 | subfolder="pose_guider", 169 | ).to(weight_dtype) 170 | 171 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( 172 | model_name, 173 | transformer=transformer, 174 | vae=vae, 175 | lmk_encoder=lmk_encoder, 176 | image_encoder=siglip, 177 | feature_extractor=siglip_normalize, 178 | torch_dtype=torch.bfloat16 179 | ) 180 | 181 | pipe.to("cuda") 182 | pipe.enable_model_cpu_offload() 183 | pipe.vae.enable_tiling() 184 | 185 | image = load_image(image=args.image_path) 186 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) 187 | 188 | # ref image crop face 189 | ref_image, x1, y1 = processor.face_crop(np.array(image)) 190 | face_h, face_w, _ = ref_image.shape 191 | source_image = ref_image 192 | 193 | source_outputs, source_tform, image_original = processor.process_source_image(source_image) 194 | driving_outputs = diffposetalk.infer_from_file(args.driving_audio_path, source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy()) 195 | 196 | out_frames = processor.preprocess_lmk3d_from_coef(source_outputs, source_tform, image_original.shape, driving_outputs) 197 | out_frames = parse_video(out_frames) 198 | 199 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(len(out_frames), axis=0) 200 | for ii in range(rescale_motions.shape[0]): 201 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] 202 | ref_image = cv2.resize(ref_image, (512, 512)) 203 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1]) 204 | 205 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True) 206 | 207 | first_motion = np.zeros_like(np.array(image)) 208 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img 209 | first_motion = first_motion[np.newaxis, :] 210 | 211 | input_video = rescale_motions[:max_frame_num] 212 | 213 | video_length = len(input_video) 214 | print(f"orginal video length: {video_length}") 215 | 216 | face_helper.clean_all() 217 | face_helper.read_image(np.array(image)[:, :, ::-1]) 218 | face_helper.get_face_landmarks_5(only_center_face=True) 219 | face_helper.align_warp_face() 220 | align_face = face_helper.cropped_faces[0] 221 | image_face = align_face[:, :, ::-1] 222 | 223 | # [F, H, W, C] 224 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, F, H, W] 225 | input_video = input_video / 255 226 | input_video_all = input_video 227 | 228 | first_motion = torch.from_numpy(np.array(first_motion)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, 1, H, W] 229 | first_motion = first_motion / 255 230 | 231 | out_samples = [] 232 | padding_frame_num = None 233 | latents_cache = [] 234 | 235 | time_start = time.time() 236 | 237 | for i in range(0, video_length, frame_num_per_batch-1-overlap_frame_num): 238 | is_first_batch = (i == 0) 239 | is_last_batch = (i + frame_num_per_batch - 1 >= video_length) 240 | 241 | input_video = input_video_all[:, :, i:i+frame_num_per_batch-1] 242 | 243 | if input_video.shape[2] != frame_num_per_batch-1: 244 | padding_frame_num = frame_num_per_batch-1 - input_video.shape[2] 245 | print(f"padding_frame_num: {padding_frame_num}") 246 | input_video = torch.cat([input_video, torch.repeat_interleave(input_video[:, :, -1].unsqueeze(2), padding_frame_num, dim=2)], dim=2) 247 | 248 | input_video = torch.cat([first_motion, input_video[:, :, 0:1], input_video], dim=2) 249 | 250 | with torch.no_grad(): 251 | sample, latents_cache = pipe( 252 | image=image, 253 | image_face=image_face, 254 | control_video=input_video, 255 | prompt="", 256 | negative_prompt="", 257 | height=sample_size[0], 258 | width=sample_size[1], 259 | num_frames=frame_num_per_batch, 260 | generator=generator, 261 | guidance_scale=guidance_scale, 262 | num_inference_steps=num_inference_steps, 263 | is_last_batch=is_last_batch, 264 | overlap_frame_num=overlap_frame_num, 265 | fusion_interval=fusion_interval, 266 | latents_cache=latents_cache 267 | ) 268 | if use_interpolation: 269 | if is_first_batch: 270 | out_samples = sample.frames[0][1:] 271 | else: 272 | out_samples = smooth_video_transition(out_samples, sample.frames[0][1+overlap_frame_num:], 2, frame_inter_model) 273 | else: 274 | out_sample = sample.frames[0][1:] if is_first_batch else sample.frames[0][1+overlap_frame_num:] 275 | out_samples.extend(out_sample) 276 | print(f"out_samples len: {len(out_samples)}") 277 | 278 | if is_last_batch: 279 | break 280 | 281 | if padding_frame_num is not None: 282 | out_samples = out_samples[:-padding_frame_num] 283 | 284 | print(f"output video length: {len(out_samples)}") 285 | 286 | time_end = time.time() 287 | print(f"time cost: {time_end - time_start} seconds") 288 | 289 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_audio_path).split(".")[0]+ ".mp4" 290 | 291 | if not os.path.exists(save_path): 292 | os.makedirs(save_path, exist_ok=True) 293 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4") 294 | 295 | if target_fps != 12: 296 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps) 297 | 298 | export_to_video(out_samples, video_path, fps=target_fps) 299 | add_audio_to_video(video_path, args.driving_audio_path, video_path.split(".")[0] + "_audio.mp4") 300 | 301 | if target_fps == 12: 302 | target_h, target_w = sample_size[0], sample_size[1] 303 | final_images = [] 304 | final_images2 =[] 305 | rescale_motions = rescale_motions[1:] 306 | control_frames = out_frames[1:] 307 | for q in range(len(out_samples)): 308 | frame1 = image 309 | frame2 = Image.fromarray(np.array(out_samples[q])).convert("RGB") 310 | 311 | result = Image.new('RGB', (target_w * 2, target_h)) 312 | result.paste(frame1, (0, 0)) 313 | result.paste(frame2, (target_w, 0)) 314 | final_images.append(np.array(result)) 315 | 316 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4") 317 | write_mp4(video_out_path, final_images, fps=12) 318 | 319 | add_audio_to_video(video_out_path, args.driving_audio_path, video_out_path.split(".")[0] + "_audio.mp4") 320 | -------------------------------------------------------------------------------- /inference_long_video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | import insightface 7 | import cv2 8 | import subprocess 9 | import argparse 10 | import time 11 | from decord import VideoReader 12 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip 13 | from facexlib.parsing import init_parsing_model 14 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 15 | from insightface.app import FaceAnalysis 16 | 17 | from diffusers.models import AutoencoderKLCogVideoX 18 | from diffusers.utils import export_to_video, load_image 19 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel 20 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 21 | 22 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel 23 | from skyreels_a1.skyreels_a1_i2v_long_pipeline import SkyReelsA1ImagePoseToVideoPipeline 24 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor 25 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor 26 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d 27 | from skyreels_a1.src.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool 28 | from skyreels_a1.src.multi_fps import multi_fps_tool 29 | 30 | from diffusers.video_processor import VideoProcessor 31 | 32 | 33 | def crop_and_resize(image, height, width): 34 | image = np.array(image) 35 | image_height, image_width, _ = image.shape 36 | if image_height / image_width < height / width: 37 | croped_width = int(image_height / height * width) 38 | left = (image_width - croped_width) // 2 39 | image = image[:, left: left+croped_width] 40 | image = Image.fromarray(image).resize((width, height)) 41 | else: 42 | pad = int((((width / height) * image_height) - image_width) / 2.) 43 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8) 44 | padded_image[:, pad:pad+image_width] = image 45 | image = Image.fromarray(padded_image).resize((width, height)) 46 | return image 47 | 48 | 49 | def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"): 50 | clip = ImageSequenceClip(samples, fps=fps) 51 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, 52 | ffmpeg_params=["-crf", "18", "-preset", "slow"]) 53 | 54 | 55 | def parse_video(driving_video_path, max_frame_num=10000): 56 | vr = VideoReader(driving_video_path) 57 | fps = vr.get_avg_fps() 58 | video_length = len(vr) 59 | 60 | duration = video_length / fps 61 | target_times = np.arange(0, duration, 1/12) 62 | frame_indices = (target_times * fps).astype(np.int32) 63 | 64 | frame_indices = frame_indices[frame_indices < video_length] 65 | control_frames = vr.get_batch(frame_indices).asnumpy()[:max_frame_num] 66 | 67 | return control_frames 68 | 69 | 70 | def exec_cmd(cmd): 71 | return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 72 | 73 | 74 | def add_audio_to_video(silent_video_path, audio_video_path, output_video_path): 75 | cmd = [ 76 | 'ffmpeg', 77 | '-y', 78 | '-i', f'"{silent_video_path}"', 79 | '-i', f'"{audio_video_path}"', 80 | '-map', '0:v', 81 | '-map', '1:a', 82 | '-c:v', 'copy', 83 | '-shortest', 84 | f'"{output_video_path}"' 85 | ] 86 | 87 | try: 88 | exec_cmd(' '.join(cmd)) 89 | print(f"Video with audio generated successfully: {output_video_path}") 90 | except subprocess.CalledProcessError as e: 91 | print(f"Error occurred: {e}") 92 | 93 | 94 | def smooth_video_transition(frames1, frames2, smooth_frame_num, frame_inter_model, inter_frames=2): 95 | 96 | frames1_np = np.array([np.array(frame) for frame in frames1]) 97 | frames2_np = np.array([np.array(frame) for frame in frames2]) 98 | frames1_tensor = torch.from_numpy(frames1_np).permute(3,0,1,2).unsqueeze(0) / 255.0 99 | frames2_tensor = torch.from_numpy(frames2_np).permute(3,0,1,2).unsqueeze(0) / 255.0 100 | video = torch.cat([frames1_tensor[:,:,-smooth_frame_num:], frames2_tensor[:,:,:smooth_frame_num]], dim=2) 101 | 102 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=inter_frames) 103 | 104 | index = [1, 4, 5, 8] if inter_frames == 2 else [2, 5, 7, 10] 105 | video = video[:, :, index] 106 | video = video.squeeze(0) 107 | video = video.permute(1, 2, 3, 0) 108 | video = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() 109 | mid_frames = [Image.fromarray(frame) for frame in video] 110 | 111 | out_frames = frames1[:-smooth_frame_num] + mid_frames + frames2[smooth_frame_num:] 112 | 113 | return out_frames 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description="Process video and image for face animation.") 118 | parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.') 119 | parser.add_argument('--driving_video_path', type=str, default="assets/driving_video/6.mp4", help='Path to the driving video.') 120 | parser.add_argument('--output_path', type=str, default="outputs", help='Path to save the output video.') 121 | args = parser.parse_args() 122 | 123 | guidance_scale = 3.0 124 | seed = 43 125 | num_inference_steps = 10 126 | sample_size = [480, 720] 127 | max_frame_num = 10000 128 | frame_num_per_batch = 49 129 | overlap_frame_num = 8 130 | fusion_interval = [3, 8] 131 | use_interpolation = True 132 | target_fps = 12 # recommend fps: 12(Native), 24, 36, 48, 60, other fps like 25, 30 may cause unstable rates 133 | weight_dtype = torch.bfloat16 134 | save_path = args.output_path 135 | generator = torch.Generator(device="cuda").manual_seed(seed) 136 | model_name = "pretrained_models/SkyReels-A1-5B/" 137 | siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" 138 | 139 | lmk_extractor = LMKExtractor() 140 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') 141 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False) 142 | face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda") 143 | 144 | # siglip visual encoder 145 | siglip = SiglipVisionModel.from_pretrained(siglip_name) 146 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) 147 | 148 | # frame interpolation model 149 | if use_interpolation or target_fps != 12: 150 | frame_inter_model = init_frame_interpolation_model('pretrained_models/film_net/film_net_fp16.pt', device="cuda") 151 | 152 | # skyreels a1 model 153 | transformer = CogVideoXTransformer3DModel.from_pretrained( 154 | model_name, 155 | subfolder="transformer" 156 | ).to(weight_dtype) 157 | 158 | vae = AutoencoderKLCogVideoX.from_pretrained( 159 | model_name, 160 | subfolder="vae" 161 | ).to(weight_dtype) 162 | 163 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( 164 | model_name, 165 | subfolder="pose_guider", 166 | ).to(weight_dtype) 167 | 168 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( 169 | model_name, 170 | transformer=transformer, 171 | vae=vae, 172 | lmk_encoder=lmk_encoder, 173 | image_encoder=siglip, 174 | feature_extractor=siglip_normalize, 175 | torch_dtype=torch.bfloat16 176 | ) 177 | 178 | pipe.to("cuda") 179 | pipe.enable_model_cpu_offload() 180 | pipe.vae.enable_tiling() 181 | 182 | control_frames = parse_video(args.driving_video_path, max_frame_num) 183 | 184 | # driving video crop face 185 | driving_video_crop = [] 186 | empty_index = [] 187 | from tqdm import tqdm 188 | for i, control_frame in enumerate(tqdm(control_frames, desc="Face crop")): 189 | frame, _, _ = processor.face_crop(control_frame) 190 | if frame is None: 191 | print(f'Warning: No face detected in the driving video frame {i}') 192 | empty_index.append(i) 193 | else: 194 | driving_video_crop.append(frame) 195 | 196 | control_frames = np.delete(control_frames, empty_index, axis=0) 197 | 198 | video_length = len(driving_video_crop) # orginal video length 199 | 200 | print(f"orginal video length: {video_length}") 201 | 202 | image = load_image(image=args.image_path) 203 | image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) 204 | 205 | # ref image crop face 206 | ref_image, x1, y1 = processor.face_crop(np.array(image)) 207 | face_h, face_w, _ = ref_image.shape 208 | source_image = ref_image 209 | driving_video = driving_video_crop 210 | out_frames = processor.preprocess_lmk3d(source_image, driving_video) 211 | 212 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(len(out_frames), axis=0) 213 | for ii in range(rescale_motions.shape[0]): 214 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] 215 | ref_image = cv2.resize(ref_image, (512, 512)) 216 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1]) 217 | 218 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True) 219 | 220 | first_motion = np.zeros_like(np.array(image)) 221 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img 222 | first_motion = first_motion[np.newaxis, :] 223 | 224 | input_video = rescale_motions[:max_frame_num] 225 | 226 | face_helper.clean_all() 227 | face_helper.read_image(np.array(image)[:, :, ::-1]) 228 | face_helper.get_face_landmarks_5(only_center_face=True) 229 | face_helper.align_warp_face() 230 | align_face = face_helper.cropped_faces[0] 231 | image_face = align_face[:, :, ::-1] 232 | 233 | # [F, H, W, C] 234 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, F, H, W] 235 | input_video = input_video / 255 236 | input_video_all = input_video 237 | 238 | first_motion = torch.from_numpy(np.array(first_motion)).permute([3, 0, 1, 2]).unsqueeze(0) # [B, C, 1, H, W] 239 | first_motion = first_motion / 255 240 | 241 | out_samples = [] 242 | padding_frame_num = None 243 | latents_cache = [] 244 | 245 | time_start = time.time() 246 | 247 | for i in range(0, video_length, frame_num_per_batch-1-overlap_frame_num): 248 | is_first_batch = (i == 0) 249 | is_last_batch = (i + frame_num_per_batch - 1 >= video_length) 250 | 251 | input_video = input_video_all[:, :, i:i+frame_num_per_batch-1] 252 | 253 | if input_video.shape[2] != frame_num_per_batch-1: 254 | padding_frame_num = frame_num_per_batch-1 - input_video.shape[2] 255 | print(f"padding_frame_num: {padding_frame_num}") 256 | input_video = torch.cat([input_video, torch.repeat_interleave(input_video[:, :, -1].unsqueeze(2), padding_frame_num, dim=2)], dim=2) 257 | 258 | input_video = torch.cat([first_motion, input_video], dim=2) 259 | 260 | with torch.no_grad(): 261 | sample, latents_cache = pipe( 262 | image=image, 263 | image_face=image_face, 264 | control_video=input_video, 265 | prompt="", 266 | negative_prompt="", 267 | height=sample_size[0], 268 | width=sample_size[1], 269 | num_frames=frame_num_per_batch, 270 | generator=generator, 271 | guidance_scale=guidance_scale, 272 | num_inference_steps=num_inference_steps, 273 | is_last_batch=is_last_batch, 274 | overlap_frame_num=overlap_frame_num, 275 | fusion_interval=fusion_interval, 276 | latents_cache=latents_cache 277 | ) 278 | if use_interpolation: 279 | if is_first_batch: 280 | out_samples = sample.frames[0][1:] 281 | else: 282 | out_samples = smooth_video_transition(out_samples, sample.frames[0][1+overlap_frame_num:], 2, frame_inter_model) 283 | else: 284 | out_sample = sample.frames[0][1:] if is_first_batch else sample.frames[0][1+overlap_frame_num:] 285 | out_samples.extend(out_sample) 286 | print(f"out_samples len: {len(out_samples)}") 287 | 288 | if is_last_batch: 289 | break 290 | 291 | if padding_frame_num is not None: 292 | out_samples = out_samples[:-padding_frame_num] 293 | 294 | print(f"output video length: {len(out_samples)}") 295 | 296 | time_end = time.time() 297 | print(f"time cost: {time_end - time_start} seconds") 298 | 299 | save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_video_path).split(".")[0]+ ".mp4" 300 | 301 | if not os.path.exists(save_path): 302 | os.makedirs(save_path, exist_ok=True) 303 | video_path = os.path.join(save_path, save_path_name.split(".")[0] + "_output.mp4") 304 | 305 | if target_fps != 12: 306 | out_samples = multi_fps_tool(out_samples, frame_inter_model, target_fps) 307 | 308 | export_to_video(out_samples, video_path, fps=target_fps) 309 | add_audio_to_video(video_path, args.driving_video_path, video_path.split(".")[0] + "_audio.mp4") 310 | 311 | if target_fps == 12: 312 | target_h, target_w = sample_size[0], sample_size[1] 313 | final_images = [] 314 | for q in range(len(out_samples)): 315 | frame1 = image 316 | frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[q])).convert("RGB"), target_h, target_w) 317 | frame3 = Image.fromarray(np.array(out_samples[q])).convert("RGB") 318 | 319 | result = Image.new('RGB', (target_w * 3, target_h)) 320 | result.paste(frame1, (0, 0)) 321 | result.paste(frame2, (target_w, 0)) 322 | result.paste(frame3, (target_w * 2, 0)) 323 | final_images.append(np.array(result)) 324 | 325 | video_out_path = os.path.join(save_path, save_path_name.split(".")[0]+"_merge.mp4") 326 | write_mp4(video_out_path, final_images, fps=12) 327 | 328 | add_audio_to_video(video_out_path, args.driving_video_path, video_out_path.split(".")[0] + f"_audio.mp4") 329 | 330 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chumpy==0.70 2 | decord==0.6.0 3 | diffusers==0.32.2 4 | einops==0.8.1 5 | facexlib==0.3.0 6 | gradio==5.16.0 7 | insightface==0.7.3 8 | moviepy==1.0.3 9 | numpy==1.26.4 10 | opencv_contrib_python==4.10.0.84 11 | opencv_python==4.10.0.84 12 | opencv_python_headless==4.10.0.84 13 | Pillow==11.1.0 14 | pytorch3d==0.7.8 15 | safetensors==0.5.2 16 | scikit-image==0.24.0 17 | timm==0.6.13 18 | torch==2.2.2+cu118 19 | tqdm==4.66.2 20 | transformers==4.37.2 21 | mediapipe==0.10.21 22 | librosa==0.10.2.post1 23 | onnxruntime-gpu==1.16.3 24 | accelerate==1.6.0 25 | 26 | # if torch or pytorch3d installation fails, try: 27 | # pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118 28 | # pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" 29 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | import insightface 7 | import cv2 8 | import subprocess 9 | import argparse 10 | from decord import VideoReader 11 | from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip 12 | from facexlib.parsing import init_parsing_model 13 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 14 | from insightface.app import FaceAnalysis 15 | 16 | from diffusers.models import AutoencoderKLCogVideoX 17 | from diffusers.utils import export_to_video, load_image 18 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel 19 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 20 | 21 | from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel 22 | from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline 23 | from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor 24 | from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor 25 | from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d 26 | 27 | 28 | def crop_and_resize(image, height, width): 29 | image = np.array(image) 30 | image_height, image_width, _ = image.shape 31 | if image_height / image_width < height / width: 32 | croped_width = int(image_height / height * width) 33 | left = (image_width - croped_width) // 2 34 | image = image[:, left: left+croped_width] 35 | image = Image.fromarray(image).resize((width, height)) 36 | else: 37 | pad = int((((width / height) * image_height) - image_width) / 2.) 38 | padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8) 39 | padded_image[:, pad:pad+image_width] = image 40 | image = Image.fromarray(padded_image).resize((width, height)) 41 | return image 42 | 43 | def write_mp4(video_path, samples, fps=14, audio_bitrate="192k"): 44 | clip = ImageSequenceClip(samples, fps=fps) 45 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, 46 | ffmpeg_params=["-crf", "18", "-preset", "slow"]) 47 | 48 | def init_model( 49 | model_name: str = "pretrained_models/SkyReels-A1-5B/", 50 | subfolder: str = "outputs/", 51 | siglip_path: str = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384", 52 | weight_dtype=torch.bfloat16, 53 | ): 54 | 55 | lmk_extractor = LMKExtractor() 56 | vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,) 57 | processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') 58 | 59 | face_helper = FaceRestoreHelper( 60 | upscale_factor=1, 61 | face_size=512, 62 | crop_ratio=(1, 1), 63 | det_model='retinaface_resnet50', 64 | save_ext='png', 65 | device="cuda", 66 | ) 67 | 68 | siglip = SiglipVisionModel.from_pretrained(siglip_path) 69 | siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_path) 70 | 71 | transformer = CogVideoXTransformer3DModel.from_pretrained( 72 | model_name, 73 | subfolder="transformer", 74 | ).to(weight_dtype) 75 | 76 | vae = AutoencoderKLCogVideoX.from_pretrained( 77 | model_name, 78 | subfolder="vae" 79 | ).to(weight_dtype) 80 | 81 | lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( 82 | model_name, 83 | subfolder="pose_guider" 84 | ).to(weight_dtype) 85 | 86 | pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( 87 | model_name, 88 | transformer = transformer, 89 | vae = vae, 90 | lmk_encoder = lmk_encoder, 91 | image_encoder = siglip, 92 | feature_extractor = siglip_normalize, 93 | torch_dtype=weight_dtype) 94 | pipe.to("cuda") 95 | pipe.enable_model_cpu_offload() 96 | pipe.vae.enable_tiling() 97 | 98 | return pipe, face_helper, processor, lmk_extractor, vis 99 | 100 | 101 | 102 | def generate_video( 103 | pipe, 104 | face_helper, 105 | processor, 106 | lmk_extractor, 107 | vis, 108 | control_video_path: str = None, 109 | image_path: str = None, 110 | save_path: str = None, 111 | guidance_scale=3.0, 112 | seed=43, 113 | num_inference_steps=10, 114 | sample_size=[480, 720], 115 | max_frame_num=49, 116 | weight_dtype=torch.bfloat16, 117 | ): 118 | 119 | vr = VideoReader(control_video_path) 120 | fps = vr.get_avg_fps() 121 | video_length = len(vr) 122 | 123 | duration = video_length / fps 124 | target_times = np.arange(0, duration, 1/12) 125 | frame_indices = (target_times * fps).astype(np.int32) 126 | 127 | frame_indices = frame_indices[frame_indices < video_length] 128 | control_frames = vr.get_batch(frame_indices).asnumpy()[:(max_frame_num-1)] 129 | 130 | out_frames = len(control_frames) - 1 131 | if len(control_frames) < max_frame_num: 132 | video_lenght_add = max_frame_num - len(control_frames) 133 | control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-2], [control_frames[-1]] * video_lenght_add), axis=0) 134 | 135 | # driving video crop face 136 | driving_video_crop = [] 137 | for control_frame in control_frames: 138 | frame, _, _ = processor.face_crop(control_frame) 139 | driving_video_crop.append(frame) 140 | 141 | image = load_image(image=image_path) 142 | image = crop_and_resize(image, sample_size[0], sample_size[1]) 143 | 144 | with torch.no_grad(): 145 | face_helper.clean_all() 146 | face_helper.read_image(np.array(image)[:, :, ::-1]) 147 | face_helper.get_face_landmarks_5(only_center_face=True) 148 | face_helper.align_warp_face() 149 | if len(face_helper.cropped_faces) == 0: 150 | return 151 | align_face = face_helper.cropped_faces[0] 152 | image_face = align_face[:, :, ::-1] 153 | 154 | # ref image crop face 155 | ref_image, x1, y1 = processor.face_crop(np.array(image)) 156 | face_h, face_w, _, = ref_image.shape 157 | source_image = ref_image 158 | driving_video = driving_video_crop 159 | out_frames = processor.preprocess_lmk3d(source_image, driving_video) 160 | 161 | rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0) 162 | for ii in range(rescale_motions.shape[0]): 163 | rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] 164 | ref_image = cv2.resize(ref_image, (512, 512)) 165 | ref_lmk = lmk_extractor(ref_image[:, :, ::-1]) 166 | 167 | ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True) 168 | 169 | first_motion = np.zeros_like(np.array(image)) 170 | first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img 171 | first_motion = first_motion[np.newaxis, :] 172 | 173 | motions = np.concatenate([first_motion, rescale_motions]) 174 | input_video = motions[:max_frame_num] 175 | 176 | input_video = input_video[:max_frame_num] 177 | motions = np.array(input_video) 178 | 179 | # [F, H, W, C] 180 | input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) 181 | input_video = input_video / 255 182 | 183 | out_samples = [] 184 | 185 | generator = torch.Generator(device="cuda").manual_seed(seed) 186 | with torch.no_grad(): 187 | sample = pipe( 188 | image=image, 189 | image_face=image_face, 190 | control_video = input_video, 191 | height = sample_size[0], 192 | width = sample_size[1], 193 | num_frames = 49, 194 | generator = generator, 195 | guidance_scale = guidance_scale, 196 | num_inference_steps = num_inference_steps, 197 | ) 198 | out_samples.extend(sample.frames[0][2:]) 199 | 200 | # export_to_video(out_samples, save_path, fps=12) 201 | control_frames = control_frames[1:] 202 | target_h, target_w = sample_size 203 | final_images = [] 204 | for i in range(len(out_samples)): 205 | frame1 = image 206 | frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[i])).convert("RGB"), target_h, target_w) 207 | frame3 = Image.fromarray(np.array(out_samples[i])).convert("RGB") 208 | result = Image.new('RGB', (target_w * 3, target_h)) 209 | result.paste(frame1, (0, 0)) 210 | result.paste(frame2, (target_w, 0)) 211 | result.paste(frame3, (target_w * 2, 0)) 212 | final_images.append(np.array(result)) 213 | 214 | write_mp4(save_path, final_images, fps=12) 215 | 216 | 217 | 218 | if __name__ == "__main__": 219 | control_video_zip = glob.glob("assets/driving_video/*.mp4") 220 | image_path_zip = glob.glob("assets/ref_images/*.png") 221 | 222 | guidance_scale = 3.0 223 | seed = 43 224 | num_inference_steps = 10 225 | sample_size = [480, 720] 226 | max_frame_num = 49 227 | weight_dtype = torch.bfloat16 228 | 229 | save_path = "outputs" 230 | 231 | # init model 232 | pipe, face_helper, processor, lmk_extractor, vis = init_model() 233 | 234 | for i in range(len(control_video_zip)): 235 | for j in range(len(image_path_zip)): 236 | generate_video( 237 | pipe, 238 | face_helper, 239 | processor, 240 | lmk_extractor, 241 | vis, 242 | control_video_path=control_video_zip[i], 243 | image_path=image_path_zip[j], 244 | save_path=save_path, 245 | guidance_scale=guidance_scale, 246 | seed=seed, 247 | num_inference_steps=num_inference_steps, 248 | sample_size=sample_size, 249 | max_frame_num=max_frame_num, 250 | weight_dtype=weight_dtype, 251 | ) 252 | -------------------------------------------------------------------------------- /skyreels_a1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/__init__.py -------------------------------------------------------------------------------- /skyreels_a1/ddim_solver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def append_dims(x, target_dims): 5 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 6 | dims_to_append = target_dims - x.ndim 7 | if dims_to_append < 0: 8 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 9 | return x[(...,) + (None,) * dims_to_append] 10 | 11 | 12 | # From LCMScheduler.get_scalings_for_boundary_condition_discrete 13 | def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): 14 | scaled_timestep = timestep_scaling * timestep 15 | c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) 16 | c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 17 | return c_skip, c_out 18 | 19 | 20 | def extract_into_tensor(a, t, x_shape): 21 | b, *_ = t.shape 22 | out = a.gather(-1, t) 23 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 24 | 25 | 26 | 27 | class DDIMSolver: 28 | def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): 29 | # DDIM sampling parameters 30 | step_ratio = timesteps // ddim_timesteps 31 | 32 | self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 33 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] 34 | self.ddim_alpha_cumprods_prev = np.asarray( 35 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() 36 | ) 37 | # convert to torch tensors 38 | self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() 39 | self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) 40 | self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) 41 | 42 | def to(self, device): 43 | self.ddim_timesteps = self.ddim_timesteps.to(device) 44 | self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) 45 | self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) 46 | return self 47 | 48 | def ddim_step(self, pred_x0, pred_noise, timestep_index): 49 | alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) 50 | dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise 51 | x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt 52 | return x_prev -------------------------------------------------------------------------------- /skyreels_a1/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/models/__init__.py -------------------------------------------------------------------------------- /skyreels_a1/pipeline_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from diffusers.utils import BaseOutput 6 | 7 | 8 | @dataclass 9 | class CogVideoXPipelineOutput(BaseOutput): 10 | r""" 11 | Output class for CogVideo pipelines. 12 | 13 | Args: 14 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): 15 | List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing 16 | denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape 17 | `(batch_size, num_frames, channels, height, width)`. 18 | """ 19 | 20 | frames: torch.Tensor -------------------------------------------------------------------------------- /skyreels_a1/src/FLAME/FLAME.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | np.bool = np.bool_ 20 | np.int = np.int_ 21 | np.float = np.float_ 22 | np.complex = np.complex_ 23 | np.object = np.object_ 24 | np.unicode = np.unicode_ 25 | np.str = np.str_ 26 | import pickle 27 | import torch.nn.functional as F 28 | 29 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler 30 | 31 | def to_tensor(array, dtype=torch.float32): 32 | if 'torch.tensor' not in str(type(array)): 33 | return torch.tensor(array, dtype=dtype) 34 | def to_np(array, dtype=np.float32): 35 | if 'scipy.sparse' in str(type(array)): 36 | array = array.todense() 37 | return np.array(array, dtype=dtype) 38 | 39 | class Struct(object): 40 | def __init__(self, **kwargs): 41 | for key, val in kwargs.items(): 42 | setattr(self, key, val) 43 | 44 | class FLAME(nn.Module): 45 | """ 46 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py 47 | Given flame parameters this class generates a differentiable FLAME function 48 | which outputs the a mesh and 2D/3D facial landmarks 49 | """ 50 | def __init__(self, flame_model_path='pretrained_models/FLAME/generic_model.pkl', 51 | flame_lmk_embedding_path='pretrained_models/FLAME/landmark_embedding.npy', n_shape=300, n_exp=50): 52 | super(FLAME, self).__init__() 53 | 54 | with open(flame_model_path, 'rb') as f: 55 | ss = pickle.load(f, encoding='latin1') 56 | flame_model = Struct(**ss) 57 | 58 | self.n_shape = n_shape 59 | self.n_exp = n_exp 60 | self.dtype = torch.float32 61 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) 62 | # The vertices of the template model 63 | print('Using generic FLAME model') 64 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) 65 | 66 | # The shape components and expression 67 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) 68 | shapedirs = torch.cat([shapedirs[:,:,:n_shape], shapedirs[:,:,300:300+n_exp]], 2) 69 | self.register_buffer('shapedirs', shapedirs) 70 | # The pose components 71 | num_pose_basis = flame_model.posedirs.shape[-1] 72 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T 73 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) 74 | # 75 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) 76 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1 77 | self.register_buffer('parents', parents) 78 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) 79 | 80 | 81 | self.register_buffer('l_eyelid', torch.from_numpy(np.load(f'pretrained_models/smirk/l_eyelid.npy')).to(self.dtype)[None]) 82 | self.register_buffer('r_eyelid', torch.from_numpy(np.load(f'pretrained_models/smirk/r_eyelid.npy')).to(self.dtype)[None]) 83 | # import pdb;pdb.set_trace() 84 | 85 | # Fixing Eyeball and neck rotation 86 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) 87 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, 88 | requires_grad=False)) 89 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) 90 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, 91 | requires_grad=False)) 92 | 93 | # Static and Dynamic Landmark embeddings for FLAME 94 | lmk_embeddings = np.load(flame_lmk_embedding_path, allow_pickle=True, encoding='latin1') 95 | lmk_embeddings = lmk_embeddings[()] 96 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long()) 97 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype)) 98 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long()) 99 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype)) 100 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long()) 101 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype)) 102 | 103 | neck_kin_chain = []; NECK_IDX=1 104 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) 105 | while curr_idx != -1: 106 | neck_kin_chain.append(curr_idx) 107 | curr_idx = self.parents[curr_idx] 108 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) 109 | 110 | lmk_embeddings_mp = np.load("pretrained_models/smirk/mediapipe_landmark_embedding.npz") 111 | self.register_buffer('mp_lmk_faces_idx', torch.from_numpy(lmk_embeddings_mp['lmk_face_idx'].astype('int32')).long()) 112 | self.register_buffer('mp_lmk_bary_coords', torch.from_numpy(lmk_embeddings_mp['lmk_b_coords']).to(self.dtype)) 113 | 114 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx, 115 | dynamic_lmk_b_coords, 116 | neck_kin_chain, dtype=torch.float32): 117 | """ 118 | Selects the face contour depending on the reletive position of the head 119 | Input: 120 | vertices: N X num_of_vertices X 3 121 | pose: N X full pose 122 | dynamic_lmk_faces_idx: The list of contour face indexes 123 | dynamic_lmk_b_coords: The list of contour barycentric weights 124 | neck_kin_chain: The tree to consider for the relative rotation 125 | dtype: Data type 126 | return: 127 | The contour face indexes and the corresponding barycentric weights 128 | """ 129 | 130 | batch_size = pose.shape[0] 131 | 132 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 133 | neck_kin_chain) 134 | rot_mats = batch_rodrigues( 135 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 136 | 137 | rel_rot_mat = torch.eye(3, device=pose.device, 138 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) 139 | for idx in range(len(neck_kin_chain)): 140 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 141 | 142 | y_rot_angle = torch.round( 143 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 144 | max=39)).to(dtype=torch.long) 145 | 146 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 147 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 148 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 149 | y_rot_angle = (neg_mask * neg_vals + 150 | (1 - neg_mask) * y_rot_angle) 151 | 152 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 153 | 0, y_rot_angle) 154 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 155 | 0, y_rot_angle) 156 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 157 | 158 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): 159 | """ 160 | Calculates landmarks by barycentric interpolation 161 | Input: 162 | vertices: torch.tensor NxVx3, dtype = torch.float32 163 | The tensor of input vertices 164 | faces: torch.tensor (N*F)x3, dtype = torch.long 165 | The faces of the mesh 166 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long 167 | The tensor with the indices of the faces used to calculate the 168 | landmarks. 169 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 170 | The tensor of barycentric coordinates that are used to interpolate 171 | the landmarks 172 | 173 | Returns: 174 | landmarks: torch.tensor NxLx3, dtype = torch.float32 175 | The coordinates of the landmarks for each mesh in the batch 176 | """ 177 | # Extract the indices of the vertices for each face 178 | # NxLx3 179 | batch_size, num_verts = vertices.shape[:dd2] 180 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 181 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1) 182 | 183 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( 184 | device=vertices.device) * num_verts 185 | 186 | lmk_vertices = vertices.view(-1, 3)[lmk_faces] 187 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 188 | return landmarks 189 | 190 | def seletec_3d68(self, vertices): 191 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 192 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), 193 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) 194 | return landmarks3d 195 | 196 | def get_landmarks(self, vertices): 197 | """ 198 | Input: 199 | shape_params: N X number of shape parameters 200 | expression_params: N X number of expression parameters 201 | pose_params: N X number of pose parameters (6) 202 | return:d 203 | vertices: N X V X 3 204 | landmarks: N X number of landmarks X 3 205 | """ 206 | batch_size = vertices.shape[0] 207 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) 208 | 209 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) 210 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) 211 | 212 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( 213 | full_pose, self.dynamic_lmk_faces_idx, 214 | self.dynamic_lmk_bary_coords, 215 | self.neck_kin_chain, dtype=self.dtype) 216 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) 217 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) 218 | 219 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor, 220 | lmk_faces_idx, 221 | lmk_bary_coords) 222 | bz = vertices.shape[0] 223 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 224 | self.full_lmk_faces_idx.repeat(bz, 1), 225 | self.full_lmk_bary_coords.repeat(bz, 1, 1)) 226 | return vertices, landmarks2d, landmarks3d 227 | 228 | 229 | def forward(self, param_dictionary, zero_expression=False, zero_shape=False, zero_pose=False): 230 | shape_params = param_dictionary['shape_params'] 231 | expression_params = param_dictionary['expression_params'] 232 | pose_params = param_dictionary.get('pose_params', None) 233 | jaw_params = param_dictionary.get('jaw_params', None) 234 | eye_pose_params = param_dictionary.get('eye_pose_params', None) 235 | neck_pose_params = param_dictionary.get('neck_pose_params', None) 236 | eyelid_params = param_dictionary.get('eyelid_params', None) 237 | 238 | batch_size = shape_params.shape[0] 239 | 240 | # Adjust expression params size if needed 241 | if expression_params.shape[1] < self.n_exp: 242 | expression_params = torch.cat([expression_params, torch.zeros(expression_params.shape[0], self.n_exp - expression_params.shape[1]).to(shape_params.device)], dim=1) 243 | 244 | if shape_params.shape[1] < self.n_shape: 245 | shape_params = torch.cat([shape_params, torch.zeros(shape_params.shape[0], self.n_shape - shape_params.shape[1]).to(shape_params.device)], dim=1) 246 | 247 | # Zero out the expression and pose parameters if needed 248 | if zero_expression: 249 | expression_params = torch.zeros_like(expression_params).to(shape_params.device) 250 | jaw_params = torch.zeros_like(jaw_params).to(shape_params.device) 251 | 252 | if zero_shape: 253 | shape_params = torch.zeros_like(shape_params).to(shape_params.device) 254 | 255 | 256 | if zero_pose: 257 | pose_params = torch.zeros_like(pose_params).to(shape_params.device) 258 | pose_params[...,0] = 0.2 259 | pose_params[...,1] = -0.7 260 | 261 | if pose_params is None: 262 | pose_params = self.pose_params.expand(batch_size, -1) 263 | 264 | if eye_pose_params is None: 265 | eye_pose_params = self.eye_pose.expand(batch_size, -1) 266 | 267 | if neck_pose_params is None: 268 | neck_pose_params = self.neck_pose.expand(batch_size, -1) 269 | 270 | 271 | betas = torch.cat([shape_params, expression_params], dim=1) 272 | full_pose = torch.cat([pose_params, neck_pose_params, jaw_params, eye_pose_params], dim=1) 273 | # import pdb;pdb.set_trace() 274 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) 275 | 276 | vertices, _ = lbs(betas, full_pose, template_vertices, 277 | self.shapedirs, self.posedirs, 278 | self.J_regressor, self.parents, 279 | self.lbs_weights, dtype=self.dtype) 280 | # import pdb;pdb.set_trace() 281 | if eyelid_params is not None: 282 | vertices = vertices + self.r_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 1:2, None] 283 | vertices = vertices + self.l_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 0:1, None] 284 | 285 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) 286 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) 287 | 288 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( 289 | full_pose, self.dynamic_lmk_faces_idx, 290 | self.dynamic_lmk_bary_coords, 291 | self.neck_kin_chain, dtype=self.dtype) 292 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) 293 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) 294 | 295 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor, 296 | lmk_faces_idx, 297 | lmk_bary_coords) 298 | bz = vertices.shape[0] 299 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 300 | self.full_lmk_faces_idx.repeat(bz, 1), 301 | self.full_lmk_bary_coords.repeat(bz, 1, 1)) 302 | 303 | landmarksmp = vertices2landmarks(vertices, self.faces_tensor, 304 | self.mp_lmk_faces_idx.repeat(vertices.shape[0], 1), 305 | self.mp_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) 306 | 307 | return { 308 | 'vertices': vertices, 309 | 'landmarks_fan': landmarks2d, 310 | 'landmarks_fan_3d': landmarks3d, 311 | 'landmarks_mp': landmarksmp 312 | } 313 | 314 | 315 | -------------------------------------------------------------------------------- /skyreels_a1/src/FLAME/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | def rot_mat_to_euler(rot_mats): 27 | # Calculates rotation matrix to euler angles 28 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 29 | 30 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 31 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 32 | return torch.atan2(-rot_mats[:, 2, 0], sy) 33 | 34 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, 35 | dynamic_lmk_b_coords, 36 | neck_kin_chain, dtype=torch.float32): 37 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks 38 | 39 | 40 | To do so, we first compute the rotation of the neck around the y-axis 41 | and then use a pre-computed look-up table to find the faces and the 42 | barycentric coordinates that will be used. 43 | 44 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) 45 | for providing the original TensorFlow implementation and for the LUT. 46 | 47 | Parameters 48 | ---------- 49 | vertices: torch.tensor BxVx3, dtype = torch.float32 50 | The tensor of input vertices 51 | pose: torch.tensor Bx(Jx3), dtype = torch.float32 52 | The current pose of the body model 53 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long 54 | The look-up table from neck rotation to faces 55 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 56 | The look-up table from neck rotation to barycentric coordinates 57 | neck_kin_chain: list 58 | A python list that contains the indices of the joints that form the 59 | kinematic chain of the neck. 60 | dtype: torch.dtype, optional 61 | 62 | Returns 63 | ------- 64 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long 65 | A tensor of size BxL that contains the indices of the faces that 66 | will be used to compute the current dynamic landmarks. 67 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 68 | A tensor of size BxL that contains the indices of the faces that 69 | will be used to compute the current dynamic landmarks. 70 | ''' 71 | 72 | batch_size = vertices.shape[0] 73 | 74 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 75 | neck_kin_chain) 76 | rot_mats = batch_rodrigues( 77 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 78 | 79 | rel_rot_mat = torch.eye(3, device=vertices.device, 80 | dtype=dtype).unsqueeze_(dim=0) 81 | for idx in range(len(neck_kin_chain)): 82 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 83 | 84 | y_rot_angle = torch.round( 85 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 86 | max=39)).to(dtype=torch.long) 87 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 88 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 89 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 90 | y_rot_angle = (neg_mask * neg_vals + 91 | (1 - neg_mask) * y_rot_angle) 92 | 93 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 94 | 0, y_rot_angle) 95 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 96 | 0, y_rot_angle) 97 | 98 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 99 | 100 | 101 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): 102 | ''' Calculates landmarks by barycentric interpolation 103 | 104 | Parameters 105 | ---------- 106 | vertices: torch.tensor BxVx3, dtype = torch.float32 107 | The tensor of input vertices 108 | faces: torch.tensor Fx3, dtype = torch.long 109 | The faces of the mesh 110 | lmk_faces_idx: torch.tensor L, dtype = torch.long 111 | The tensor with the indices of the faces used to calculate the 112 | landmarks. 113 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 114 | The tensor of barycentric coordinates that are used to interpolate 115 | the landmarks 116 | 117 | Returns 118 | ------- 119 | landmarks: torch.tensor BxLx3, dtype = torch.float32 120 | The coordinates of the landmarks for each mesh in the batch 121 | ''' 122 | # Extract the indices of the vertices for each face 123 | # BxLx3 124 | batch_size, num_verts = vertices.shape[:2] 125 | device = vertices.device 126 | 127 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 128 | batch_size, -1, 3) 129 | 130 | lmk_faces += torch.arange( 131 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts 132 | 133 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( 134 | batch_size, -1, 3, 3) 135 | 136 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 137 | return landmarks 138 | 139 | 140 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, 141 | lbs_weights, pose2rot=True, dtype=torch.float32): 142 | ''' Performs Linear Blend Skinning with the given shape and pose parameters 143 | 144 | Parameters 145 | ---------- 146 | betas : torch.tensor BxNB 147 | The tensor of shape parameters 148 | pose : torch.tensor Bx(J + 1) * 3 149 | The pose parameters in axis-angle format 150 | v_template torch.tensor BxVx3 151 | The template mesh that will be deformed 152 | shapedirs : torch.tensor 1xNB 153 | The tensor of PCA shape displacements 154 | posedirs : torch.tensor Px(V * 3) 155 | The pose PCA coefficients 156 | J_regressor : torch.tensor JxV 157 | The regressor array that is used to calculate the joints from 158 | the position of the vertices 159 | parents: torch.tensor J 160 | The array that describes the kinematic tree for the model 161 | lbs_weights: torch.tensor N x V x (J + 1) 162 | The linear blend skinning weights that represent how much the 163 | rotation matrix of each part affects each vertex 164 | pose2rot: bool, optional 165 | Flag on whether to convert the input pose tensor to rotation 166 | matrices. The default value is True. If False, then the pose tensor 167 | should already contain rotation matrices and have a size of 168 | Bx(J + 1)x9 169 | dtype: torch.dtype, optional 170 | 171 | Returns 172 | ------- 173 | verts: torch.tensor BxVx3 174 | The vertices of the mesh after applying the shape and pose 175 | displacements. 176 | joints: torch.tensor BxJx3 177 | The joints of the model 178 | ''' 179 | 180 | batch_size = max(betas.shape[0], pose.shape[0]) 181 | device = betas.device 182 | 183 | # Add shape contribution 184 | v_shaped = v_template + blend_shapes(betas, shapedirs) 185 | 186 | # Get the joints 187 | # NxJx3 array 188 | J = vertices2joints(J_regressor, v_shaped) 189 | 190 | # 3. Add pose blend shapes 191 | # N x J x 3 x 3 192 | ident = torch.eye(3, dtype=dtype, device=device) 193 | if pose2rot: 194 | rot_mats = batch_rodrigues( 195 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) 196 | 197 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 198 | # (N x P) x (P, V * 3) -> N x V x 3 199 | pose_offsets = torch.matmul(pose_feature, posedirs) \ 200 | .view(batch_size, -1, 3) 201 | else: 202 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 203 | rot_mats = pose.view(batch_size, -1, 3, 3) 204 | 205 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), 206 | posedirs).view(batch_size, -1, 3) 207 | 208 | v_posed = pose_offsets + v_shaped 209 | # 4. Get the global joint location 210 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 211 | 212 | # 5. Do skinning: 213 | # W is N x V x (J + 1) 214 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) 215 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 216 | num_joints = J_regressor.shape[0] 217 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ 218 | .view(batch_size, -1, 4, 4) 219 | 220 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], 221 | dtype=dtype, device=device) 222 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 223 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 224 | 225 | verts = v_homo[:, :, :3, 0] 226 | 227 | return verts, J_transformed 228 | 229 | 230 | def vertices2joints(J_regressor, vertices): 231 | ''' Calculates the 3D joint locations from the vertices 232 | 233 | Parameters 234 | ---------- 235 | J_regressor : torch.tensor JxV 236 | The regressor array that is used to calculate the joints from the 237 | position of the vertices 238 | vertices : torch.tensor BxVx3 239 | The tensor of mesh vertices 240 | 241 | Returns 242 | ------- 243 | torch.tensor BxJx3 244 | The location of the joints 245 | ''' 246 | 247 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) 248 | 249 | 250 | def blend_shapes(betas, shape_disps): 251 | ''' Calculates the per vertex displacement due to the blend shapes 252 | 253 | 254 | Parameters 255 | ---------- 256 | betas : torch.tensor Bx(num_betas) 257 | Blend shape coefficients 258 | shape_disps: torch.tensor Vx3x(num_betas) 259 | Blend shapes 260 | 261 | Returns 262 | ------- 263 | torch.tensor BxVx3 264 | The per-vertex displacement due to shape deformation 265 | ''' 266 | 267 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 268 | # i.e. Multiply each shape displacement by its corresponding beta and 269 | # then sum them. 270 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) 271 | return blend_shape 272 | 273 | 274 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 275 | ''' Calculates the rotation matrices for a batch of rotation vectors 276 | Parameters 277 | ---------- 278 | rot_vecs: torch.tensor Nx3 279 | array of N axis-angle vectors 280 | Returns 281 | ------- 282 | R: torch.tensor Nx3x3 283 | The rotation matrices for the given axis-angle parameters 284 | ''' 285 | 286 | batch_size = rot_vecs.shape[0] 287 | device = rot_vecs.device 288 | 289 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 290 | rot_dir = rot_vecs / angle 291 | 292 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 293 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 294 | 295 | # Bx1 arrays 296 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 297 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 298 | 299 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 300 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 301 | .view((batch_size, 3, 3)) 302 | 303 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 304 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 305 | return rot_mat 306 | 307 | 308 | def transform_mat(R, t): 309 | ''' Creates a batch of transformation matrices 310 | Args: 311 | - R: Bx3x3 array of a batch of rotation matrices 312 | - t: Bx3x1 array of a batch of translation vectors 313 | Returns: 314 | - T: Bx4x4 Transformation matrix 315 | ''' 316 | # No padding left or right, only add an extra row 317 | return torch.cat([F.pad(R, [0, 0, 0, 1]), 318 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 319 | 320 | 321 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 322 | """ 323 | Applies a batch of rigid transformations to the joints 324 | 325 | Parameters 326 | ---------- 327 | rot_mats : torch.tensor BxNx3x3 328 | Tensor of rotation matrices 329 | joints : torch.tensor BxNx3 330 | Locations of joints 331 | parents : torch.tensor BxN 332 | The kinematic tree of each object 333 | dtype : torch.dtype, optional: 334 | The data type of the created tensors, the default is torch.float32 335 | 336 | Returns 337 | ------- 338 | posed_joints : torch.tensor BxNx3 339 | The locations of the joints after applying the pose rotations 340 | rel_transforms : torch.tensor BxNx4x4 341 | The relative (with respect to the root joint) rigid transformations 342 | for all the joints 343 | """ 344 | 345 | joints = torch.unsqueeze(joints, dim=-1) 346 | 347 | rel_joints = joints.clone() 348 | rel_joints[:, 1:] -= joints[:, parents[1:]] 349 | 350 | # transforms_mat = transform_mat( 351 | # rot_mats.view(-1, 3, 3), 352 | # rel_joints.view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4) 353 | transforms_mat = transform_mat( 354 | rot_mats.view(-1, 3, 3), 355 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) 356 | 357 | transform_chain = [transforms_mat[:, 0]] 358 | for i in range(1, parents.shape[0]): 359 | # Subtract the joint location at the rest pose 360 | # No need for rotation, since it's identity when at rest 361 | curr_res = torch.matmul(transform_chain[parents[i]], 362 | transforms_mat[:, i]) 363 | transform_chain.append(curr_res) 364 | 365 | transforms = torch.stack(transform_chain, dim=1) 366 | 367 | # The last column of the transformations contains the posed joints 368 | posed_joints = transforms[:, :, :3, 3] 369 | 370 | # The last column of the transformations contains the posed joints 371 | posed_joints = transforms[:, :, :3, 3] 372 | 373 | joints_homogen = F.pad(joints, [0, 0, 0, 1]) 374 | 375 | rel_transforms = transforms - F.pad( 376 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) 377 | 378 | return posed_joints, rel_transforms -------------------------------------------------------------------------------- /skyreels_a1/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/__init__.py -------------------------------------------------------------------------------- /skyreels_a1/src/frame_interpolation.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/dajes/frame-interpolation-pytorch 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import bisect 7 | import shutil 8 | import pdb 9 | from tqdm import tqdm 10 | 11 | def init_frame_interpolation_model(checkpoint_name, device="cuda"): 12 | print(f"Initializing frame interpolation model from {checkpoint_name}") 13 | 14 | model = torch.jit.load(checkpoint_name, map_location='cpu') 15 | model.eval() 16 | model = model.half() 17 | model = model.to(device=device) 18 | return model 19 | 20 | 21 | def batch_images_interpolation_tool(input_tensor, model, inter_frames=1): 22 | 23 | video_tensor = [] 24 | frame_num = input_tensor.shape[2] # bs, channel, frame, height, width 25 | 26 | for idx in tqdm(range(frame_num-1)): 27 | image1 = input_tensor[:,:,idx] 28 | image2 = input_tensor[:,:,idx+1] 29 | 30 | results = [image1, image2] 31 | 32 | inter_frames = int(inter_frames) 33 | idxes = [0, inter_frames + 1] 34 | remains = list(range(1, inter_frames + 1)) 35 | 36 | splits = torch.linspace(0, 1, inter_frames + 2) 37 | 38 | for _ in range(len(remains)): 39 | starts = splits[idxes[:-1]] 40 | ends = splits[idxes[1:]] 41 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() 42 | matrix = torch.argmin(distances).item() 43 | start_i, step = np.unravel_index(matrix, distances.shape) 44 | end_i = start_i + 1 45 | 46 | x0 = results[start_i] 47 | x1 = results[end_i] 48 | 49 | x0 = x0.half() 50 | x1 = x1.half() 51 | x0 = x0.cuda() 52 | x1 = x1.cuda() 53 | 54 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) 55 | 56 | with torch.no_grad(): 57 | prediction = model(x0, x1, dt) 58 | insert_position = bisect.bisect_left(idxes, remains[step]) 59 | idxes.insert(insert_position, remains[step]) 60 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) 61 | del remains[step] 62 | 63 | for sub_idx in range(len(results)-1): 64 | video_tensor.append(results[sub_idx].unsqueeze(2)) 65 | 66 | video_tensor.append(input_tensor[:,:,-1].unsqueeze(2)) 67 | video_tensor = torch.cat(video_tensor, dim=2) 68 | return video_tensor -------------------------------------------------------------------------------- /skyreels_a1/src/media_pipe/draw_util_2d.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mediapipe as mp 3 | import numpy as np 4 | from mediapipe.framework.formats import landmark_pb2 5 | 6 | class FaceMeshVisualizer2d: 7 | # def __init__(self, 8 | # forehead_edge=False, 9 | # upface_only=False, 10 | # draw_eye=True, 11 | # draw_head=False, 12 | # draw_iris=True, 13 | # draw_eyebrow=True, 14 | # draw_mouse=True, 15 | # draw_nose=True, 16 | # draw_pupil=True 17 | # ): 18 | def __init__(self, 19 | forehead_edge=True, 20 | upface_only=True, 21 | draw_eye=True, 22 | draw_head=True, 23 | draw_iris=True, 24 | draw_eyebrow=True, 25 | draw_mouse=True, 26 | draw_nose=True, 27 | draw_pupil=True 28 | ): 29 | self.mp_drawing = mp.solutions.drawing_utils 30 | mp_face_mesh = mp.solutions.face_mesh 31 | self.mp_face_mesh = mp_face_mesh 32 | self.forehead_edge = forehead_edge 33 | 34 | DrawingSpec = mp.solutions.drawing_styles.DrawingSpec 35 | f_thick = 1 36 | f_rad = 1 37 | right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) 38 | right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) 39 | right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) 40 | left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) 41 | left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) 42 | left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) 43 | head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) 44 | nose_draw = DrawingSpec(color=(200, 200, 200), thickness=f_thick, circle_radius=f_rad) 45 | 46 | mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad) 47 | mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad) 48 | 49 | mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad) 50 | mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad) 51 | 52 | mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad) 53 | mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad) 54 | 55 | mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad) 56 | mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad) 57 | 58 | FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)] 59 | FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)] 60 | 61 | FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)] 62 | FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)] 63 | 64 | FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)] 65 | FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)] 66 | 67 | FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)] 68 | FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)] 69 | 70 | FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)] 71 | 72 | # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. 73 | face_connection_spec = {} 74 | 75 | #from IPython import embed 76 | #embed() 77 | if self.forehead_edge: 78 | for edge in mp_face_mesh.FACEMESH_FACE_OVAL: 79 | face_connection_spec[edge] = head_draw 80 | else: 81 | if draw_head: 82 | FACEMESH_CUSTOM_FACE_OVAL_sorted = sorted(FACEMESH_CUSTOM_FACE_OVAL) 83 | if upface_only: 84 | for edge in [FACEMESH_CUSTOM_FACE_OVAL_sorted[edge_idx] for edge_idx in [1,2,9,12,13,16,22,25]]: 85 | face_connection_spec[edge] = head_draw 86 | else: 87 | for edge in FACEMESH_CUSTOM_FACE_OVAL_sorted: 88 | face_connection_spec[edge] = head_draw 89 | 90 | if draw_eye: 91 | for edge in mp_face_mesh.FACEMESH_LEFT_EYE: 92 | face_connection_spec[edge] = left_eye_draw 93 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: 94 | face_connection_spec[edge] = right_eye_draw 95 | 96 | if draw_eyebrow: 97 | for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: 98 | face_connection_spec[edge] = left_eyebrow_draw 99 | 100 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: 101 | face_connection_spec[edge] = right_eyebrow_draw 102 | 103 | if draw_iris: 104 | for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: 105 | face_connection_spec[edge] = left_iris_draw 106 | for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: 107 | face_connection_spec[edge] = right_iris_draw 108 | 109 | #for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: 110 | # face_connection_spec[edge] = right_eyebrow_draw 111 | # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: 112 | # face_connection_spec[edge] = right_iris_draw 113 | 114 | # for edge in mp_face_mesh.FACEMESH_LIPS: 115 | # face_connection_spec[edge] = mouth_draw 116 | 117 | if draw_mouse: 118 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT: 119 | face_connection_spec[edge] = mouth_draw_obl 120 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT: 121 | face_connection_spec[edge] = mouth_draw_obr 122 | for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT: 123 | face_connection_spec[edge] = mouth_draw_ibl 124 | for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT: 125 | face_connection_spec[edge] = mouth_draw_ibr 126 | for edge in FACEMESH_LIPS_OUTER_TOP_LEFT: 127 | face_connection_spec[edge] = mouth_draw_otl 128 | for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT: 129 | face_connection_spec[edge] = mouth_draw_otr 130 | for edge in FACEMESH_LIPS_INNER_TOP_LEFT: 131 | face_connection_spec[edge] = mouth_draw_itl 132 | for edge in FACEMESH_LIPS_INNER_TOP_RIGHT: 133 | face_connection_spec[edge] = mouth_draw_itr 134 | 135 | self.face_connection_spec = face_connection_spec 136 | 137 | self.pupil_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} 138 | self.nose_landmark_spec = {4: nose_draw} 139 | 140 | self.draw_pupil = draw_pupil 141 | self.draw_nose = draw_nose 142 | 143 | def draw_points(self, image, landmark_list, drawing_spec, halfwidth: int = 2): 144 | """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all 145 | landmarks. Until our PR is merged into mediapipe, we need this separate method.""" 146 | if len(image.shape) != 3: 147 | raise ValueError("Input image must be H,W,C.") 148 | image_rows, image_cols, image_channels = image.shape 149 | if image_channels != 3: # BGR channels 150 | raise ValueError('Input image must contain three channel bgr data.') 151 | for idx, landmark in enumerate(landmark_list.landmark): 152 | if idx not in drawing_spec: 153 | continue 154 | 155 | if ( 156 | (landmark.HasField('visibility') and landmark.visibility < 0.9) or 157 | (landmark.HasField('presence') and landmark.presence < 0.5) 158 | ): 159 | continue 160 | if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: 161 | continue 162 | 163 | image_x = int(image_cols * landmark.x) 164 | image_y = int(image_rows * landmark.y) 165 | 166 | draw_color = drawing_spec[idx].color 167 | image[image_y - halfwidth : image_y + halfwidth, image_x - halfwidth : image_x + halfwidth, :] = draw_color 168 | 169 | 170 | def draw_landmarks(self, image_size, keypoints, normed=False): 171 | ini_size = [512, 512] 172 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) 173 | if keypoints is not None: 174 | new_landmarks = landmark_pb2.NormalizedLandmarkList() 175 | for i in range(keypoints.shape[0]): 176 | landmark = new_landmarks.landmark.add() 177 | if normed: 178 | landmark.x = keypoints[i, 0] 179 | landmark.y = keypoints[i, 1] 180 | else: 181 | landmark.x = keypoints[i, 0] / image_size[0] 182 | landmark.y = keypoints[i, 1] / image_size[1] 183 | landmark.z = 1.0 184 | 185 | self.mp_drawing.draw_landmarks( 186 | image=image, 187 | landmark_list=new_landmarks, 188 | connections=self.face_connection_spec.keys(), 189 | landmark_drawing_spec=None, 190 | connection_drawing_spec=self.face_connection_spec 191 | ) 192 | 193 | # if self.draw_pupil: 194 | # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2) 195 | 196 | # if self.draw_nose: 197 | # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2) 198 | 199 | if self.draw_pupil: 200 | self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1) 201 | 202 | if self.draw_nose: 203 | self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2) 204 | 205 | image = cv2.resize(image, (image_size[0], image_size[1])) 206 | 207 | return image 208 | 209 | def draw_landmarks_v2(self, image_size, keypoints, normed=False): 210 | # ini_size = [512, 512] 211 | ini_size = [image_size[0], image_size[1]] 212 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) 213 | if keypoints is not None: 214 | new_landmarks = landmark_pb2.NormalizedLandmarkList() 215 | for i in range(keypoints.shape[0]): 216 | landmark = new_landmarks.landmark.add() 217 | if normed: 218 | landmark.x = keypoints[i, 0] 219 | landmark.y = keypoints[i, 1] 220 | else: 221 | landmark.x = keypoints[i, 0] / image_size[0] 222 | landmark.y = keypoints[i, 1] / image_size[1] 223 | landmark.z = 1.0 224 | 225 | self.mp_drawing.draw_landmarks( 226 | image=image, 227 | landmark_list=new_landmarks, 228 | connections=self.face_connection_spec.keys(), 229 | landmark_drawing_spec=None, 230 | connection_drawing_spec=self.face_connection_spec 231 | ) 232 | 233 | # if self.draw_pupil: 234 | # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2) 235 | 236 | # if self.draw_nose: 237 | # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2) 238 | 239 | if self.draw_pupil: 240 | self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1) 241 | 242 | if self.draw_nose: 243 | self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2) 244 | 245 | image = cv2.resize(image, (image_size[0], image_size[1])) 246 | 247 | return image 248 | 249 | def draw_landmarks_v3(self, image_size, resize_size, keypoints, normed=False): 250 | # ini_size = [512, 512] 251 | ini_size = [resize_size[0], resize_size[1]] 252 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) 253 | if keypoints is not None: 254 | new_landmarks = landmark_pb2.NormalizedLandmarkList() 255 | for i in range(keypoints.shape[0]): 256 | landmark = new_landmarks.landmark.add() 257 | if normed: 258 | # landmark.x = keypoints[i, 0] * resize_size[0] / image_size[0] 259 | # landmark.y = keypoints[i, 1] * resize_size[1] / image_size[1] 260 | landmark.x = keypoints[i, 0] 261 | landmark.y = keypoints[i, 1] 262 | else: 263 | landmark.x = keypoints[i, 0] / image_size[0] 264 | landmark.y = keypoints[i, 1] / image_size[1] 265 | landmark.z = 1.0 266 | 267 | self.mp_drawing.draw_landmarks( 268 | image=image, 269 | landmark_list=new_landmarks, 270 | connections=self.face_connection_spec.keys(), 271 | landmark_drawing_spec=None, 272 | connection_drawing_spec=self.face_connection_spec 273 | ) 274 | 275 | # if self.draw_pupil: 276 | # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2) 277 | 278 | # if self.draw_nose: 279 | # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2) 280 | 281 | if self.draw_pupil: 282 | self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1) 283 | 284 | if self.draw_nose: 285 | self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2) 286 | 287 | image = cv2.resize(image, (resize_size[0], resize_size[1])) 288 | 289 | return image 290 | -------------------------------------------------------------------------------- /skyreels_a1/src/media_pipe/mp_models/blaze_face_short_range.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/media_pipe/mp_models/blaze_face_short_range.tflite -------------------------------------------------------------------------------- /skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task -------------------------------------------------------------------------------- /skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-A1/52ee51a8e30798273a67b5c5f5bc1d6db360f7d7/skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task -------------------------------------------------------------------------------- /skyreels_a1/src/media_pipe/mp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import time 5 | from tqdm import tqdm 6 | import multiprocessing 7 | import glob 8 | 9 | import mediapipe as mp 10 | from mediapipe import solutions 11 | from mediapipe.framework.formats import landmark_pb2 12 | from mediapipe.tasks import python 13 | from mediapipe.tasks.python import vision 14 | from . import face_landmark 15 | 16 | CUR_DIR = os.path.dirname(__file__) 17 | 18 | 19 | class LMKExtractor(): 20 | def __init__(self, FPS=25): 21 | # Create an FaceLandmarker object. 22 | self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE 23 | base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task')) 24 | base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU 25 | options = vision.FaceLandmarkerOptions(base_options=base_options, 26 | running_mode=self.mode, 27 | # min_face_detection_confidence=0.3, 28 | # min_face_presence_confidence=0.3, 29 | # min_tracking_confidence=0.3, 30 | output_face_blendshapes=True, 31 | output_facial_transformation_matrixes=True, 32 | num_faces=1) 33 | self.detector = face_landmark.FaceLandmarker.create_from_options(options) 34 | self.last_ts = 0 35 | self.frame_ms = int(1000 / FPS) 36 | 37 | det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite')) 38 | det_options = vision.FaceDetectorOptions(base_options=det_base_options) 39 | self.det_detector = vision.FaceDetector.create_from_options(det_options) 40 | 41 | 42 | def __call__(self, img): 43 | frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 44 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) 45 | t0 = time.time() 46 | if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO: 47 | det_result = self.det_detector.detect(image) 48 | if len(det_result.detections) != 1: 49 | return None 50 | self.last_ts += self.frame_ms 51 | try: 52 | detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts) 53 | except: 54 | return None 55 | elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE: 56 | # det_result = self.det_detector.detect(image) 57 | 58 | # if len(det_result.detections) != 1: 59 | # return None 60 | try: 61 | detection_result, mesh3d = self.detector.detect(image) 62 | except: 63 | return None 64 | 65 | 66 | bs_list = detection_result.face_blendshapes 67 | if len(bs_list) == 1: 68 | bs = bs_list[0] 69 | bs_values = [] 70 | for index in range(len(bs)): 71 | bs_values.append(bs[index].score) 72 | bs_values = bs_values[1:] # remove neutral 73 | trans_mat = detection_result.facial_transformation_matrixes[0] 74 | face_landmarks_list = detection_result.face_landmarks 75 | face_landmarks = face_landmarks_list[0] 76 | lmks = [] 77 | for index in range(len(face_landmarks)): 78 | x = face_landmarks[index].x 79 | y = face_landmarks[index].y 80 | z = face_landmarks[index].z 81 | lmks.append([x, y, z]) 82 | lmks = np.array(lmks) 83 | 84 | lmks3d = np.array(mesh3d.vertex_buffer) 85 | lmks3d = lmks3d.reshape(-1, 5)[:, :3] 86 | mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1 87 | 88 | return { 89 | "lmks": lmks, 90 | 'lmks3d': lmks3d, 91 | "trans_mat": trans_mat, 92 | 'faces': mp_tris, 93 | "bs": bs_values 94 | } 95 | else: 96 | # print('multiple faces in the image: {}'.format(img_path)) 97 | return None 98 | -------------------------------------------------------------------------------- /skyreels_a1/src/media_pipe/readme: -------------------------------------------------------------------------------- 1 | The landmark file defines the barycentric embedding of 105 points of the Mediapipe mesh in the surface of FLAME. 2 | In consists of three arrays: lmk_face_idx, lmk_b_coords, and landmark_indices. 3 | - lmk_face_idx contains for every landmark the index of the FLAME triangle which each landmark is embedded into 4 | - lmk_b_coords are the barycentric weights for each vertex of the triangles 5 | - landmark_indices are the indices of the vertices of the Mediapipe mesh 6 | -------------------------------------------------------------------------------- /skyreels_a1/src/multi_fps.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import subprocess 6 | from .frame_interpolation import batch_images_interpolation_tool 7 | 8 | def multi_fps_tool(frames, frame_inter_model, target_fps, original_fps=12): 9 | frames_np = np.array([np.array(frame) for frame in frames]) 10 | 11 | interpolation_factor = target_fps / original_fps 12 | inter_frames = math.ceil(interpolation_factor) - 1 13 | frames_tensor = torch.from_numpy(frames_np).permute(3, 0, 1, 2).unsqueeze(0) / 255.0 14 | 15 | video = batch_images_interpolation_tool(frames_tensor, frame_inter_model, inter_frames=inter_frames) 16 | video = video.squeeze(0) 17 | video = video.permute(1, 2, 3, 0) 18 | video = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() 19 | out_frames = [Image.fromarray(frame) for frame in video] 20 | 21 | if not interpolation_factor.is_integer(): 22 | print(f"Warning: target fps {target_fps} is not mulitple of 12, which may cause unstable video rate.") 23 | out_frames = adjust_video_fps(out_frames, target_fps, int(target_fps//12+1)*12) 24 | 25 | return out_frames 26 | 27 | def adjust_video_fps(frames, target_fps, fps): 28 | video_length = len(frames) 29 | 30 | duration = video_length / fps 31 | target_times = np.arange(0, duration, 1/target_fps) 32 | frame_indices = (target_times * fps).astype(np.int32) 33 | 34 | frame_indices = frame_indices[frame_indices < video_length] 35 | new_frames = [] 36 | for idx in frame_indices: 37 | new_frames.append(frames[idx]) 38 | 39 | return new_frames -------------------------------------------------------------------------------- /skyreels_a1/src/smirk_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import timm 5 | 6 | 7 | def create_backbone(backbone_name, pretrained=True): 8 | backbone = timm.create_model(backbone_name, 9 | pretrained=pretrained, 10 | features_only=True) 11 | feature_dim = backbone.feature_info[-1]['num_chs'] 12 | return backbone, feature_dim 13 | 14 | class PoseEncoder(nn.Module): 15 | def __init__(self) -> None: 16 | super().__init__() 17 | 18 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_small_minimal_100') 19 | 20 | self.pose_cam_layers = nn.Sequential( 21 | nn.Linear(feature_dim, 6) 22 | ) 23 | 24 | self.init_weights() 25 | 26 | def init_weights(self): 27 | self.pose_cam_layers[-1].weight.data *= 0.001 28 | self.pose_cam_layers[-1].bias.data *= 0.001 29 | 30 | self.pose_cam_layers[-1].weight.data[3] = 0 31 | self.pose_cam_layers[-1].bias.data[3] = 7 32 | 33 | 34 | def forward(self, img): 35 | features = self.encoder(img)[-1] 36 | 37 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) 38 | 39 | outputs = {} 40 | 41 | pose_cam = self.pose_cam_layers(features).reshape(img.size(0), -1) 42 | outputs['pose_params'] = pose_cam[...,:3] 43 | # import pdb;pdb.set_trace() 44 | outputs['cam'] = pose_cam[...,3:] 45 | 46 | return outputs 47 | 48 | 49 | class ShapeEncoder(nn.Module): 50 | def __init__(self, n_shape=300) -> None: 51 | super().__init__() 52 | 53 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100') 54 | 55 | self.shape_layers = nn.Sequential( 56 | nn.Linear(feature_dim, n_shape) 57 | ) 58 | 59 | self.init_weights() 60 | 61 | 62 | def init_weights(self): 63 | self.shape_layers[-1].weight.data *= 0 64 | self.shape_layers[-1].bias.data *= 0 65 | 66 | 67 | def forward(self, img): 68 | features = self.encoder(img)[-1] 69 | 70 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) 71 | 72 | parameters = self.shape_layers(features).reshape(img.size(0), -1) 73 | 74 | return {'shape_params': parameters} 75 | 76 | 77 | class ExpressionEncoder(nn.Module): 78 | def __init__(self, n_exp=50) -> None: 79 | super().__init__() 80 | 81 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100') 82 | 83 | self.expression_layers = nn.Sequential( 84 | nn.Linear(feature_dim, n_exp+2+3) # num expressions + jaw + eyelid 85 | ) 86 | 87 | self.n_exp = n_exp 88 | self.init_weights() 89 | 90 | 91 | def init_weights(self): 92 | self.expression_layers[-1].weight.data *= 0.1 93 | self.expression_layers[-1].bias.data *= 0.1 94 | 95 | 96 | def forward(self, img): 97 | features = self.encoder(img)[-1] 98 | 99 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) 100 | 101 | 102 | parameters = self.expression_layers(features).reshape(img.size(0), -1) 103 | 104 | outputs = {} 105 | 106 | outputs['expression_params'] = parameters[...,:self.n_exp] 107 | outputs['eyelid_params'] = torch.clamp(parameters[...,self.n_exp:self.n_exp+2], 0, 1) 108 | outputs['jaw_params'] = torch.cat([F.relu(parameters[...,self.n_exp+2].unsqueeze(-1)), 109 | torch.clamp(parameters[...,self.n_exp+3:self.n_exp+5], -.2, .2)], dim=-1) 110 | 111 | return outputs 112 | 113 | 114 | class SmirkEncoder(nn.Module): 115 | def __init__(self, n_exp=50, n_shape=300) -> None: 116 | super().__init__() 117 | 118 | self.pose_encoder = PoseEncoder() 119 | 120 | self.shape_encoder = ShapeEncoder(n_shape=n_shape) 121 | 122 | self.expression_encoder = ExpressionEncoder(n_exp=n_exp) 123 | 124 | def forward(self, img): 125 | pose_outputs = self.pose_encoder(img) 126 | shape_outputs = self.shape_encoder(img) 127 | expression_outputs = self.expression_encoder(img) 128 | 129 | outputs = {} 130 | outputs.update(pose_outputs) 131 | outputs.update(shape_outputs) 132 | outputs.update(expression_outputs) 133 | 134 | return outputs 135 | -------------------------------------------------------------------------------- /skyreels_a1/src/utils/mediapipe_utils.py: -------------------------------------------------------------------------------- 1 | import mediapipe as mp 2 | from mediapipe.tasks import python 3 | from mediapipe.tasks.python import vision 4 | import cv2 5 | import numpy as np 6 | import os 7 | 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import os 13 | import cv2 14 | 15 | # borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py 16 | def face_vertices(vertices, faces): 17 | """ 18 | :param vertices: [batch size, number of vertices, 3] 19 | :param faces: [batch size, number of faces, 3] 20 | :return: [batch size, number of faces, 3, 3] 21 | """ 22 | assert (vertices.ndimension() == 3) 23 | assert (faces.ndimension() == 3) 24 | assert (vertices.shape[0] == faces.shape[0]) 25 | assert (vertices.shape[2] == 3) 26 | assert (faces.shape[2] == 3) 27 | 28 | bs, nv = vertices.shape[:2] 29 | bs, nf = faces.shape[:2] 30 | device = vertices.device 31 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] 32 | vertices = vertices.reshape((bs * nv, 3)) 33 | # pytorch only supports long and byte tensors for indexing 34 | return vertices[faces.long()] 35 | 36 | def vertex_normals(vertices, faces): 37 | """ 38 | :param vertices: [batch size, number of vertices, 3] 39 | :param faces: [batch size, number of faces, 3] 40 | :return: [batch size, number of vertices, 3] 41 | """ 42 | assert (vertices.ndimension() == 3) 43 | assert (faces.ndimension() == 3) 44 | assert (vertices.shape[0] == faces.shape[0]) 45 | assert (vertices.shape[2] == 3) 46 | assert (faces.shape[2] == 3) 47 | bs, nv = vertices.shape[:2] 48 | bs, nf = faces.shape[:2] 49 | device = vertices.device 50 | normals = torch.zeros(bs * nv, 3).to(device) 51 | 52 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] # expanded faces 53 | vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()] 54 | 55 | faces = faces.reshape(-1, 3) 56 | vertices_faces = vertices_faces.reshape(-1, 3, 3) 57 | 58 | normals.index_add_(0, faces[:, 1].long(), 59 | torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1], vertices_faces[:, 0] - vertices_faces[:, 1])) 60 | normals.index_add_(0, faces[:, 2].long(), 61 | torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2], vertices_faces[:, 1] - vertices_faces[:, 2])) 62 | normals.index_add_(0, faces[:, 0].long(), 63 | torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0], vertices_faces[:, 2] - vertices_faces[:, 0])) 64 | 65 | normals = F.normalize(normals, eps=1e-6, dim=1) 66 | normals = normals.reshape((bs, nv, 3)) 67 | # pytorch only supports long and byte tensors for indexing 68 | return normals 69 | 70 | def batch_orth_proj(X, camera): 71 | ''' orthgraphic projection 72 | X: 3d vertices, [bz, n_point, 3] 73 | camera: scale and translation, [bz, 3], [scale, tx, ty] 74 | ''' 75 | #print('--------') 76 | #print(camera[0, 1:].abs()) 77 | #print(X[0].abs().mean(0)) 78 | 79 | camera = camera.clone().view(-1, 1, 3) 80 | X_trans = X[:, :, :2] + camera[:, :, 1:] 81 | #print(X_trans[0].abs().mean(0)) 82 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2) 83 | Xn = (camera[:, :, 0:1] * X_trans) 84 | return Xn 85 | 86 | class MP_2_FLAME(): 87 | """ 88 | Convert Mediapipe 52 blendshape scores to FLAME's coefficients 89 | """ 90 | def __init__(self, mappings_path): 91 | self.bs2exp = np.load(os.path.join(mappings_path, 'bs2exp.npy')) 92 | self.bs2pose = np.load(os.path.join(mappings_path, 'bs2pose.npy')) 93 | self.bs2eye = np.load(os.path.join(mappings_path, 'bs2eye.npy')) 94 | 95 | def convert(self, blendshape_scores : np.array): 96 | # blendshape_scores: [N, 52] 97 | 98 | # Calculate expression, pose, and eye_pose using the mappings 99 | exp = blendshape_scores @ self.bs2exp 100 | pose = blendshape_scores @ self.bs2pose 101 | pose[0, :3] = 0 # we do not support head rotation yet 102 | eye_pose = blendshape_scores @ self.bs2eye 103 | 104 | return exp, pose, eye_pose 105 | 106 | class MediaPipeUtils: 107 | def __init__(self, model_asset_path='pretrained_models/mediapipe/face_landmarker.task', mappings_path='pretrained_models/mediapipe/'): 108 | base_options = python.BaseOptions(model_asset_path=model_asset_path) 109 | options = vision.FaceLandmarkerOptions(base_options=base_options, 110 | output_face_blendshapes=True, 111 | output_facial_transformation_matrixes=True, 112 | num_faces=1, 113 | min_face_detection_confidence=0.1, 114 | min_face_presence_confidence=0.1) 115 | self.detector = vision.FaceLandmarker.create_from_options(options) 116 | self.mp2flame = MP_2_FLAME(mappings_path=mappings_path) 117 | 118 | def run_mediapipe(self, image): 119 | image_numpy = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 120 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_numpy) 121 | detection_result = self.detector.detect(image) 122 | 123 | if len(detection_result.face_landmarks) == 0: 124 | print('No face detected') 125 | return None 126 | 127 | blend_scores = detection_result.face_blendshapes[0] 128 | blend_scores = np.array(list(map(lambda l: l.score, blend_scores)), dtype=np.float32).reshape(1, 52) 129 | exp, pose, eye_pose = self.mp2flame.convert(blendshape_scores=blend_scores) 130 | 131 | face_landmarks = detection_result.face_landmarks[0] 132 | face_landmarks_numpy = np.zeros((478, 3)) 133 | 134 | for i, landmark in enumerate(face_landmarks): 135 | face_landmarks_numpy[i] = [landmark.x * image.width, landmark.y * image.height, landmark.z] 136 | 137 | return face_landmarks_numpy, exp, pose, eye_pose 138 | --------------------------------------------------------------------------------