├── LICENSE ├── README.md ├── asset ├── Aragaki.mp4 ├── cxk.mp4 ├── jijin.mp4 ├── kara.mp4 ├── lyl.mp4 ├── num18.mp4 ├── pipeline.png ├── solo.mp4 ├── zhiji_logo.png └── zl.mp4 ├── configs ├── inference │ ├── audio │ │ └── lyl.wav │ ├── head_pose_temp │ │ ├── pose_ref_video.mp4 │ │ └── pose_temp.npy │ ├── inference_audio.yaml │ ├── inference_v1.yaml │ ├── inference_v2.yaml │ ├── pose_videos │ │ └── solo_pose.mp4 │ ├── ref_images │ │ ├── Aragaki.png │ │ ├── lyl.png │ │ └── solo.png │ └── video │ │ └── Aragaki_song.mp4 ├── prompts │ ├── animation.yaml │ ├── animation_audio.yaml │ ├── animation_facereenac.yaml │ └── test_cases.py └── train │ ├── stage1.yaml │ └── stage2.yaml ├── pretrained_model └── Put pre-trained weights here.txt ├── requirements.txt ├── scripts ├── app.py ├── audio2vid.py ├── generate_ref_pose.py ├── pose2vid.py ├── prepare_video.py ├── preprocess_dataset.py ├── vid2pose.py └── vid2vid.py ├── src ├── audio_models │ ├── mish.py │ ├── model.py │ ├── pose_model.py │ ├── torch_utils.py │ └── wav2vec2.py ├── dataset │ └── dataset_face.py ├── models │ ├── attention.py │ ├── motion_module.py │ ├── mutual_self_attention.py │ ├── pose_guider.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── transformer_3d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ ├── unet_3d.py │ └── unet_3d_blocks.py ├── pipelines │ ├── context.py │ ├── pipeline_pose2img.py │ ├── pipeline_pose2vid.py │ ├── pipeline_pose2vid_long.py │ └── utils.py └── utils │ ├── audio_util.py │ ├── draw_util.py │ ├── face_landmark.py │ ├── frame_interpolation.py │ ├── mp_models │ ├── blaze_face_short_range.tflite │ ├── face_landmarker_v2_with_blendshapes.task │ └── pose_landmarker_heavy.task │ ├── mp_utils.py │ ├── pose_util.py │ └── util.py ├── train_stage_1.py └── train_stage_2.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AniPortrait 2 | 3 | **AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations** 4 | 5 | Author: Huawei Wei, Zejun Yang, Zhisheng Wang 6 | 7 | Organization: Tencent Games Zhiji, Tencent 8 | 9 | ![zhiji_logo](asset/zhiji_logo.png) 10 | 11 | Here we propose AniPortrait, a novel framework for generating high-quality animation driven by 12 | audio and a reference portrait image. You can also provide a video to achieve face reenacment. 13 | 14 | 15 | 16 | 17 | 18 | ## Pipeline 19 | 20 | ![pipeline](asset/pipeline.png) 21 | 22 | ## Updates / TODO List 23 | 24 | - ✅ [2024/03/27] Now our paper is available on arXiv. 25 | 26 | - ✅ [2024/03/27] Update the code to generate pose_temp.npy for head pose control. 27 | 28 | - ✅ [2024/04/02] Update a new pose retarget strategy for vid2vid. Now we support substantial pose difference between ref_image and source video. 29 | 30 | - ✅ [2024/04/03] We release our Gradio [demo](https://huggingface.co/spaces/ZJYang/AniPortrait_official) on HuggingFace Spaces (thanks to the HF team for their free GPU support)! 31 | 32 | - ✅ [2024/04/07] Update a frame interpolation module to accelerate the inference process. Now you can add -acc in inference commands to get a faster video generation. 33 | 34 | - ✅ [2024/04/21] We have released the audio2pose model and [pre-trained weight](https://huggingface.co/ZJYang/AniPortrait/tree/main) for audio2video. Please update the code and download the weight file to experience. 35 | 36 | ## Various Generated Videos 37 | 38 | ### Self driven 39 | 40 | 41 | 42 | 45 | 48 | 49 |
43 | 44 | 46 | 47 |
50 | 51 | ### Face reenacment 52 | 53 | 54 | 55 | 58 | 61 | 62 |
56 | 57 | 59 | 60 |
63 | 64 | Video Source: [鹿火CAVY from bilibili](https://www.bilibili.com/video/BV1H4421F7dE/?spm_id_from=333.337.search-card.all.click) 65 | 66 | ### Audio driven 67 | 68 | 69 | 70 | 73 | 76 | 77 | 78 | 79 | 82 | 85 | 86 |
71 | 72 | 74 | 75 |
80 | 81 | 83 | 84 |
87 | 88 | ## Installation 89 | 90 | ### Build environment 91 | 92 | We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows: 93 | 94 | ```shell 95 | pip install -r requirements.txt 96 | ``` 97 | 98 | ### Download weights 99 | 100 | All the weights should be placed under the `./pretrained_weights` direcotry. You can download weights manually as follows: 101 | 102 | 1. Download our trained [weights](https://huggingface.co/ZJYang/AniPortrait/tree/main), which include the following parts: `denoising_unet.pth`, `reference_unet.pth`, `pose_guider.pth`, `motion_module.pth`, `audio2mesh.pt`, `audio2pose.pt` and `film_net_fp16.pt`. You can also download from [wisemodel](https://wisemodel.cn/models/zjyang8510/AniPortrait). 103 | 104 | 2. Download pretrained weight of based models and other components: 105 | - [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) 106 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 107 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder) 108 | - [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) 109 | 110 | Finally, these weights should be orgnized as follows: 111 | 112 | ```text 113 | ./pretrained_weights/ 114 | |-- image_encoder 115 | | |-- config.json 116 | | `-- pytorch_model.bin 117 | |-- sd-vae-ft-mse 118 | | |-- config.json 119 | | |-- diffusion_pytorch_model.bin 120 | | `-- diffusion_pytorch_model.safetensors 121 | |-- stable-diffusion-v1-5 122 | | |-- feature_extractor 123 | | | `-- preprocessor_config.json 124 | | |-- model_index.json 125 | | |-- unet 126 | | | |-- config.json 127 | | | `-- diffusion_pytorch_model.bin 128 | | `-- v1-inference.yaml 129 | |-- wav2vec2-base-960h 130 | | |-- config.json 131 | | |-- feature_extractor_config.json 132 | | |-- preprocessor_config.json 133 | | |-- pytorch_model.bin 134 | | |-- README.md 135 | | |-- special_tokens_map.json 136 | | |-- tokenizer_config.json 137 | | `-- vocab.json 138 | |-- audio2mesh.pt 139 | |-- audio2pose.pt 140 | |-- denoising_unet.pth 141 | |-- film_net_fp16.pt 142 | |-- motion_module.pth 143 | |-- pose_guider.pth 144 | `-- reference_unet.pth 145 | ``` 146 | 147 | Note: If you have installed some of the pretrained models, such as `StableDiffusion V1.5`, you can specify their paths in the config file (e.g. `./config/prompts/animation.yaml`). 148 | 149 | 150 | ## Gradio Web UI 151 | 152 | You can try out our web demo by the following command. We alse provide online demo in Huggingface Spaces. 153 | 154 | 155 | ```shell 156 | python -m scripts.app 157 | ``` 158 | 159 | ## Inference 160 | 161 | Kindly note that you can set -L to the desired number of generating frames in the command, for example, `-L 300`. 162 | 163 | **Acceleration method**: If it takes long time to generate a video, you can download [film_net_fp16.pt](https://huggingface.co/ZJYang/AniPortrait/tree/main) and put it under the `./pretrained_weights` direcotry. Then add `-acc` in the command. 164 | 165 | Here are the cli commands for running inference scripts: 166 | 167 | ### Self driven 168 | 169 | ```shell 170 | python -m scripts.pose2vid --config ./configs/prompts/animation.yaml -W 512 -H 512 -acc 171 | ``` 172 | 173 | You can refer the format of animation.yaml to add your own reference images or pose videos. To convert the raw video into a pose video (keypoint sequence), you can run with the following command: 174 | 175 | ```shell 176 | python -m scripts.vid2pose --video_path pose_video_path.mp4 177 | ``` 178 | 179 | ### Face reenacment 180 | 181 | ```shell 182 | python -m scripts.vid2vid --config ./configs/prompts/animation_facereenac.yaml -W 512 -H 512 -acc 183 | ``` 184 | 185 | Add source face videos and reference images in the animation_facereenac.yaml. 186 | 187 | ### Audio driven 188 | 189 | ```shell 190 | python -m scripts.audio2vid --config ./configs/prompts/animation_audio.yaml -W 512 -H 512 -acc 191 | ``` 192 | 193 | Add audios and reference images in the animation_audio.yaml. 194 | 195 | Delete `pose_temp` in `./configs/prompts/animation_audio.yaml` can enable the audio2pose model. 196 | 197 | You can also use this command to generate a pose_temp.npy for head pose control: 198 | 199 | ```shell 200 | python -m scripts.generate_ref_pose --ref_video ./configs/inference/head_pose_temp/pose_ref_video.mp4 --save_path ./configs/inference/head_pose_temp/pose.npy 201 | ``` 202 | 203 | ## Training 204 | 205 | ### Data preparation 206 | Download [VFHQ](https://liangbinxie.github.io/projects/vfhq/) and [CelebV-HQ](https://github.com/CelebV-HQ/CelebV-HQ) 207 | 208 | Extract keypoints from raw videos and write training json file (here is an example of processing VFHQ): 209 | 210 | ```shell 211 | python -m scripts.preprocess_dataset --input_dir VFHQ_PATH --output_dir SAVE_PATH --training_json JSON_PATH 212 | ``` 213 | 214 | Update lines in the training config file: 215 | 216 | ```yaml 217 | data: 218 | json_path: JSON_PATH 219 | ``` 220 | 221 | ### Stage1 222 | 223 | Run command: 224 | 225 | ```shell 226 | accelerate launch train_stage_1.py --config ./configs/train/stage1.yaml 227 | ``` 228 | 229 | ### Stage2 230 | 231 | Put the pretrained motion module weights `mm_sd_v15_v2.ckpt` ([download link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt)) under `./pretrained_weights`. 232 | 233 | Specify the stage1 training weights in the config file `stage2.yaml`, for example: 234 | 235 | ```yaml 236 | stage1_ckpt_dir: './exp_output/stage1' 237 | stage1_ckpt_step: 30000 238 | ``` 239 | 240 | Run command: 241 | 242 | ```shell 243 | accelerate launch train_stage_2.py --config ./configs/train/stage2.yaml 244 | ``` 245 | 246 | ## Acknowledgements 247 | 248 | We first thank the authors of [EMO](https://github.com/HumanAIGC/EMO), and part of the images and audios in our demos are from EMO. Additionally, we would like to thank the contributors to the [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone), [majic-animate](https://github.com/magic-research/magic-animate), [animatediff](https://github.com/guoyww/AnimateDiff) and [Open-AnimateAnyone](https://github.com/guoqincode/Open-AnimateAnyone) repositories, for their open research and exploration. 249 | 250 | ## Citation 251 | 252 | ``` 253 | @misc{wei2024aniportrait, 254 | title={AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations}, 255 | author={Huawei Wei and Zejun Yang and Zhisheng Wang}, 256 | year={2024}, 257 | eprint={2403.17694}, 258 | archivePrefix={arXiv}, 259 | primaryClass={cs.CV} 260 | } 261 | ``` -------------------------------------------------------------------------------- /asset/Aragaki.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/Aragaki.mp4 -------------------------------------------------------------------------------- /asset/cxk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/cxk.mp4 -------------------------------------------------------------------------------- /asset/jijin.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/jijin.mp4 -------------------------------------------------------------------------------- /asset/kara.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/kara.mp4 -------------------------------------------------------------------------------- /asset/lyl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/lyl.mp4 -------------------------------------------------------------------------------- /asset/num18.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/num18.mp4 -------------------------------------------------------------------------------- /asset/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/pipeline.png -------------------------------------------------------------------------------- /asset/solo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/solo.mp4 -------------------------------------------------------------------------------- /asset/zhiji_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/zhiji_logo.png -------------------------------------------------------------------------------- /asset/zl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/asset/zl.mp4 -------------------------------------------------------------------------------- /configs/inference/audio/lyl.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/audio/lyl.wav -------------------------------------------------------------------------------- /configs/inference/head_pose_temp/pose_ref_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/head_pose_temp/pose_ref_video.mp4 -------------------------------------------------------------------------------- /configs/inference/head_pose_temp/pose_temp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/head_pose_temp/pose_temp.npy -------------------------------------------------------------------------------- /configs/inference/inference_audio.yaml: -------------------------------------------------------------------------------- 1 | a2m_model: 2 | out_dim: 1404 3 | latent_dim: 512 4 | model_path: ./pretrained_model/wav2vec2-base-960h 5 | only_last_fetures: True 6 | from_pretrained: True 7 | 8 | a2p_model: 9 | out_dim: 6 10 | latent_dim: 512 11 | model_path: ./pretrained_model/wav2vec2-base-960h 12 | only_last_fetures: True 13 | from_pretrained: True 14 | 15 | pretrained_model: 16 | a2m_ckpt: ./pretrained_model/audio2mesh.pt 17 | a2p_ckpt: ./pretrained_model/audio2pose.pt 18 | -------------------------------------------------------------------------------- /configs/inference/inference_v1.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | unet_use_cross_frame_attention: false 3 | unet_use_temporal_attention: false 4 | use_motion_module: true 5 | motion_module_resolutions: [1,2,4,8] 6 | motion_module_mid_block: false 7 | motion_module_decoder_only: false 8 | motion_module_type: "Vanilla" 9 | 10 | motion_module_kwargs: 11 | num_attention_heads: 8 12 | num_transformer_block: 1 13 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 14 | temporal_position_encoding: true 15 | temporal_position_encoding_max_len: 24 16 | temporal_attention_dim_div: 1 17 | 18 | noise_scheduler_kwargs: 19 | beta_start: 0.00085 20 | beta_end: 0.012 21 | beta_schedule: "linear" 22 | steps_offset: 1 23 | clip_sample: False -------------------------------------------------------------------------------- /configs/inference/inference_v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | motion_module_mid_block: true 12 | motion_module_decoder_only: false 13 | motion_module_type: Vanilla 14 | motion_module_kwargs: 15 | num_attention_heads: 8 16 | num_transformer_block: 1 17 | attention_block_types: 18 | - Temporal_Self 19 | - Temporal_Self 20 | temporal_position_encoding: true 21 | temporal_position_encoding_max_len: 32 22 | temporal_attention_dim_div: 1 23 | 24 | noise_scheduler_kwargs: 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "linear" 28 | clip_sample: false 29 | steps_offset: 1 30 | ### Zero-SNR params 31 | prediction_type: "v_prediction" 32 | rescale_betas_zero_snr: True 33 | timestep_spacing: "trailing" 34 | 35 | sampler: DDIM -------------------------------------------------------------------------------- /configs/inference/pose_videos/solo_pose.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/pose_videos/solo_pose.mp4 -------------------------------------------------------------------------------- /configs/inference/ref_images/Aragaki.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/ref_images/Aragaki.png -------------------------------------------------------------------------------- /configs/inference/ref_images/lyl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/ref_images/lyl.png -------------------------------------------------------------------------------- /configs/inference/ref_images/solo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/ref_images/solo.png -------------------------------------------------------------------------------- /configs/inference/video/Aragaki_song.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/configs/inference/video/Aragaki_song.mp4 -------------------------------------------------------------------------------- /configs/prompts/animation.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5' 2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse' 3 | image_encoder_path: './pretrained_model/image_encoder' 4 | 5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth" 6 | reference_unet_path: "./pretrained_model/reference_unet.pth" 7 | pose_guider_path: "./pretrained_model/pose_guider.pth" 8 | motion_module_path: "./pretrained_model/motion_module.pth" 9 | 10 | inference_config: "./configs/inference/inference_v2.yaml" 11 | weight_dtype: 'fp16' 12 | 13 | test_cases: 14 | "./configs/inference/ref_images/solo.png": 15 | - "./configs/inference/pose_videos/solo_pose.mp4" 16 | -------------------------------------------------------------------------------- /configs/prompts/animation_audio.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5' 2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse' 3 | image_encoder_path: './pretrained_model/image_encoder' 4 | 5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth" 6 | reference_unet_path: "./pretrained_model/reference_unet.pth" 7 | pose_guider_path: "./pretrained_model/pose_guider.pth" 8 | motion_module_path: "./pretrained_model/motion_module.pth" 9 | 10 | audio_inference_config: "./configs/inference/inference_audio.yaml" 11 | inference_config: "./configs/inference/inference_v2.yaml" 12 | weight_dtype: 'fp16' 13 | 14 | # path of your custom head pose template 15 | # pose_temp: "./configs/inference/head_pose_temp/pose_temp.npy" 16 | 17 | test_cases: 18 | "./configs/inference/ref_images/lyl.png": 19 | - "./configs/inference/audio/lyl.wav" 20 | -------------------------------------------------------------------------------- /configs/prompts/animation_facereenac.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: './pretrained_model/stable-diffusion-v1-5' 2 | pretrained_vae_path: './pretrained_model/sd-vae-ft-mse' 3 | image_encoder_path: './pretrained_model/image_encoder' 4 | 5 | denoising_unet_path: "./pretrained_model/denoising_unet.pth" 6 | reference_unet_path: "./pretrained_model/reference_unet.pth" 7 | pose_guider_path: "./pretrained_model/pose_guider.pth" 8 | motion_module_path: "./pretrained_model/motion_module.pth" 9 | 10 | inference_config: "./configs/inference/inference_v2.yaml" 11 | weight_dtype: 'fp16' 12 | 13 | test_cases: 14 | "./configs/inference/ref_images/Aragaki.png": 15 | - "./configs/inference/video/Aragaki_song.mp4" 16 | -------------------------------------------------------------------------------- /configs/prompts/test_cases.py: -------------------------------------------------------------------------------- 1 | TestCasesDict = { 2 | 0: [ 3 | { 4 | "./configs/inference/ref_images/Aragaki.png": [ 5 | "./configs/inference/pose_videos/Aragaki_pose.mp4", 6 | "./configs/inference/pose_videos/solo_pose.mp4", 7 | ] 8 | }, 9 | { 10 | "./configs/inference/ref_images/solo.png": [ 11 | "./configs/inference/pose_videos/solo_pose.mp4", 12 | "./configs/inference/pose_videos/Aragaki_pose.mp4", 13 | ] 14 | }, 15 | ], 16 | } 17 | -------------------------------------------------------------------------------- /configs/train/stage1.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | json_path: "/data/VFHQ/training_data.json" 3 | sample_size: [512, 512] 4 | sample_stride: 4 5 | sample_n_frames: 16 6 | 7 | solver: 8 | gradient_accumulation_steps: 1 9 | mixed_precision: 'fp16' 10 | enable_xformers_memory_efficient_attention: True 11 | gradient_checkpointing: False 12 | max_train_steps: 300000 13 | max_grad_norm: 1.0 14 | # lr 15 | learning_rate: 1.0e-5 16 | scale_lr: False 17 | lr_warmup_steps: 1 18 | lr_scheduler: 'constant' 19 | 20 | # optimizer 21 | use_8bit_adam: True 22 | adam_beta1: 0.9 23 | adam_beta2: 0.999 24 | adam_weight_decay: 1.0e-2 25 | adam_epsilon: 1.0e-8 26 | 27 | val: 28 | validation_steps: 500 29 | validation_steps_tuple: [1, 10, 50, 100, 200, 500] 30 | 31 | 32 | noise_scheduler_kwargs: 33 | num_train_timesteps: 1000 34 | beta_start: 0.00085 35 | beta_end: 0.012 36 | beta_schedule: "scaled_linear" 37 | steps_offset: 1 38 | clip_sample: false 39 | 40 | base_model_path: './pretrained_model/stable-diffusion-v1-5' 41 | vae_model_path: './pretrained_model/sd-vae-ft-mse' 42 | image_encoder_path: './pretrained_model/image_encoder' 43 | controlnet_openpose_path: '' 44 | 45 | train_bs: 2 46 | 47 | weight_dtype: 'fp16' # [fp16, fp32] 48 | uncond_ratio: 0.1 49 | noise_offset: 0.05 50 | snr_gamma: 5.0 51 | enable_zero_snr: True 52 | pose_guider_pretrain: True 53 | 54 | seed: 12580 55 | resume_from_checkpoint: '' 56 | checkpointing_steps: 2000 57 | save_model_epoch_interval: 5 58 | exp_name: 'stage1' 59 | output_dir: './exp_output' -------------------------------------------------------------------------------- /configs/train/stage2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | json_path: "/data/VFHQ/training_data.json" 3 | sample_size: [512, 512] 4 | sample_stride: 1 5 | sample_n_frames: 16 6 | sample_stride_aug: True 7 | 8 | 9 | solver: 10 | gradient_accumulation_steps: 1 11 | mixed_precision: 'fp16' 12 | enable_xformers_memory_efficient_attention: True 13 | gradient_checkpointing: True 14 | max_train_steps: 40000 15 | max_grad_norm: 1.0 16 | # lr 17 | learning_rate: 1e-5 18 | scale_lr: False 19 | lr_warmup_steps: 1 20 | lr_scheduler: 'constant' 21 | 22 | # optimizer 23 | use_8bit_adam: True 24 | adam_beta1: 0.9 25 | adam_beta2: 0.999 26 | adam_weight_decay: 1.0e-2 27 | adam_epsilon: 1.0e-8 28 | 29 | val: 30 | validation_steps: 500 31 | validation_steps_tuple: [1, 10, 20, 50, 80] 32 | 33 | 34 | noise_scheduler_kwargs: 35 | num_train_timesteps: 1000 36 | beta_start: 0.00085 37 | beta_end: 0.012 38 | beta_schedule: "linear" 39 | steps_offset: 1 40 | clip_sample: false 41 | 42 | base_model_path: './pretrained_model/stable-diffusion-v1-5' 43 | vae_model_path: './pretrained_model/sd-vae-ft-mse' 44 | image_encoder_path: './pretrained_model/image_encoder' 45 | mm_path: './pretrained_model/mm_sd_v15_v2.ckpt' 46 | 47 | train_bs: 1 48 | 49 | weight_dtype: 'fp16' # [fp16, fp32] 50 | uncond_ratio: 0.1 51 | noise_offset: 0.05 52 | snr_gamma: 5.0 53 | enable_zero_snr: True 54 | stage1_ckpt_dir: './exp_output/stage1' 55 | stage1_ckpt_step: 300000 56 | 57 | seed: 12580 58 | resume_from_checkpoint: '' 59 | checkpointing_steps: 2000 60 | exp_name: 'stage2' 61 | output_dir: './exp_output' -------------------------------------------------------------------------------- /pretrained_model/Put pre-trained weights here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/pretrained_model/Put pre-trained weights here.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | av==11.0.0 3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a 4 | decord==0.6.0 5 | diffusers==0.24.0 6 | einops==0.4.1 7 | gradio==4.24.0 8 | gradio_client==0.14.0 9 | imageio==2.33.0 10 | imageio-ffmpeg==0.4.9 11 | numpy==1.24.4 12 | omegaconf==2.2.3 13 | onnxruntime-gpu==1.16.3 14 | open-clip-torch==2.20.0 15 | opencv-contrib-python==4.8.1.78 16 | opencv-python==4.8.1.78 17 | Pillow==9.5.0 18 | scikit-image==0.21.0 19 | scikit-learn==1.3.2 20 | scipy==1.11.4 21 | torch==2.0.1 22 | torchdiffeq==0.2.3 23 | torchmetrics==1.2.1 24 | torchsde==0.2.5 25 | torchvision==0.15.2 26 | tqdm==4.66.1 27 | transformers==4.30.2 28 | xformers==0.0.22 29 | controlnet-aux==0.0.7 30 | mediapipe==0.10.11 31 | librosa==0.9.2 32 | ffmpeg-python==0.2.0 -------------------------------------------------------------------------------- /scripts/audio2vid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import ffmpeg 4 | import random 5 | from datetime import datetime 6 | from pathlib import Path 7 | from typing import List 8 | import subprocess 9 | import av 10 | import numpy as np 11 | import cv2 12 | import torch 13 | import torchvision 14 | from diffusers import AutoencoderKL, DDIMScheduler 15 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 16 | from einops import repeat 17 | from omegaconf import OmegaConf 18 | from PIL import Image 19 | from torchvision import transforms 20 | from transformers import CLIPVisionModelWithProjection 21 | 22 | from configs.prompts.test_cases import TestCasesDict 23 | from src.models.pose_guider import PoseGuider 24 | from src.models.unet_2d_condition import UNet2DConditionModel 25 | from src.models.unet_3d import UNet3DConditionModel 26 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 27 | from src.utils.util import get_fps, read_frames, save_videos_grid 28 | 29 | from src.audio_models.model import Audio2MeshModel 30 | from src.audio_models.pose_model import Audio2PoseModel 31 | from src.utils.audio_util import prepare_audio_feature 32 | from src.utils.mp_utils import LMKExtractor 33 | from src.utils.draw_util import FaceMeshVisualizer 34 | from src.utils.pose_util import project_points, smooth_pose_seq 35 | from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--config", type=str, default='./configs/prompts/animation_audio.yaml') 41 | parser.add_argument("-W", type=int, default=512) 42 | parser.add_argument("-H", type=int, default=512) 43 | parser.add_argument("-L", type=int) 44 | parser.add_argument("--seed", type=int, default=42) 45 | parser.add_argument("--cfg", type=float, default=3.5) 46 | parser.add_argument("--steps", type=int, default=25) 47 | parser.add_argument("--fps", type=int, default=30) 48 | parser.add_argument("-acc", "--accelerate", action='store_true') 49 | parser.add_argument("--fi_step", type=int, default=3) 50 | args = parser.parse_args() 51 | 52 | return args 53 | 54 | def main(): 55 | args = parse_args() 56 | 57 | config = OmegaConf.load(args.config) 58 | 59 | if config.weight_dtype == "fp16": 60 | weight_dtype = torch.float16 61 | else: 62 | weight_dtype = torch.float32 63 | 64 | audio_infer_config = OmegaConf.load(config.audio_inference_config) 65 | # prepare model 66 | a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) 67 | a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False) 68 | a2m_model.cuda().eval() 69 | 70 | a2p_model = Audio2PoseModel(audio_infer_config['a2p_model']) 71 | a2p_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2p_ckpt']), strict=False) 72 | a2p_model.cuda().eval() 73 | 74 | vae = AutoencoderKL.from_pretrained( 75 | config.pretrained_vae_path, 76 | ).to("cuda", dtype=weight_dtype) 77 | 78 | reference_unet = UNet2DConditionModel.from_pretrained( 79 | config.pretrained_base_model_path, 80 | subfolder="unet", 81 | ).to(dtype=weight_dtype, device="cuda") 82 | 83 | inference_config_path = config.inference_config 84 | infer_config = OmegaConf.load(inference_config_path) 85 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 86 | config.pretrained_base_model_path, 87 | config.motion_module_path, 88 | subfolder="unet", 89 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 90 | ).to(dtype=weight_dtype, device="cuda") 91 | 92 | 93 | pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention 94 | 95 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 96 | config.image_encoder_path 97 | ).to(dtype=weight_dtype, device="cuda") 98 | 99 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 100 | scheduler = DDIMScheduler(**sched_kwargs) 101 | 102 | generator = torch.manual_seed(args.seed) 103 | 104 | width, height = args.W, args.H 105 | 106 | # load pretrained weights 107 | denoising_unet.load_state_dict( 108 | torch.load(config.denoising_unet_path, map_location="cpu"), 109 | strict=False, 110 | ) 111 | reference_unet.load_state_dict( 112 | torch.load(config.reference_unet_path, map_location="cpu"), 113 | ) 114 | pose_guider.load_state_dict( 115 | torch.load(config.pose_guider_path, map_location="cpu"), 116 | ) 117 | 118 | pipe = Pose2VideoPipeline( 119 | vae=vae, 120 | image_encoder=image_enc, 121 | reference_unet=reference_unet, 122 | denoising_unet=denoising_unet, 123 | pose_guider=pose_guider, 124 | scheduler=scheduler, 125 | ) 126 | pipe = pipe.to("cuda", dtype=weight_dtype) 127 | 128 | date_str = datetime.now().strftime("%Y%m%d") 129 | time_str = datetime.now().strftime("%H%M") 130 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}" 131 | 132 | save_dir = Path(f"output/{date_str}/{save_dir_name}") 133 | save_dir.mkdir(exist_ok=True, parents=True) 134 | 135 | 136 | lmk_extractor = LMKExtractor() 137 | vis = FaceMeshVisualizer(forehead_edge=False) 138 | 139 | if args.accelerate: 140 | frame_inter_model = init_frame_interpolation_model() 141 | 142 | for ref_image_path in config["test_cases"].keys(): 143 | # Each ref_image may correspond to multiple actions 144 | for audio_path in config["test_cases"][ref_image_path]: 145 | ref_name = Path(ref_image_path).stem 146 | audio_name = Path(audio_path).stem 147 | 148 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 149 | ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR) 150 | ref_image_np = cv2.resize(ref_image_np, (args.H, args.W)) 151 | 152 | face_result = lmk_extractor(ref_image_np) 153 | assert face_result is not None, "No face detected." 154 | lmks = face_result['lmks'].astype(np.float32) 155 | ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) 156 | 157 | sample = prepare_audio_feature(audio_path, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) 158 | sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() 159 | sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) 160 | 161 | # inference 162 | pred = a2m_model.infer(sample['audio_feature'], sample['seq_len']) 163 | pred = pred.squeeze().detach().cpu().numpy() 164 | pred = pred.reshape(pred.shape[0], -1, 3) 165 | pred = pred + face_result['lmks3d'] 166 | 167 | if 'pose_temp' in config and config['pose_temp'] is not None: 168 | pose_seq = np.load(config['pose_temp']) 169 | mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0) 170 | pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']] 171 | else: 172 | id_seed = random.randint(0, 99) 173 | id_seed = torch.LongTensor([id_seed]).cuda() 174 | 175 | # Currently, only inference up to a maximum length of 10 seconds is supported. 176 | chunk_duration = 5 # 5 seconds 177 | sr = 16000 178 | fps = 30 179 | chunk_size = sr * chunk_duration 180 | 181 | audio_chunks = list(sample['audio_feature'].split(chunk_size, dim=1)) 182 | seq_len_list = [chunk_duration*fps] * (len(audio_chunks) - 1) + [sample['seq_len'] % (chunk_duration*fps)] # 30 fps 183 | 184 | audio_chunks[-2] = torch.cat((audio_chunks[-2], audio_chunks[-1]), dim=1) 185 | seq_len_list[-2] = seq_len_list[-2] + seq_len_list[-1] 186 | del audio_chunks[-1] 187 | del seq_len_list[-1] 188 | 189 | pose_seq = [] 190 | for audio, seq_len in zip(audio_chunks, seq_len_list): 191 | pose_seq_chunk = a2p_model.infer(audio, seq_len, id_seed) 192 | pose_seq_chunk = pose_seq_chunk.squeeze().detach().cpu().numpy() 193 | pose_seq_chunk[:, :3] *= 0.5 194 | pose_seq.append(pose_seq_chunk) 195 | 196 | pose_seq = np.concatenate(pose_seq, 0) 197 | pose_seq = smooth_pose_seq(pose_seq, 7) 198 | 199 | # project 3D mesh to 2D landmark 200 | projected_vertices = project_points(pred, face_result['trans_mat'], pose_seq, [height, width]) 201 | 202 | pose_images = [] 203 | for i, verts in enumerate(projected_vertices): 204 | lmk_img = vis.draw_landmarks((width, height), verts, normed=False) 205 | pose_images.append(lmk_img) 206 | 207 | pose_list = [] 208 | pose_tensor_list = [] 209 | print(f"pose video has {len(pose_images)} frames, with {args.fps} fps") 210 | pose_transform = transforms.Compose( 211 | [transforms.Resize((height, width)), transforms.ToTensor()] 212 | ) 213 | args_L = len(pose_images) if args.L is None else args.L 214 | for pose_image_np in pose_images[: args_L]: 215 | pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB)) 216 | pose_tensor_list.append(pose_transform(pose_image_pil)) 217 | sub_step = args.fi_step if args.accelerate else 1 218 | for pose_image_np in pose_images[: args_L: sub_step]: 219 | pose_image_np = cv2.resize(pose_image_np, (width, height)) 220 | pose_list.append(pose_image_np) 221 | 222 | pose_list = np.array(pose_list) 223 | 224 | video_length = len(pose_list) 225 | 226 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 227 | pose_tensor = pose_tensor.transpose(0, 1) 228 | pose_tensor = pose_tensor.unsqueeze(0) 229 | 230 | video = pipe( 231 | ref_image_pil, 232 | pose_list, 233 | ref_pose, 234 | width, 235 | height, 236 | video_length, 237 | args.steps, 238 | args.cfg, 239 | generator=generator, 240 | ).videos 241 | 242 | if args.accelerate: 243 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=args.fi_step-1) 244 | 245 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 246 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze( 247 | 0 248 | ) # (1, c, 1, h, w) 249 | ref_image_tensor = repeat( 250 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video.shape[2] 251 | ) 252 | 253 | video = torch.cat([ref_image_tensor, pose_tensor[:,:,:video.shape[2]], video], dim=0) 254 | save_path = f"{save_dir}/{ref_name}_{audio_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}_noaudio.mp4" 255 | save_videos_grid( 256 | video, 257 | save_path, 258 | n_rows=3, 259 | fps=args.fps, 260 | ) 261 | 262 | stream = ffmpeg.input(save_path) 263 | audio = ffmpeg.input(audio_path) 264 | ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run() 265 | os.remove(save_path) 266 | 267 | if __name__ == "__main__": 268 | main() 269 | -------------------------------------------------------------------------------- /scripts/generate_ref_pose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | 5 | import numpy as np 6 | import cv2 7 | from tqdm import tqdm 8 | from scipy.spatial.transform import Rotation as R 9 | from scipy.interpolate import interp1d 10 | 11 | from src.utils.mp_utils import LMKExtractor 12 | from src.utils.pose_util import smooth_pose_seq, matrix_to_euler_and_translation 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--ref_video", type=str, default='', help='path of input video') 18 | parser.add_argument("--save_path", type=str, default='', help='path to save pose') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | 26 | lmk_extractor = LMKExtractor() 27 | 28 | cap = cv2.VideoCapture(args.ref_video) 29 | 30 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 31 | fps = cap.get(cv2.CAP_PROP_FPS) 32 | 33 | pbar = tqdm(range(total_frames), desc="processing ...") 34 | 35 | trans_mat_list = [] 36 | while cap.isOpened(): 37 | ret, frame = cap.read() 38 | if not ret: 39 | break 40 | 41 | pbar.update(1) 42 | result = lmk_extractor(frame) 43 | if result is None: 44 | break 45 | trans_mat_list.append(result['trans_mat'].astype(np.float32)) 46 | cap.release() 47 | 48 | total_frames = len(trans_mat_list) 49 | 50 | trans_mat_arr = np.array(trans_mat_list) 51 | 52 | # compute delta pose 53 | trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0]) 54 | pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) 55 | 56 | for i in range(pose_arr.shape[0]): 57 | pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i] 58 | euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat) 59 | pose_arr[i, :3] = euler_angles 60 | pose_arr[i, 3:6] = translation_vector 61 | 62 | # interpolate to 30 fps 63 | new_fps = 30 64 | old_time = np.linspace(0, total_frames / fps, total_frames) 65 | new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps)) 66 | 67 | pose_arr_interp = np.zeros((len(new_time), 6)) 68 | for i in range(6): 69 | interp_func = interp1d(old_time, pose_arr[:, i]) 70 | pose_arr_interp[:, i] = interp_func(new_time) 71 | 72 | pose_arr_smooth = smooth_pose_seq(pose_arr_interp) 73 | np.save(args.save_path, pose_arr_smooth) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | 79 | -------------------------------------------------------------------------------- /scripts/pose2vid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import ffmpeg 4 | from datetime import datetime 5 | from pathlib import Path 6 | from typing import List 7 | import subprocess 8 | import av 9 | import numpy as np 10 | import cv2 11 | import torch 12 | import torchvision 13 | from diffusers import AutoencoderKL, DDIMScheduler 14 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 15 | from einops import repeat 16 | from omegaconf import OmegaConf 17 | from PIL import Image 18 | from torchvision import transforms 19 | from transformers import CLIPVisionModelWithProjection 20 | 21 | from configs.prompts.test_cases import TestCasesDict 22 | from src.models.pose_guider import PoseGuider 23 | from src.models.unet_2d_condition import UNet2DConditionModel 24 | from src.models.unet_3d import UNet3DConditionModel 25 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 26 | from src.utils.util import get_fps, read_frames, save_videos_grid 27 | from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool 28 | 29 | from src.utils.mp_utils import LMKExtractor 30 | from src.utils.draw_util import FaceMeshVisualizer 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--config", type=str, default='./configs/prompts/animation.yaml') 36 | parser.add_argument("-W", type=int, default=512) 37 | parser.add_argument("-H", type=int, default=512) 38 | parser.add_argument("-L", type=int) 39 | parser.add_argument("--seed", type=int, default=42) 40 | parser.add_argument("--cfg", type=float, default=3.5) 41 | parser.add_argument("--steps", type=int, default=25) 42 | parser.add_argument("--fps", type=int) 43 | parser.add_argument("-acc", "--accelerate", action='store_true') 44 | parser.add_argument("--fi_step", type=int, default=3) 45 | args = parser.parse_args() 46 | 47 | return args 48 | 49 | def main(): 50 | args = parse_args() 51 | 52 | config = OmegaConf.load(args.config) 53 | 54 | if config.weight_dtype == "fp16": 55 | weight_dtype = torch.float16 56 | else: 57 | weight_dtype = torch.float32 58 | 59 | vae = AutoencoderKL.from_pretrained( 60 | config.pretrained_vae_path, 61 | ).to("cuda", dtype=weight_dtype) 62 | 63 | reference_unet = UNet2DConditionModel.from_pretrained( 64 | config.pretrained_base_model_path, 65 | subfolder="unet", 66 | ).to(dtype=weight_dtype, device="cuda") 67 | 68 | inference_config_path = config.inference_config 69 | infer_config = OmegaConf.load(inference_config_path) 70 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 71 | config.pretrained_base_model_path, 72 | config.motion_module_path, 73 | subfolder="unet", 74 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 75 | ).to(dtype=weight_dtype, device="cuda") 76 | 77 | pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention 78 | 79 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 80 | config.image_encoder_path 81 | ).to(dtype=weight_dtype, device="cuda") 82 | 83 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 84 | scheduler = DDIMScheduler(**sched_kwargs) 85 | 86 | generator = torch.manual_seed(args.seed) 87 | 88 | width, height = args.W, args.H 89 | 90 | # load pretrained weights 91 | denoising_unet.load_state_dict( 92 | torch.load(config.denoising_unet_path, map_location="cpu"), 93 | strict=False, 94 | ) 95 | reference_unet.load_state_dict( 96 | torch.load(config.reference_unet_path, map_location="cpu"), 97 | ) 98 | pose_guider.load_state_dict( 99 | torch.load(config.pose_guider_path, map_location="cpu"), 100 | ) 101 | 102 | pipe = Pose2VideoPipeline( 103 | vae=vae, 104 | image_encoder=image_enc, 105 | reference_unet=reference_unet, 106 | denoising_unet=denoising_unet, 107 | pose_guider=pose_guider, 108 | scheduler=scheduler, 109 | ) 110 | pipe = pipe.to("cuda", dtype=weight_dtype) 111 | 112 | date_str = datetime.now().strftime("%Y%m%d") 113 | time_str = datetime.now().strftime("%H%M") 114 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}" 115 | 116 | save_dir = Path(f"output/{date_str}/{save_dir_name}") 117 | save_dir.mkdir(exist_ok=True, parents=True) 118 | 119 | 120 | lmk_extractor = LMKExtractor() 121 | vis = FaceMeshVisualizer(forehead_edge=False) 122 | 123 | if args.accelerate: 124 | frame_inter_model = init_frame_interpolation_model() 125 | 126 | for ref_image_path in config["test_cases"].keys(): 127 | # Each ref_image may correspond to multiple actions 128 | for pose_video_path in config["test_cases"][ref_image_path]: 129 | ref_name = Path(ref_image_path).stem 130 | pose_name = Path(pose_video_path).stem 131 | 132 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 133 | ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR) 134 | ref_image_np = cv2.resize(ref_image_np, (args.H, args.W)) 135 | 136 | face_result = lmk_extractor(ref_image_np) 137 | assert face_result is not None, "Can not detect a face in the reference image." 138 | lmks = face_result['lmks'].astype(np.float32) 139 | ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) 140 | 141 | pose_list = [] 142 | pose_tensor_list = [] 143 | pose_images = read_frames(pose_video_path) 144 | src_fps = get_fps(pose_video_path) 145 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") 146 | pose_transform = transforms.Compose( 147 | [transforms.Resize((height, width)), transforms.ToTensor()] 148 | ) 149 | args_L = len(pose_images) if args.L is None else args.L 150 | for pose_image_pil in pose_images[: args_L]: 151 | pose_tensor_list.append(pose_transform(pose_image_pil)) 152 | sub_step = args.fi_step if args.accelerate else 1 153 | for pose_image_pil in pose_images[: args.L: sub_step]: 154 | pose_image_np = cv2.cvtColor(np.array(pose_image_pil), cv2.COLOR_RGB2BGR) 155 | pose_image_np = cv2.resize(pose_image_np, (width, height)) 156 | pose_list.append(pose_image_np) 157 | 158 | pose_list = np.array(pose_list) 159 | 160 | video_length = len(pose_list) 161 | 162 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 163 | pose_tensor = pose_tensor.transpose(0, 1) 164 | pose_tensor = pose_tensor.unsqueeze(0) 165 | 166 | video = pipe( 167 | ref_image_pil, 168 | pose_list, 169 | ref_pose, 170 | width, 171 | height, 172 | video_length, 173 | args.steps, 174 | args.cfg, 175 | generator=generator, 176 | ).videos 177 | 178 | if args.accelerate: 179 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=args.fi_step-1) 180 | 181 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 182 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze( 183 | 0 184 | ) # (1, c, 1, h, w) 185 | ref_image_tensor = repeat( 186 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video.shape[2] 187 | ) 188 | 189 | video = torch.cat([ref_image_tensor, pose_tensor[:,:,:video.shape[2]], video], dim=0) 190 | save_path = f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}_noaudio.mp4" 191 | save_videos_grid( 192 | video, 193 | save_path, 194 | n_rows=3, 195 | fps=src_fps if args.fps is None else args.fps, 196 | ) 197 | 198 | audio_output = 'audio_from_video.aac' 199 | # extract audio 200 | ffmpeg.input(pose_video_path).output(audio_output, acodec='copy').run() 201 | # merge audio and video 202 | stream = ffmpeg.input(save_path) 203 | audio = ffmpeg.input(audio_output) 204 | ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run() 205 | 206 | os.remove(save_path) 207 | os.remove(audio_output) 208 | 209 | if __name__ == "__main__": 210 | main() 211 | -------------------------------------------------------------------------------- /scripts/prepare_video.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from src.utils.mp_utils import LMKExtractor 3 | from src.utils.draw_util import FaceMeshVisualizer 4 | 5 | import os 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | def crop_video(input_file, output_file): 11 | width = 800 12 | height = 800 13 | x = 550 14 | y = 50 15 | 16 | ffmpeg_cmd = f'ffmpeg -i {input_file} -filter:v "crop={width}:{height}:{x}:{y}" -c:a copy {output_file}' 17 | subprocess.call(ffmpeg_cmd, shell=True) 18 | 19 | 20 | 21 | def extract_and_draw_lmks(input_file, output_file): 22 | lmk_extractor = LMKExtractor() 23 | vis = FaceMeshVisualizer() 24 | 25 | cap = cv2.VideoCapture(input_file) 26 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 27 | fps = cap.get(cv2.CAP_PROP_FPS) 28 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 29 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 30 | 31 | # Define the codec and create VideoWriter object 32 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 33 | out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) 34 | 35 | write_ref_img = False 36 | for i in range(200): 37 | ret, frame = cap.read() 38 | 39 | if not ret: 40 | break 41 | 42 | if not write_ref_img: 43 | write_ref_img = True 44 | cv2.imwrite(os.path.join(os.path.dirname(output_file), "ref_img.jpg"), frame) 45 | 46 | result = lmk_extractor(frame) 47 | 48 | if result is not None: 49 | lmks = result['lmks'].astype(np.float32) 50 | lmk_img = vis.draw_landmarks((frame.shape[1], frame.shape[0]), lmks, normed=True) 51 | out.write(lmk_img) 52 | else: 53 | print('multiple faces in the frame') 54 | 55 | 56 | if __name__ == "__main__": 57 | 58 | input_file = "./Moore-AnimateAnyone/examples/video.mp4" 59 | lmk_video_path = "./Moore-AnimateAnyone/examples/pose.mp4" 60 | 61 | # crop video 62 | # crop_video(input_file, output_file) 63 | 64 | # extract and draw lmks 65 | extract_and_draw_lmks(input_file, lmk_video_path) -------------------------------------------------------------------------------- /scripts/preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import cv2 6 | from tqdm import tqdm 7 | import glob 8 | import json 9 | 10 | from src.utils.mp_utils import LMKExtractor 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--input_dir", type=str, default='', help='path of dataset') 16 | parser.add_argument("--output_dir", type=str, default='', help='path to save extracted annotations') 17 | parser.add_argument("--training_json", type=str, default='', help='path to save training json') 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def generate_training_json_mesh(video_dir, face_info_dir, res_json_path, min_clip_length=30): 23 | video_name_list = sorted(os.listdir(face_info_dir)) 24 | res_data_dic = {} 25 | 26 | pbar = tqdm(range(len(video_name_list))) 27 | 28 | 29 | for video_index, video_name in enumerate(video_name_list): 30 | pbar.update(1) 31 | 32 | tem_dic = {} 33 | tem_tem_dic = {} 34 | video_clip_dir = os.path.join(video_dir, video_name) 35 | lmks_clip_dir = os.path.join(face_info_dir, video_name) 36 | 37 | video_clip_num = 1 38 | video_data_list = [] 39 | 40 | frame_path_list = sorted(glob.glob(os.path.join(video_clip_dir, '*.png'))) 41 | lmks_path_list = sorted(glob.glob(os.path.join(lmks_clip_dir, '*lmks.npy'))) 42 | 43 | min_len = min(len(frame_path_list), len(lmks_path_list)) 44 | frame_path_list = frame_path_list[:min_len] 45 | lmks_path_list = lmks_path_list[:min_len] 46 | 47 | 48 | if min_len < min_clip_length: 49 | info = 'min length: {} {}'.format(video_name, min_len) 50 | video_clip_num -= 1 51 | continue 52 | 53 | first_frame_basename = os.path.basename(frame_path_list[0]).split('.')[0] 54 | first_lmks_basename = os.path.basename(lmks_path_list[0]).split('_')[0] 55 | last_frame_basename = os.path.basename(frame_path_list[-1]).split('.')[0] 56 | last_lmks_basename = os.path.basename(lmks_path_list[-1]).split('_')[0] 57 | 58 | if (first_frame_basename != first_lmks_basename) or (last_frame_basename != last_lmks_basename): 59 | info = 'different length skip: {} , length {}/{}, frame/lmks'.format(video_name, len(frame_path_list), len(lmks_path_list)) 60 | video_clip_num -= 1 61 | continue 62 | 63 | frame_name_list = [os.path.join(video_name, os.path.basename(item)) for item in frame_path_list] 64 | 65 | tem_tem_dic['frame_name_list'] = frame_name_list 66 | tem_tem_dic['frame_path_list'] = frame_path_list 67 | tem_tem_dic['lmks_list'] = lmks_path_list 68 | video_data_list.append(tem_tem_dic) 69 | 70 | tem_dic['video_clip_num'] = video_clip_num 71 | tem_dic['clip_data_list'] = video_data_list 72 | res_data_dic[video_name] = tem_dic 73 | 74 | with open(res_json_path, 'w') as f: 75 | json.dump(res_data_dic, f) 76 | 77 | 78 | def main(): 79 | args = parse_args() 80 | 81 | os.makedirs(args.output_dir, exist_ok=True) 82 | folders = [f.path for f in os.scandir(args.input_dir) if f.is_dir()] 83 | folders.sort() 84 | 85 | lmk_extractor = LMKExtractor() 86 | 87 | pbar = tqdm(range(len(folders)), desc="processing ...") 88 | for folder in folders: 89 | pbar.update(1) 90 | output_subdir = os.path.join(args.output_dir, os.path.basename(folder)) 91 | os.makedirs(output_subdir, exist_ok=True) 92 | for img_file in sorted(glob.glob(os.path.join(folder, "*.png"))): 93 | base = os.path.basename(img_file) 94 | lmks_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_lmks.npy") 95 | lmks3d_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_lmks3d.npy") 96 | trans_mat_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_trans_mat.npy") 97 | bs_output_file = os.path.join(output_subdir, os.path.splitext(base)[0] + "_bs.npy") 98 | 99 | img = cv2.imread(img_file) 100 | result = lmk_extractor(img) 101 | 102 | if result is not None: 103 | np.save(lmks_output_file, result['lmks'].astype(np.float32)) 104 | np.save(lmks3d_output_file, result['lmks3d'].astype(np.float32)) 105 | np.save(trans_mat_output_file, result['trans_mat'].astype(np.float32)) 106 | np.save(bs_output_file, np.array(result['bs']).astype(np.float32)) 107 | 108 | # write json 109 | generate_training_json_mesh(args.input_dir, args.output_dir, args.training_json, min_clip_length=30) 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | 115 | 116 | -------------------------------------------------------------------------------- /scripts/vid2pose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ffmpeg 3 | from PIL import Image 4 | import cv2 5 | from tqdm import tqdm 6 | 7 | from src.utils.util import get_fps, read_frames, save_videos_from_pil 8 | import numpy as np 9 | from src.utils.draw_util import FaceMeshVisualizer 10 | from src.utils.mp_utils import LMKExtractor 11 | 12 | if __name__ == "__main__": 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--video_path", type=str) 17 | args = parser.parse_args() 18 | 19 | if not os.path.exists(args.video_path): 20 | raise ValueError(f"Path: {args.video_path} not exists") 21 | 22 | dir_path, video_name = ( 23 | os.path.dirname(args.video_path), 24 | os.path.splitext(os.path.basename(args.video_path))[0], 25 | ) 26 | out_path = os.path.join(dir_path, video_name + "_kps_noaudio.mp4") 27 | 28 | lmk_extractor = LMKExtractor() 29 | vis = FaceMeshVisualizer(forehead_edge=False) 30 | 31 | width = 512 32 | height = 512 33 | 34 | fps = get_fps(args.video_path) 35 | frames = read_frames(args.video_path) 36 | kps_results = [] 37 | for i, frame_pil in enumerate(tqdm(frames)): 38 | image_np = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR) 39 | image_np = cv2.resize(image_np, (height, width)) 40 | face_result = lmk_extractor(image_np) 41 | try: 42 | lmks = face_result['lmks'].astype(np.float32) 43 | pose_img = vis.draw_landmarks((image_np.shape[1], image_np.shape[0]), lmks, normed=True) 44 | pose_img = Image.fromarray(cv2.cvtColor(pose_img, cv2.COLOR_BGR2RGB)) 45 | except: 46 | pose_img = kps_results[-1] 47 | 48 | kps_results.append(pose_img) 49 | 50 | print(out_path.replace('_noaudio.mp4', '.mp4')) 51 | save_videos_from_pil(kps_results, out_path, fps=fps) 52 | 53 | audio_output = 'audio_from_video.aac' 54 | ffmpeg.input(args.video_path).output(audio_output, acodec='copy').run() 55 | stream = ffmpeg.input(out_path) 56 | audio = ffmpeg.input(audio_output) 57 | ffmpeg.output(stream.video, audio.audio, out_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run() 58 | os.remove(out_path) 59 | os.remove(audio_output) 60 | -------------------------------------------------------------------------------- /scripts/vid2vid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import ffmpeg 4 | from datetime import datetime 5 | from pathlib import Path 6 | from typing import List 7 | import subprocess 8 | import av 9 | import numpy as np 10 | import cv2 11 | import torch 12 | import torchvision 13 | from diffusers import AutoencoderKL, DDIMScheduler 14 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 15 | from einops import repeat 16 | from omegaconf import OmegaConf 17 | from PIL import Image 18 | from torchvision import transforms 19 | from transformers import CLIPVisionModelWithProjection 20 | 21 | from configs.prompts.test_cases import TestCasesDict 22 | from src.models.pose_guider import PoseGuider 23 | from src.models.unet_2d_condition import UNet2DConditionModel 24 | from src.models.unet_3d import UNet3DConditionModel 25 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 26 | from src.utils.util import get_fps, read_frames, save_videos_grid 27 | 28 | from src.utils.mp_utils import LMKExtractor 29 | from src.utils.draw_util import FaceMeshVisualizer 30 | from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix, smooth_pose_seq 31 | from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--config", type=str, default='./configs/prompts/animation_facereenac.yaml') 37 | parser.add_argument("-W", type=int, default=512) 38 | parser.add_argument("-H", type=int, default=512) 39 | parser.add_argument("-L", type=int) 40 | parser.add_argument("--seed", type=int, default=42) 41 | parser.add_argument("--cfg", type=float, default=3.5) 42 | parser.add_argument("--steps", type=int, default=25) 43 | parser.add_argument("--fps", type=int) 44 | parser.add_argument("-acc", "--accelerate", action='store_true') 45 | parser.add_argument("--fi_step", type=int, default=3) 46 | args = parser.parse_args() 47 | 48 | return args 49 | 50 | def main(): 51 | args = parse_args() 52 | 53 | config = OmegaConf.load(args.config) 54 | 55 | if config.weight_dtype == "fp16": 56 | weight_dtype = torch.float16 57 | else: 58 | weight_dtype = torch.float32 59 | 60 | vae = AutoencoderKL.from_pretrained( 61 | config.pretrained_vae_path, 62 | ).to("cuda", dtype=weight_dtype) 63 | 64 | reference_unet = UNet2DConditionModel.from_pretrained( 65 | config.pretrained_base_model_path, 66 | subfolder="unet", 67 | ).to(dtype=weight_dtype, device="cuda") 68 | 69 | inference_config_path = config.inference_config 70 | infer_config = OmegaConf.load(inference_config_path) 71 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 72 | config.pretrained_base_model_path, 73 | config.motion_module_path, 74 | subfolder="unet", 75 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 76 | ).to(dtype=weight_dtype, device="cuda") 77 | 78 | pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention 79 | 80 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 81 | config.image_encoder_path 82 | ).to(dtype=weight_dtype, device="cuda") 83 | 84 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 85 | scheduler = DDIMScheduler(**sched_kwargs) 86 | 87 | generator = torch.manual_seed(args.seed) 88 | 89 | width, height = args.W, args.H 90 | 91 | # load pretrained weights 92 | denoising_unet.load_state_dict( 93 | torch.load(config.denoising_unet_path, map_location="cpu"), 94 | strict=False, 95 | ) 96 | reference_unet.load_state_dict( 97 | torch.load(config.reference_unet_path, map_location="cpu"), 98 | ) 99 | pose_guider.load_state_dict( 100 | torch.load(config.pose_guider_path, map_location="cpu"), 101 | ) 102 | 103 | pipe = Pose2VideoPipeline( 104 | vae=vae, 105 | image_encoder=image_enc, 106 | reference_unet=reference_unet, 107 | denoising_unet=denoising_unet, 108 | pose_guider=pose_guider, 109 | scheduler=scheduler, 110 | ) 111 | pipe = pipe.to("cuda", dtype=weight_dtype) 112 | 113 | date_str = datetime.now().strftime("%Y%m%d") 114 | time_str = datetime.now().strftime("%H%M") 115 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}" 116 | 117 | save_dir = Path(f"output/{date_str}/{save_dir_name}") 118 | save_dir.mkdir(exist_ok=True, parents=True) 119 | 120 | 121 | lmk_extractor = LMKExtractor() 122 | vis = FaceMeshVisualizer(forehead_edge=False) 123 | 124 | if args.accelerate: 125 | frame_inter_model = init_frame_interpolation_model() 126 | 127 | for ref_image_path in config["test_cases"].keys(): 128 | # Each ref_image may correspond to multiple actions 129 | for source_video_path in config["test_cases"][ref_image_path]: 130 | ref_name = Path(ref_image_path).stem 131 | pose_name = Path(source_video_path).stem 132 | 133 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 134 | ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR) 135 | ref_image_np = cv2.resize(ref_image_np, (args.H, args.W)) 136 | 137 | face_result = lmk_extractor(ref_image_np) 138 | assert face_result is not None, "Can not detect a face in the reference image." 139 | lmks = face_result['lmks'].astype(np.float32) 140 | ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) 141 | 142 | 143 | 144 | source_images = read_frames(source_video_path) 145 | src_fps = get_fps(source_video_path) 146 | print(f"source video has {len(source_images)} frames, with {src_fps} fps") 147 | pose_transform = transforms.Compose( 148 | [transforms.Resize((height, width)), transforms.ToTensor()] 149 | ) 150 | 151 | step = 1 152 | if src_fps == 60: 153 | src_fps = 30 154 | step = 2 155 | 156 | pose_trans_list = [] 157 | verts_list = [] 158 | bs_list = [] 159 | src_tensor_list = [] 160 | args_L = len(source_images) if args.L is None else args.L*step 161 | for src_image_pil in source_images[: args_L: step]: 162 | src_tensor_list.append(pose_transform(src_image_pil)) 163 | sub_step = step*args.fi_step if args.accelerate else step 164 | for src_image_pil in source_images[: args_L: sub_step]: 165 | src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR) 166 | frame_height, frame_width, _ = src_img_np.shape 167 | src_img_result = lmk_extractor(src_img_np) 168 | if src_img_result is None: 169 | break 170 | pose_trans_list.append(src_img_result['trans_mat']) 171 | verts_list.append(src_img_result['lmks3d']) 172 | bs_list.append(src_img_result['bs']) 173 | 174 | trans_mat_arr = np.array(pose_trans_list) 175 | verts_arr = np.array(verts_list) 176 | bs_arr = np.array(bs_list) 177 | min_bs_idx = np.argmin(bs_arr.sum(1)) 178 | 179 | # compute delta pose 180 | pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) 181 | 182 | for i in range(pose_arr.shape[0]): 183 | euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source 184 | pose_arr[i, :3] = euler_angles 185 | pose_arr[i, 3:6] = translation_vector 186 | 187 | init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt 188 | pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt) 189 | 190 | pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3) 191 | pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])] 192 | pose_mat_smooth = np.array(pose_mat_smooth) 193 | 194 | # face retarget 195 | verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d'] 196 | # project 3D mesh to 2D landmark 197 | projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width]) 198 | 199 | pose_list = [] 200 | for i, verts in enumerate(projected_vertices): 201 | lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False) 202 | pose_image_np = cv2.resize(lmk_img, (width, height)) 203 | pose_list.append(pose_image_np) 204 | 205 | pose_list = np.array(pose_list) 206 | 207 | video_length = len(pose_list) 208 | 209 | src_tensor = torch.stack(src_tensor_list, dim=0) # (f, c, h, w) 210 | src_tensor = src_tensor.transpose(0, 1) 211 | src_tensor = src_tensor.unsqueeze(0) 212 | 213 | video = pipe( 214 | ref_image_pil, 215 | pose_list, 216 | ref_pose, 217 | width, 218 | height, 219 | video_length, 220 | args.steps, 221 | args.cfg, 222 | generator=generator, 223 | ).videos 224 | 225 | if args.accelerate: 226 | video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=args.fi_step-1) 227 | 228 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 229 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze( 230 | 0 231 | ) # (1, c, 1, h, w) 232 | ref_image_tensor = repeat( 233 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=video.shape[2] 234 | ) 235 | 236 | video = torch.cat([ref_image_tensor, video, src_tensor[:,:,:video.shape[2]]], dim=0) 237 | save_path = f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}_noaudio.mp4" 238 | save_videos_grid( 239 | video, 240 | save_path, 241 | n_rows=3, 242 | fps=src_fps if args.fps is None else args.fps, 243 | ) 244 | 245 | audio_output = 'audio_from_video.aac' 246 | # extract audio 247 | ffmpeg.input(source_video_path).output(audio_output, acodec='copy').run() 248 | # merge audio and video 249 | stream = ffmpeg.input(save_path) 250 | audio = ffmpeg.input(audio_output) 251 | ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run() 252 | 253 | os.remove(save_path) 254 | os.remove(audio_output) 255 | 256 | if __name__ == "__main__": 257 | main() 258 | -------------------------------------------------------------------------------- /src/audio_models/mish.py: -------------------------------------------------------------------------------- 1 | """ 2 | Applies the mish function element-wise: 3 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 4 | """ 5 | 6 | # import pytorch 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | @torch.jit.script 12 | def mish(input): 13 | """ 14 | Applies the mish function element-wise: 15 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 16 | See additional documentation for mish class. 17 | """ 18 | return input * torch.tanh(F.softplus(input)) 19 | 20 | class Mish(nn.Module): 21 | """ 22 | Applies the mish function element-wise: 23 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 24 | 25 | Shape: 26 | - Input: (N, *) where * means, any number of additional 27 | dimensions 28 | - Output: (N, *), same shape as the input 29 | 30 | Examples: 31 | >>> m = Mish() 32 | >>> input = torch.randn(2) 33 | >>> output = m(input) 34 | 35 | Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html 36 | """ 37 | 38 | def __init__(self): 39 | """ 40 | Init method. 41 | """ 42 | super().__init__() 43 | 44 | def forward(self, input): 45 | """ 46 | Forward pass of the function. 47 | """ 48 | if torch.__version__ >= "1.9": 49 | return F.mish(input) 50 | else: 51 | return mish(input) -------------------------------------------------------------------------------- /src/audio_models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import Wav2Vec2Config 6 | 7 | from .torch_utils import get_mask_from_lengths 8 | from .wav2vec2 import Wav2Vec2Model 9 | 10 | 11 | class Audio2MeshModel(nn.Module): 12 | def __init__( 13 | self, 14 | config 15 | ): 16 | super().__init__() 17 | out_dim = config['out_dim'] 18 | latent_dim = config['latent_dim'] 19 | model_path = config['model_path'] 20 | only_last_fetures = config['only_last_fetures'] 21 | from_pretrained = config['from_pretrained'] 22 | 23 | self._only_last_features = only_last_fetures 24 | 25 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) 26 | if from_pretrained: 27 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) 28 | else: 29 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) 30 | self.audio_encoder.feature_extractor._freeze_parameters() 31 | 32 | hidden_size = self.audio_encoder_config.hidden_size 33 | 34 | self.in_fn = nn.Linear(hidden_size, latent_dim) 35 | 36 | self.out_fn = nn.Linear(latent_dim, out_dim) 37 | nn.init.constant_(self.out_fn.weight, 0) 38 | nn.init.constant_(self.out_fn.bias, 0) 39 | 40 | def forward(self, audio, label, audio_len=None): 41 | attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None 42 | 43 | seq_len = label.shape[1] 44 | 45 | embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True, 46 | attention_mask=attention_mask) 47 | 48 | if self._only_last_features: 49 | hidden_states = embeddings.last_hidden_state 50 | else: 51 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) 52 | 53 | layer_in = self.in_fn(hidden_states) 54 | out = self.out_fn(layer_in) 55 | 56 | return out, None 57 | 58 | def infer(self, input_value, seq_len): 59 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) 60 | 61 | if self._only_last_features: 62 | hidden_states = embeddings.last_hidden_state 63 | else: 64 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) 65 | 66 | layer_in = self.in_fn(hidden_states) 67 | out = self.out_fn(layer_in) 68 | 69 | return out 70 | 71 | 72 | -------------------------------------------------------------------------------- /src/audio_models/pose_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from transformers import Wav2Vec2Config 6 | 7 | from .torch_utils import get_mask_from_lengths 8 | from .wav2vec2 import Wav2Vec2Model 9 | 10 | 11 | def init_biased_mask(n_head, max_seq_len, period): 12 | def get_slopes(n): 13 | def get_slopes_power_of_2(n): 14 | start = (2**(-2**-(math.log2(n)-3))) 15 | ratio = start 16 | return [start*ratio**i for i in range(n)] 17 | if math.log2(n).is_integer(): 18 | return get_slopes_power_of_2(n) 19 | else: 20 | closest_power_of_2 = 2**math.floor(math.log2(n)) 21 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2] 22 | slopes = torch.Tensor(get_slopes(n_head)) 23 | bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period) 24 | bias = - torch.flip(bias,dims=[0]) 25 | alibi = torch.zeros(max_seq_len, max_seq_len) 26 | for i in range(max_seq_len): 27 | alibi[i, :i+1] = bias[-(i+1):] 28 | alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) 29 | mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1) 30 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 31 | mask = mask.unsqueeze(0) + alibi 32 | return mask 33 | 34 | 35 | def enc_dec_mask(device, T, S): 36 | mask = torch.ones(T, S) 37 | for i in range(T): 38 | mask[i, i] = 0 39 | return (mask==1).to(device=device) 40 | 41 | 42 | class PositionalEncoding(nn.Module): 43 | def __init__(self, d_model, max_len=600): 44 | super(PositionalEncoding, self).__init__() 45 | pe = torch.zeros(max_len, d_model) 46 | position = torch.arange(0, max_len).unsqueeze(1).float() 47 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 48 | pe[:, 0::2] = torch.sin(position * div_term) 49 | pe[:, 1::2] = torch.cos(position * div_term) 50 | pe = pe.unsqueeze(0) 51 | self.register_buffer('pe', pe) 52 | 53 | def forward(self, x): 54 | x = x + self.pe[:, :x.size(1)] 55 | return x 56 | 57 | 58 | class Audio2PoseModel(nn.Module): 59 | def __init__( 60 | self, 61 | config 62 | ): 63 | 64 | super().__init__() 65 | 66 | latent_dim = config['latent_dim'] 67 | model_path = config['model_path'] 68 | only_last_fetures = config['only_last_fetures'] 69 | from_pretrained = config['from_pretrained'] 70 | out_dim = config['out_dim'] 71 | 72 | self.out_dim = out_dim 73 | 74 | self._only_last_features = only_last_fetures 75 | 76 | self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) 77 | if from_pretrained: 78 | self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) 79 | else: 80 | self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) 81 | self.audio_encoder.feature_extractor._freeze_parameters() 82 | 83 | hidden_size = self.audio_encoder_config.hidden_size 84 | 85 | self.pose_map = nn.Linear(out_dim, latent_dim) 86 | self.in_fn = nn.Linear(hidden_size, latent_dim) 87 | 88 | self.PPE = PositionalEncoding(latent_dim) 89 | self.biased_mask = init_biased_mask(n_head = 8, max_seq_len = 600, period=1) 90 | decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8, dim_feedforward=2*latent_dim, batch_first=True) 91 | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=8) 92 | self.pose_map_r = nn.Linear(latent_dim, out_dim) 93 | 94 | self.id_embed = nn.Embedding(100, latent_dim) # 100 ids 95 | 96 | 97 | def infer(self, input_value, seq_len, id_seed=None): 98 | embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) 99 | 100 | if self._only_last_features: 101 | hidden_states = embeddings.last_hidden_state 102 | else: 103 | hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) 104 | 105 | hidden_states = self.in_fn(hidden_states) 106 | 107 | id_embedding = self.id_embed(id_seed).unsqueeze(1) 108 | 109 | init_pose = torch.zeros([hidden_states.shape[0], 1, self.out_dim]).to(hidden_states.device) 110 | for i in range(seq_len): 111 | if i==0: 112 | pose_emb = self.pose_map(init_pose) 113 | pose_input = self.PPE(pose_emb) 114 | else: 115 | pose_input = self.PPE(pose_emb) 116 | 117 | pose_input = pose_input + id_embedding 118 | tgt_mask = self.biased_mask[:, :pose_input.shape[1], :pose_input.shape[1]].clone().detach().to(hidden_states.device) 119 | memory_mask = enc_dec_mask(hidden_states.device, pose_input.shape[1], hidden_states.shape[1]) 120 | pose_out = self.transformer_decoder(pose_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask) 121 | pose_out = self.pose_map_r(pose_out) 122 | new_output = self.pose_map(pose_out[:,-1,:]).unsqueeze(1) 123 | pose_emb = torch.cat((pose_emb, new_output), 1) 124 | return pose_out 125 | 126 | -------------------------------------------------------------------------------- /src/audio_models/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_mask_from_lengths(lengths, max_len=None): 6 | lengths = lengths.to(torch.long) 7 | if max_len is None: 8 | max_len = torch.max(lengths).item() 9 | 10 | ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device) 11 | mask = ids < lengths.unsqueeze(1).expand(-1, max_len) 12 | 13 | return mask 14 | 15 | 16 | def linear_interpolation(features, seq_len): 17 | features = features.transpose(1, 2) 18 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') 19 | return output_features.transpose(1, 2) 20 | 21 | 22 | if __name__ == "__main__": 23 | import numpy as np 24 | mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6]))) 25 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /src/audio_models/wav2vec2.py: -------------------------------------------------------------------------------- 1 | from transformers import Wav2Vec2Config, Wav2Vec2Model 2 | from transformers.modeling_outputs import BaseModelOutput 3 | 4 | from .torch_utils import linear_interpolation 5 | 6 | # the implementation of Wav2Vec2Model is borrowed from 7 | # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py 8 | # initialize our encoder with the pre-trained wav2vec 2.0 weights. 9 | class Wav2Vec2Model(Wav2Vec2Model): 10 | def __init__(self, config: Wav2Vec2Config): 11 | super().__init__(config) 12 | 13 | def forward( 14 | self, 15 | input_values, 16 | seq_len, 17 | attention_mask=None, 18 | mask_time_indices=None, 19 | output_attentions=None, 20 | output_hidden_states=None, 21 | return_dict=None, 22 | ): 23 | self.config.output_attentions = True 24 | 25 | output_hidden_states = ( 26 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 27 | ) 28 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 29 | 30 | extract_features = self.feature_extractor(input_values) 31 | extract_features = extract_features.transpose(1, 2) 32 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 33 | 34 | if attention_mask is not None: 35 | # compute reduced attention_mask corresponding to feature vectors 36 | attention_mask = self._get_feature_vector_attention_mask( 37 | extract_features.shape[1], attention_mask, add_adapter=False 38 | ) 39 | 40 | hidden_states, extract_features = self.feature_projection(extract_features) 41 | hidden_states = self._mask_hidden_states( 42 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 43 | ) 44 | 45 | encoder_outputs = self.encoder( 46 | hidden_states, 47 | attention_mask=attention_mask, 48 | output_attentions=output_attentions, 49 | output_hidden_states=output_hidden_states, 50 | return_dict=return_dict, 51 | ) 52 | 53 | hidden_states = encoder_outputs[0] 54 | 55 | if self.adapter is not None: 56 | hidden_states = self.adapter(hidden_states) 57 | 58 | if not return_dict: 59 | return (hidden_states, ) + encoder_outputs[1:] 60 | return BaseModelOutput( 61 | last_hidden_state=hidden_states, 62 | hidden_states=encoder_outputs.hidden_states, 63 | attentions=encoder_outputs.attentions, 64 | ) 65 | 66 | 67 | def feature_extract( 68 | self, 69 | input_values, 70 | seq_len, 71 | ): 72 | extract_features = self.feature_extractor(input_values) 73 | extract_features = extract_features.transpose(1, 2) 74 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 75 | 76 | return extract_features 77 | 78 | def encode( 79 | self, 80 | extract_features, 81 | attention_mask=None, 82 | mask_time_indices=None, 83 | output_attentions=None, 84 | output_hidden_states=None, 85 | return_dict=None, 86 | ): 87 | self.config.output_attentions = True 88 | 89 | output_hidden_states = ( 90 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 91 | ) 92 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 93 | 94 | if attention_mask is not None: 95 | # compute reduced attention_mask corresponding to feature vectors 96 | attention_mask = self._get_feature_vector_attention_mask( 97 | extract_features.shape[1], attention_mask, add_adapter=False 98 | ) 99 | 100 | 101 | hidden_states, extract_features = self.feature_projection(extract_features) 102 | hidden_states = self._mask_hidden_states( 103 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 104 | ) 105 | 106 | encoder_outputs = self.encoder( 107 | hidden_states, 108 | attention_mask=attention_mask, 109 | output_attentions=output_attentions, 110 | output_hidden_states=output_hidden_states, 111 | return_dict=return_dict, 112 | ) 113 | 114 | hidden_states = encoder_outputs[0] 115 | 116 | if self.adapter is not None: 117 | hidden_states = self.adapter(hidden_states) 118 | 119 | if not return_dict: 120 | return (hidden_states, ) + encoder_outputs[1:] 121 | return BaseModelOutput( 122 | last_hidden_state=hidden_states, 123 | hidden_states=encoder_outputs.hidden_states, 124 | attentions=encoder_outputs.attentions, 125 | ) 126 | -------------------------------------------------------------------------------- /src/dataset/dataset_face.py: -------------------------------------------------------------------------------- 1 | import os, io, csv, math, random, pdb 2 | import cv2 3 | import numpy as np 4 | import json 5 | from PIL import Image 6 | from einops import rearrange 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | from torch.utils.data.dataset import Dataset 11 | from transformers import CLIPImageProcessor 12 | import torch.distributed as dist 13 | 14 | 15 | from src.utils.draw_util import FaceMeshVisualizer 16 | 17 | def zero_rank_print(s): 18 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 19 | 20 | 21 | 22 | class FaceDatasetValid(Dataset): 23 | def __init__( 24 | self, 25 | json_path, 26 | extra_json_path=None, 27 | sample_size=[512, 512], sample_stride=4, sample_n_frames=16, 28 | is_image=False, 29 | sample_stride_aug=False 30 | ): 31 | zero_rank_print(f"loading annotations from {json_path} ...") 32 | self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path) 33 | 34 | self.length = len(self.data_dic_name_list) 35 | zero_rank_print(f"data scale: {self.length}") 36 | 37 | self.sample_stride = sample_stride 38 | self.sample_n_frames = sample_n_frames 39 | 40 | self.sample_stride_aug = sample_stride_aug 41 | 42 | self.sample_size = sample_size 43 | self.resize = transforms.Resize((sample_size[0], sample_size[1])) 44 | 45 | 46 | self.pixel_transforms = transforms.Compose([ 47 | transforms.Resize([sample_size[1], sample_size[0]]), 48 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 49 | ]) 50 | 51 | self.visualizer = FaceMeshVisualizer(forehead_edge=False) 52 | self.clip_image_processor = CLIPImageProcessor() 53 | self.is_image = is_image 54 | 55 | def get_data(self, json_name, extra_json_name, augment_num=1): 56 | zero_rank_print(f"start loading data: {json_name}") 57 | with open(json_name,'r') as f: 58 | data_dic = json.load(f) 59 | 60 | data_dic_name_list = [] 61 | for augment_index in range(augment_num): 62 | for video_name in data_dic.keys(): 63 | data_dic_name_list.append(video_name) 64 | 65 | invalid_video_name_list = [] 66 | for video_name in data_dic_name_list: 67 | video_clip_num = len(data_dic[video_name]['clip_data_list']) 68 | if video_clip_num < 1: 69 | invalid_video_name_list.append(video_name) 70 | for name in invalid_video_name_list: 71 | data_dic_name_list.remove(name) 72 | 73 | 74 | if extra_json_name is not None: 75 | zero_rank_print(f"start loading data: {extra_json_name}") 76 | with open(extra_json_name,'r') as f: 77 | extra_data_dic = json.load(f) 78 | data_dic.update(extra_data_dic) 79 | for augment_index in range(3*augment_num): 80 | for video_name in extra_data_dic.keys(): 81 | data_dic_name_list.append(video_name) 82 | random.shuffle(data_dic_name_list) 83 | zero_rank_print("finish loading") 84 | return data_dic_name_list, data_dic 85 | 86 | def __len__(self): 87 | return len(self.data_dic_name_list) 88 | 89 | def get_batch_wo_pose(self, index): 90 | video_name = self.data_dic_name_list[index] 91 | video_clip_num = len(self.data_dic[video_name]['clip_data_list']) 92 | 93 | source_anchor = random.sample(range(video_clip_num), 1)[0] 94 | source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list'] 95 | source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list'] 96 | 97 | video_length = len(source_image_path_list) 98 | 99 | if self.sample_stride_aug: 100 | tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4 101 | else: 102 | tmp_sample_stride = self.sample_stride 103 | 104 | if not self.is_image: 105 | clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1) 106 | start_idx = random.randint(0, video_length - clip_length) 107 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 108 | else: 109 | batch_index = [random.randint(0, video_length - 1)] 110 | 111 | ref_img_idx = random.randint(0, video_length - 1) 112 | 113 | ref_img = cv2.imread(source_image_path_list[ref_img_idx]) 114 | ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB) 115 | ref_img = self.contrast_normalization(ref_img) 116 | 117 | ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float) 118 | ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True) 119 | 120 | images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index] 121 | images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images] 122 | image_np = np.array([self.contrast_normalization(img) for img in images]) 123 | 124 | pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous() 125 | pixel_values = pixel_values / 255. 126 | 127 | mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index]) 128 | 129 | pixel_values_pose = [] 130 | for frame_id in range(mesh2d_clip.shape[0]): 131 | normed_mesh2d = mesh2d_clip[frame_id] 132 | 133 | pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True) 134 | pixel_values_pose.append(pose_image) 135 | pixel_values_pose = np.array(pixel_values_pose) 136 | 137 | if self.is_image: 138 | pixel_values = pixel_values[0] 139 | pixel_values_pose = pixel_values_pose[0] 140 | image_np = image_np[0] 141 | 142 | return ref_img, pixel_values_pose, image_np, ref_pose_image 143 | 144 | def contrast_normalization(self, image, lower_bound=0, upper_bound=255): 145 | # convert input image to float32 146 | image = image.astype(np.float32) 147 | 148 | # normalize the image 149 | normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound 150 | 151 | # convert to uint8 152 | normalized_image = normalized_image.astype(np.uint8) 153 | 154 | return normalized_image 155 | 156 | def __getitem__(self, idx): 157 | ref_img, pixel_values_pose, tar_gt, pixel_values_ref_pose = self.get_batch_wo_pose(idx) 158 | 159 | sample = dict( 160 | pixel_values_pose=pixel_values_pose, 161 | ref_img=ref_img, 162 | tar_gt=tar_gt, 163 | pixel_values_ref_pose=pixel_values_ref_pose, 164 | ) 165 | 166 | return sample 167 | 168 | 169 | 170 | class FaceDataset(Dataset): 171 | def __init__( 172 | self, 173 | json_path, 174 | extra_json_path=None, 175 | sample_size=[512, 512], sample_stride=4, sample_n_frames=16, 176 | is_image=False, 177 | sample_stride_aug=False 178 | ): 179 | zero_rank_print(f"loading annotations from {json_path} ...") 180 | self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path) 181 | 182 | self.length = len(self.data_dic_name_list) 183 | zero_rank_print(f"data scale: {self.length}") 184 | 185 | self.sample_stride = sample_stride 186 | self.sample_n_frames = sample_n_frames 187 | 188 | self.sample_stride_aug = sample_stride_aug 189 | 190 | self.sample_size = sample_size 191 | self.resize = transforms.Resize((sample_size[0], sample_size[1])) 192 | 193 | 194 | self.pixel_transforms = transforms.Compose([ 195 | transforms.Resize([sample_size[1], sample_size[0]]), 196 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 197 | ]) 198 | 199 | self.visualizer = FaceMeshVisualizer(forehead_edge=False) 200 | self.clip_image_processor = CLIPImageProcessor() 201 | self.is_image = is_image 202 | 203 | def get_data(self, json_name, extra_json_name, augment_num=1): 204 | zero_rank_print(f"start loading data: {json_name}") 205 | with open(json_name,'r') as f: 206 | data_dic = json.load(f) 207 | 208 | data_dic_name_list = [] 209 | for augment_index in range(augment_num): 210 | for video_name in data_dic.keys(): 211 | data_dic_name_list.append(video_name) 212 | 213 | invalid_video_name_list = [] 214 | for video_name in data_dic_name_list: 215 | video_clip_num = len(data_dic[video_name]['clip_data_list']) 216 | if video_clip_num < 1: 217 | invalid_video_name_list.append(video_name) 218 | for name in invalid_video_name_list: 219 | data_dic_name_list.remove(name) 220 | 221 | 222 | if extra_json_name is not None: 223 | zero_rank_print(f"start loading data: {extra_json_name}") 224 | with open(extra_json_name,'r') as f: 225 | extra_data_dic = json.load(f) 226 | data_dic.update(extra_data_dic) 227 | for augment_index in range(3*augment_num): 228 | for video_name in extra_data_dic.keys(): 229 | data_dic_name_list.append(video_name) 230 | random.shuffle(data_dic_name_list) 231 | zero_rank_print("finish loading") 232 | return data_dic_name_list, data_dic 233 | 234 | def __len__(self): 235 | return len(self.data_dic_name_list) 236 | 237 | 238 | def get_batch_wo_pose(self, index): 239 | video_name = self.data_dic_name_list[index] 240 | video_clip_num = len(self.data_dic[video_name]['clip_data_list']) 241 | 242 | source_anchor = random.sample(range(video_clip_num), 1)[0] 243 | source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list'] 244 | source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list'] 245 | 246 | video_length = len(source_image_path_list) 247 | 248 | if self.sample_stride_aug: 249 | tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4 250 | else: 251 | tmp_sample_stride = self.sample_stride 252 | 253 | if not self.is_image: 254 | clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1) 255 | start_idx = random.randint(0, video_length - clip_length) 256 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 257 | else: 258 | batch_index = [random.randint(0, video_length - 1)] 259 | 260 | ref_img_idx = random.randint(0, video_length - 1) 261 | ref_img = cv2.imread(source_image_path_list[ref_img_idx]) 262 | ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB) 263 | ref_img = self.contrast_normalization(ref_img) 264 | ref_img_pil = Image.fromarray(ref_img) 265 | 266 | clip_ref_image = self.clip_image_processor(images=ref_img_pil, return_tensors="pt").pixel_values 267 | 268 | pixel_values_ref_img = torch.from_numpy(ref_img).permute(2, 0, 1).contiguous() 269 | pixel_values_ref_img = pixel_values_ref_img / 255. 270 | 271 | ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float) 272 | ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True) 273 | pixel_values_ref_pose = torch.from_numpy(ref_pose_image).permute(2, 0, 1).contiguous() 274 | pixel_values_ref_pose = pixel_values_ref_pose / 255. 275 | 276 | images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index] 277 | images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images] 278 | image_np = np.array([self.contrast_normalization(img) for img in images]) 279 | 280 | pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous() 281 | pixel_values = pixel_values / 255. 282 | 283 | mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index]) 284 | 285 | pixel_values_pose = [] 286 | for frame_id in range(mesh2d_clip.shape[0]): 287 | normed_mesh2d = mesh2d_clip[frame_id] 288 | 289 | pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True) 290 | 291 | pixel_values_pose.append(pose_image) 292 | 293 | pixel_values_pose = np.array(pixel_values_pose) 294 | pixel_values_pose = torch.from_numpy(pixel_values_pose).permute(0, 3, 1, 2).contiguous() 295 | pixel_values_pose = pixel_values_pose / 255. 296 | 297 | if self.is_image: 298 | pixel_values = pixel_values[0] 299 | pixel_values_pose = pixel_values_pose[0] 300 | 301 | return pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose 302 | 303 | def contrast_normalization(self, image, lower_bound=0, upper_bound=255): 304 | image = image.astype(np.float32) 305 | normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound 306 | normalized_image = normalized_image.astype(np.uint8) 307 | 308 | return normalized_image 309 | 310 | def __getitem__(self, idx): 311 | pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose = self.get_batch_wo_pose(idx) 312 | 313 | pixel_values = self.pixel_transforms(pixel_values) 314 | pixel_values_pose = self.pixel_transforms(pixel_values_pose) 315 | 316 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) 317 | pixel_values_ref_img = self.pixel_transforms(pixel_values_ref_img) 318 | pixel_values_ref_img = pixel_values_ref_img.squeeze(0) 319 | 320 | pixel_values_ref_pose = pixel_values_ref_pose.unsqueeze(0) 321 | pixel_values_ref_pose = self.pixel_transforms(pixel_values_ref_pose) 322 | pixel_values_ref_pose = pixel_values_ref_pose.squeeze(0) 323 | 324 | drop_image_embeds = 1 if random.random() < 0.1 else 0 325 | 326 | sample = dict( 327 | pixel_values=pixel_values, 328 | pixel_values_pose=pixel_values_pose, 329 | clip_ref_image=clip_ref_image, 330 | pixel_values_ref_img=pixel_values_ref_img, 331 | drop_image_embeds=drop_image_embeds, 332 | pixel_values_ref_pose=pixel_values_ref_pose, 333 | ) 334 | 335 | return sample 336 | 337 | def collate_fn(data): 338 | pixel_values = torch.stack([example["pixel_values"] for example in data]) 339 | pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data]) 340 | clip_ref_image = torch.cat([example["clip_ref_image"] for example in data]) 341 | pixel_values_ref_img = torch.stack([example["pixel_values_ref_img"] for example in data]) 342 | drop_image_embeds = [example["drop_image_embeds"] for example in data] 343 | drop_image_embeds = torch.Tensor(drop_image_embeds) 344 | pixel_values_ref_pose = torch.stack([example["pixel_values_ref_pose"] for example in data]) 345 | 346 | return { 347 | "pixel_values": pixel_values, 348 | "pixel_values_pose": pixel_values_pose, 349 | "clip_ref_image": clip_ref_image, 350 | "pixel_values_ref_img": pixel_values_ref_img, 351 | "drop_image_embeds": drop_image_embeds, 352 | "pixel_values_ref_pose": pixel_values_ref_pose, 353 | } 354 | 355 | -------------------------------------------------------------------------------- /src/models/motion_module.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Callable, Optional 5 | 6 | import torch 7 | from diffusers.models.attention import FeedForward 8 | from diffusers.models.attention_processor import Attention, AttnProcessor 9 | from diffusers.utils import BaseOutput 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from einops import rearrange, repeat 12 | from torch import nn 13 | 14 | 15 | def zero_module(module): 16 | # Zero out the parameters of a module and return it. 17 | for p in module.parameters(): 18 | p.detach().zero_() 19 | return module 20 | 21 | 22 | @dataclass 23 | class TemporalTransformer3DModelOutput(BaseOutput): 24 | sample: torch.FloatTensor 25 | 26 | 27 | if is_xformers_available(): 28 | import xformers 29 | import xformers.ops 30 | else: 31 | xformers = None 32 | 33 | 34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): 35 | if motion_module_type == "Vanilla": 36 | return VanillaTemporalModule( 37 | in_channels=in_channels, 38 | **motion_module_kwargs, 39 | ) 40 | else: 41 | raise ValueError 42 | 43 | 44 | class VanillaTemporalModule(nn.Module): 45 | def __init__( 46 | self, 47 | in_channels, 48 | num_attention_heads=8, 49 | num_transformer_block=2, 50 | attention_block_types=("Temporal_Self", "Temporal_Self"), 51 | cross_frame_attention_mode=None, 52 | temporal_position_encoding=False, 53 | temporal_position_encoding_max_len=24, 54 | temporal_attention_dim_div=1, 55 | zero_initialize=True, 56 | ): 57 | super().__init__() 58 | 59 | self.temporal_transformer = TemporalTransformer3DModel( 60 | in_channels=in_channels, 61 | num_attention_heads=num_attention_heads, 62 | attention_head_dim=in_channels 63 | // num_attention_heads 64 | // temporal_attention_dim_div, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_frame_attention_mode=cross_frame_attention_mode, 68 | temporal_position_encoding=temporal_position_encoding, 69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 70 | ) 71 | 72 | if zero_initialize: 73 | self.temporal_transformer.proj_out = zero_module( 74 | self.temporal_transformer.proj_out 75 | ) 76 | 77 | def forward( 78 | self, 79 | input_tensor, 80 | temb, 81 | encoder_hidden_states, 82 | attention_mask=None, 83 | anchor_frame_idx=None, 84 | ): 85 | hidden_states = input_tensor 86 | hidden_states = self.temporal_transformer( 87 | hidden_states, encoder_hidden_states, attention_mask 88 | ) 89 | 90 | output = hidden_states 91 | return output 92 | 93 | 94 | class TemporalTransformer3DModel(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels, 98 | num_attention_heads, 99 | attention_head_dim, 100 | num_layers, 101 | attention_block_types=( 102 | "Temporal_Self", 103 | "Temporal_Self", 104 | ), 105 | dropout=0.0, 106 | norm_num_groups=32, 107 | cross_attention_dim=768, 108 | activation_fn="geglu", 109 | attention_bias=False, 110 | upcast_attention=False, 111 | cross_frame_attention_mode=None, 112 | temporal_position_encoding=False, 113 | temporal_position_encoding_max_len=24, 114 | ): 115 | super().__init__() 116 | 117 | inner_dim = num_attention_heads * attention_head_dim 118 | 119 | self.norm = torch.nn.GroupNorm( 120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 121 | ) 122 | self.proj_in = nn.Linear(in_channels, inner_dim) 123 | 124 | self.transformer_blocks = nn.ModuleList( 125 | [ 126 | TemporalTransformerBlock( 127 | dim=inner_dim, 128 | num_attention_heads=num_attention_heads, 129 | attention_head_dim=attention_head_dim, 130 | attention_block_types=attention_block_types, 131 | dropout=dropout, 132 | norm_num_groups=norm_num_groups, 133 | cross_attention_dim=cross_attention_dim, 134 | activation_fn=activation_fn, 135 | attention_bias=attention_bias, 136 | upcast_attention=upcast_attention, 137 | cross_frame_attention_mode=cross_frame_attention_mode, 138 | temporal_position_encoding=temporal_position_encoding, 139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 140 | ) 141 | for d in range(num_layers) 142 | ] 143 | ) 144 | self.proj_out = nn.Linear(inner_dim, in_channels) 145 | 146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 147 | assert ( 148 | hidden_states.dim() == 5 149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 150 | video_length = hidden_states.shape[2] 151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 152 | 153 | batch, channel, height, weight = hidden_states.shape 154 | residual = hidden_states 155 | 156 | hidden_states = self.norm(hidden_states) 157 | inner_dim = hidden_states.shape[1] 158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 159 | batch, height * weight, inner_dim 160 | ) 161 | hidden_states = self.proj_in(hidden_states) 162 | 163 | # Transformer Blocks 164 | for block in self.transformer_blocks: 165 | hidden_states = block( 166 | hidden_states, 167 | encoder_hidden_states=encoder_hidden_states, 168 | video_length=video_length, 169 | ) 170 | 171 | # output 172 | hidden_states = self.proj_out(hidden_states) 173 | hidden_states = ( 174 | hidden_states.reshape(batch, height, weight, inner_dim) 175 | .permute(0, 3, 1, 2) 176 | .contiguous() 177 | ) 178 | 179 | output = hidden_states + residual 180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 181 | 182 | return output 183 | 184 | 185 | class TemporalTransformerBlock(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | num_attention_heads, 190 | attention_head_dim, 191 | attention_block_types=( 192 | "Temporal_Self", 193 | "Temporal_Self", 194 | ), 195 | dropout=0.0, 196 | norm_num_groups=32, 197 | cross_attention_dim=768, 198 | activation_fn="geglu", 199 | attention_bias=False, 200 | upcast_attention=False, 201 | cross_frame_attention_mode=None, 202 | temporal_position_encoding=False, 203 | temporal_position_encoding_max_len=24, 204 | ): 205 | super().__init__() 206 | 207 | attention_blocks = [] 208 | norms = [] 209 | 210 | for block_name in attention_block_types: 211 | attention_blocks.append( 212 | VersatileAttention( 213 | attention_mode=block_name.split("_")[0], 214 | cross_attention_dim=cross_attention_dim 215 | if block_name.endswith("_Cross") 216 | else None, 217 | query_dim=dim, 218 | heads=num_attention_heads, 219 | dim_head=attention_head_dim, 220 | dropout=dropout, 221 | bias=attention_bias, 222 | upcast_attention=upcast_attention, 223 | cross_frame_attention_mode=cross_frame_attention_mode, 224 | temporal_position_encoding=temporal_position_encoding, 225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 226 | ) 227 | ) 228 | norms.append(nn.LayerNorm(dim)) 229 | 230 | self.attention_blocks = nn.ModuleList(attention_blocks) 231 | self.norms = nn.ModuleList(norms) 232 | 233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 234 | self.ff_norm = nn.LayerNorm(dim) 235 | 236 | def forward( 237 | self, 238 | hidden_states, 239 | encoder_hidden_states=None, 240 | attention_mask=None, 241 | video_length=None, 242 | ): 243 | for attention_block, norm in zip(self.attention_blocks, self.norms): 244 | norm_hidden_states = norm(hidden_states) 245 | hidden_states = ( 246 | attention_block( 247 | norm_hidden_states, 248 | encoder_hidden_states=encoder_hidden_states 249 | if attention_block.is_cross_attention 250 | else None, 251 | video_length=video_length, 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 257 | 258 | output = hidden_states 259 | return output 260 | 261 | 262 | class PositionalEncoding(nn.Module): 263 | def __init__(self, d_model, dropout=0.0, max_len=24): 264 | super().__init__() 265 | self.dropout = nn.Dropout(p=dropout) 266 | position = torch.arange(max_len).unsqueeze(1) 267 | div_term = torch.exp( 268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 269 | ) 270 | pe = torch.zeros(1, max_len, d_model) 271 | pe[0, :, 0::2] = torch.sin(position * div_term) 272 | pe[0, :, 1::2] = torch.cos(position * div_term) 273 | self.register_buffer("pe", pe) 274 | 275 | def forward(self, x): 276 | x = x + self.pe[:, : x.size(1)] 277 | return self.dropout(x) 278 | 279 | 280 | class VersatileAttention(Attention): 281 | def __init__( 282 | self, 283 | attention_mode=None, 284 | cross_frame_attention_mode=None, 285 | temporal_position_encoding=False, 286 | temporal_position_encoding_max_len=24, 287 | *args, 288 | **kwargs, 289 | ): 290 | super().__init__(*args, **kwargs) 291 | assert attention_mode == "Temporal" 292 | 293 | self.attention_mode = attention_mode 294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 295 | 296 | self.pos_encoder = ( 297 | PositionalEncoding( 298 | kwargs["query_dim"], 299 | dropout=0.0, 300 | max_len=temporal_position_encoding_max_len, 301 | ) 302 | if (temporal_position_encoding and attention_mode == "Temporal") 303 | else None 304 | ) 305 | 306 | def extra_repr(self): 307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 308 | 309 | def set_use_memory_efficient_attention_xformers( 310 | self, 311 | use_memory_efficient_attention_xformers: bool, 312 | attention_op: Optional[Callable] = None, 313 | ): 314 | if use_memory_efficient_attention_xformers: 315 | if not is_xformers_available(): 316 | raise ModuleNotFoundError( 317 | ( 318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 319 | " xformers" 320 | ), 321 | name="xformers", 322 | ) 323 | elif not torch.cuda.is_available(): 324 | raise ValueError( 325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 326 | " only available for GPU " 327 | ) 328 | else: 329 | try: 330 | # Make sure we can run the memory efficient attention 331 | _ = xformers.ops.memory_efficient_attention( 332 | torch.randn((1, 2, 40), device="cuda"), 333 | torch.randn((1, 2, 40), device="cuda"), 334 | torch.randn((1, 2, 40), device="cuda"), 335 | ) 336 | except Exception as e: 337 | raise e 338 | 339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13. 340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13. 341 | # You don't need XFormersAttnProcessor here. 342 | # processor = XFormersAttnProcessor( 343 | # attention_op=attention_op, 344 | # ) 345 | processor = AttnProcessor() 346 | else: 347 | processor = AttnProcessor() 348 | 349 | self.set_processor(processor) 350 | 351 | def forward( 352 | self, 353 | hidden_states, 354 | encoder_hidden_states=None, 355 | attention_mask=None, 356 | video_length=None, 357 | **cross_attention_kwargs, 358 | ): 359 | if self.attention_mode == "Temporal": 360 | d = hidden_states.shape[1] # d means HxW 361 | hidden_states = rearrange( 362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 363 | ) 364 | 365 | if self.pos_encoder is not None: 366 | hidden_states = self.pos_encoder(hidden_states) 367 | 368 | encoder_hidden_states = ( 369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) 370 | if encoder_hidden_states is not None 371 | else encoder_hidden_states 372 | ) 373 | 374 | else: 375 | raise NotImplementedError 376 | 377 | hidden_states = self.processor( 378 | self, 379 | hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | 385 | if self.attention_mode == "Temporal": 386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 387 | 388 | return hidden_states 389 | -------------------------------------------------------------------------------- /src/models/mutual_self_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from src.models.attention import TemporalBasicTransformerBlock 8 | 9 | from .attention import BasicTransformerBlock 10 | 11 | 12 | def torch_dfs(model: torch.nn.Module): 13 | result = [model] 14 | for child in model.children(): 15 | result += torch_dfs(child) 16 | return result 17 | 18 | 19 | class ReferenceAttentionControl: 20 | def __init__( 21 | self, 22 | unet, 23 | mode="write", 24 | do_classifier_free_guidance=False, 25 | attention_auto_machine_weight=float("inf"), 26 | gn_auto_machine_weight=1.0, 27 | style_fidelity=1.0, 28 | reference_attn=True, 29 | reference_adain=False, 30 | fusion_blocks="midup", 31 | batch_size=1, 32 | ) -> None: 33 | # 10. Modify self attention and group norm 34 | self.unet = unet 35 | assert mode in ["read", "write"] 36 | assert fusion_blocks in ["midup", "full"] 37 | self.reference_attn = reference_attn 38 | self.reference_adain = reference_adain 39 | self.fusion_blocks = fusion_blocks 40 | self.register_reference_hooks( 41 | mode, 42 | do_classifier_free_guidance, 43 | attention_auto_machine_weight, 44 | gn_auto_machine_weight, 45 | style_fidelity, 46 | reference_attn, 47 | reference_adain, 48 | fusion_blocks, 49 | batch_size=batch_size, 50 | ) 51 | 52 | def register_reference_hooks( 53 | self, 54 | mode, 55 | do_classifier_free_guidance, 56 | attention_auto_machine_weight, 57 | gn_auto_machine_weight, 58 | style_fidelity, 59 | reference_attn, 60 | reference_adain, 61 | dtype=torch.float16, 62 | batch_size=1, 63 | num_images_per_prompt=1, 64 | device=torch.device("cpu"), 65 | fusion_blocks="midup", 66 | ): 67 | MODE = mode 68 | do_classifier_free_guidance = do_classifier_free_guidance 69 | attention_auto_machine_weight = attention_auto_machine_weight 70 | gn_auto_machine_weight = gn_auto_machine_weight 71 | style_fidelity = style_fidelity 72 | reference_attn = reference_attn 73 | reference_adain = reference_adain 74 | fusion_blocks = fusion_blocks 75 | num_images_per_prompt = num_images_per_prompt 76 | dtype = dtype 77 | if do_classifier_free_guidance: 78 | uc_mask = ( 79 | torch.Tensor( 80 | [1] * batch_size * num_images_per_prompt * 16 81 | + [0] * batch_size * num_images_per_prompt * 16 82 | ) 83 | .to(device) 84 | .bool() 85 | ) 86 | else: 87 | uc_mask = ( 88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 89 | .to(device) 90 | .bool() 91 | ) 92 | 93 | def hacked_basic_transformer_inner_forward( 94 | self, 95 | hidden_states: torch.FloatTensor, 96 | attention_mask: Optional[torch.FloatTensor] = None, 97 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 98 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 99 | timestep: Optional[torch.LongTensor] = None, 100 | cross_attention_kwargs: Dict[str, Any] = None, 101 | class_labels: Optional[torch.LongTensor] = None, 102 | video_length=None, 103 | ): 104 | if self.use_ada_layer_norm: # False 105 | norm_hidden_states = self.norm1(hidden_states, timestep) 106 | elif self.use_ada_layer_norm_zero: 107 | ( 108 | norm_hidden_states, 109 | gate_msa, 110 | shift_mlp, 111 | scale_mlp, 112 | gate_mlp, 113 | ) = self.norm1( 114 | hidden_states, 115 | timestep, 116 | class_labels, 117 | hidden_dtype=hidden_states.dtype, 118 | ) 119 | else: 120 | norm_hidden_states = self.norm1(hidden_states) 121 | 122 | # 1. Self-Attention 123 | # self.only_cross_attention = False 124 | cross_attention_kwargs = ( 125 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 126 | ) 127 | if self.only_cross_attention: 128 | attn_output = self.attn1( 129 | norm_hidden_states, 130 | encoder_hidden_states=encoder_hidden_states 131 | if self.only_cross_attention 132 | else None, 133 | attention_mask=attention_mask, 134 | **cross_attention_kwargs, 135 | ) 136 | else: 137 | if MODE == "write": 138 | self.bank.append(norm_hidden_states.clone()) 139 | attn_output = self.attn1( 140 | norm_hidden_states, 141 | encoder_hidden_states=encoder_hidden_states 142 | if self.only_cross_attention 143 | else None, 144 | attention_mask=attention_mask, 145 | **cross_attention_kwargs, 146 | ) 147 | if MODE == "read": 148 | bank_fea = [ 149 | rearrange( 150 | d.unsqueeze(1).repeat(1, video_length, 1, 1), 151 | "b t l c -> (b t) l c", 152 | ) 153 | for d in self.bank 154 | ] 155 | modify_norm_hidden_states = torch.cat( 156 | [norm_hidden_states] + bank_fea, dim=1 157 | ) 158 | hidden_states_uc = ( 159 | self.attn1( 160 | norm_hidden_states, 161 | encoder_hidden_states=modify_norm_hidden_states, 162 | attention_mask=attention_mask, 163 | ) 164 | + hidden_states 165 | ) 166 | if do_classifier_free_guidance: 167 | hidden_states_c = hidden_states_uc.clone() 168 | _uc_mask = uc_mask.clone() 169 | if hidden_states.shape[0] != _uc_mask.shape[0]: 170 | _uc_mask = ( 171 | torch.Tensor( 172 | [1] * (hidden_states.shape[0] // 2) 173 | + [0] * (hidden_states.shape[0] // 2) 174 | ) 175 | .to(device) 176 | .bool() 177 | ) 178 | hidden_states_c[_uc_mask] = ( 179 | self.attn1( 180 | norm_hidden_states[_uc_mask], 181 | encoder_hidden_states=norm_hidden_states[_uc_mask], 182 | attention_mask=attention_mask, 183 | ) 184 | + hidden_states[_uc_mask] 185 | ) 186 | hidden_states = hidden_states_c.clone() 187 | else: 188 | hidden_states = hidden_states_uc 189 | 190 | # self.bank.clear() 191 | if self.attn2 is not None: 192 | # Cross-Attention 193 | norm_hidden_states = ( 194 | self.norm2(hidden_states, timestep) 195 | if self.use_ada_layer_norm 196 | else self.norm2(hidden_states) 197 | ) 198 | hidden_states = ( 199 | self.attn2( 200 | norm_hidden_states, 201 | encoder_hidden_states=encoder_hidden_states, 202 | attention_mask=attention_mask, 203 | ) 204 | + hidden_states 205 | ) 206 | 207 | # Feed-forward 208 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 209 | 210 | # Temporal-Attention 211 | if self.unet_use_temporal_attention: 212 | d = hidden_states.shape[1] 213 | hidden_states = rearrange( 214 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 215 | ) 216 | norm_hidden_states = ( 217 | self.norm_temp(hidden_states, timestep) 218 | if self.use_ada_layer_norm 219 | else self.norm_temp(hidden_states) 220 | ) 221 | hidden_states = ( 222 | self.attn_temp(norm_hidden_states) + hidden_states 223 | ) 224 | hidden_states = rearrange( 225 | hidden_states, "(b d) f c -> (b f) d c", d=d 226 | ) 227 | 228 | return hidden_states 229 | 230 | if self.use_ada_layer_norm_zero: 231 | attn_output = gate_msa.unsqueeze(1) * attn_output 232 | hidden_states = attn_output + hidden_states 233 | 234 | if self.attn2 is not None: 235 | norm_hidden_states = ( 236 | self.norm2(hidden_states, timestep) 237 | if self.use_ada_layer_norm 238 | else self.norm2(hidden_states) 239 | ) 240 | 241 | # 2. Cross-Attention 242 | attn_output = self.attn2( 243 | norm_hidden_states, 244 | encoder_hidden_states=encoder_hidden_states, 245 | attention_mask=encoder_attention_mask, 246 | **cross_attention_kwargs, 247 | ) 248 | hidden_states = attn_output + hidden_states 249 | 250 | # 3. Feed-forward 251 | norm_hidden_states = self.norm3(hidden_states) 252 | 253 | if self.use_ada_layer_norm_zero: 254 | norm_hidden_states = ( 255 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 256 | ) 257 | 258 | ff_output = self.ff(norm_hidden_states) 259 | 260 | if self.use_ada_layer_norm_zero: 261 | ff_output = gate_mlp.unsqueeze(1) * ff_output 262 | 263 | hidden_states = ff_output + hidden_states 264 | 265 | return hidden_states 266 | 267 | if self.reference_attn: 268 | if self.fusion_blocks == "midup": 269 | attn_modules = [ 270 | module 271 | for module in ( 272 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 273 | ) 274 | if isinstance(module, BasicTransformerBlock) 275 | or isinstance(module, TemporalBasicTransformerBlock) 276 | ] 277 | elif self.fusion_blocks == "full": 278 | attn_modules = [ 279 | module 280 | for module in torch_dfs(self.unet) 281 | if isinstance(module, BasicTransformerBlock) 282 | or isinstance(module, TemporalBasicTransformerBlock) 283 | ] 284 | attn_modules = sorted( 285 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 286 | ) 287 | 288 | for i, module in enumerate(attn_modules): 289 | module._original_inner_forward = module.forward 290 | if isinstance(module, BasicTransformerBlock): 291 | module.forward = hacked_basic_transformer_inner_forward.__get__( 292 | module, BasicTransformerBlock 293 | ) 294 | if isinstance(module, TemporalBasicTransformerBlock): 295 | module.forward = hacked_basic_transformer_inner_forward.__get__( 296 | module, TemporalBasicTransformerBlock 297 | ) 298 | 299 | module.bank = [] 300 | module.attn_weight = float(i) / float(len(attn_modules)) 301 | 302 | def update(self, writer, dtype=torch.float16): 303 | if self.reference_attn: 304 | if self.fusion_blocks == "midup": 305 | reader_attn_modules = [ 306 | module 307 | for module in ( 308 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 309 | ) 310 | if isinstance(module, TemporalBasicTransformerBlock) 311 | ] 312 | writer_attn_modules = [ 313 | module 314 | for module in ( 315 | torch_dfs(writer.unet.mid_block) 316 | + torch_dfs(writer.unet.up_blocks) 317 | ) 318 | if isinstance(module, BasicTransformerBlock) 319 | ] 320 | elif self.fusion_blocks == "full": 321 | reader_attn_modules = [ 322 | module 323 | for module in torch_dfs(self.unet) 324 | if isinstance(module, TemporalBasicTransformerBlock) 325 | ] 326 | writer_attn_modules = [ 327 | module 328 | for module in torch_dfs(writer.unet) 329 | if isinstance(module, BasicTransformerBlock) 330 | ] 331 | reader_attn_modules = sorted( 332 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 333 | ) 334 | writer_attn_modules = sorted( 335 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 336 | ) 337 | for r, w in zip(reader_attn_modules, writer_attn_modules): 338 | r.bank = [v.clone().to(dtype) for v in w.bank] 339 | # w.bank.clear() 340 | 341 | def clear(self): 342 | if self.reference_attn: 343 | if self.fusion_blocks == "midup": 344 | reader_attn_modules = [ 345 | module 346 | for module in ( 347 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 348 | ) 349 | if isinstance(module, BasicTransformerBlock) 350 | or isinstance(module, TemporalBasicTransformerBlock) 351 | ] 352 | elif self.fusion_blocks == "full": 353 | reader_attn_modules = [ 354 | module 355 | for module in torch_dfs(self.unet) 356 | if isinstance(module, BasicTransformerBlock) 357 | or isinstance(module, TemporalBasicTransformerBlock) 358 | ] 359 | reader_attn_modules = sorted( 360 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 361 | ) 362 | for r in reader_attn_modules: 363 | r.bank.clear() 364 | -------------------------------------------------------------------------------- /src/models/pose_guider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | from einops import rearrange 6 | import numpy as np 7 | from diffusers.models.modeling_utils import ModelMixin 8 | 9 | from typing import Any, Dict, Optional 10 | from src.models.attention import BasicTransformerBlock 11 | 12 | 13 | class PoseGuider(ModelMixin): 14 | def __init__(self, noise_latent_channels=320, use_ca=True): 15 | super(PoseGuider, self).__init__() 16 | 17 | self.use_ca = use_ca 18 | 19 | self.conv_layers = nn.Sequential( 20 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(3), 22 | nn.ReLU(), 23 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), 24 | nn.BatchNorm2d(16), 25 | nn.ReLU(), 26 | 27 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1), 28 | nn.BatchNorm2d(16), 29 | nn.ReLU(), 30 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), 31 | nn.BatchNorm2d(32), 32 | nn.ReLU(), 33 | 34 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(32), 36 | nn.ReLU(), 37 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), 38 | nn.BatchNorm2d(64), 39 | nn.ReLU(), 40 | 41 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(64), 43 | nn.ReLU(), 44 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 45 | nn.BatchNorm2d(128), 46 | nn.ReLU() 47 | ) 48 | 49 | # Final projection layer 50 | self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1) 51 | 52 | self.conv_layers_1 = nn.Sequential( 53 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1), 54 | nn.BatchNorm2d(noise_latent_channels), 55 | nn.ReLU(), 56 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, stride=2, padding=1), 57 | nn.BatchNorm2d(noise_latent_channels), 58 | nn.ReLU(), 59 | ) 60 | 61 | self.conv_layers_2 = nn.Sequential( 62 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(noise_latent_channels), 64 | nn.ReLU(), 65 | nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels*2, kernel_size=3, stride=2, padding=1), 66 | nn.BatchNorm2d(noise_latent_channels*2), 67 | nn.ReLU(), 68 | ) 69 | 70 | self.conv_layers_3 = nn.Sequential( 71 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*2, kernel_size=3, padding=1), 72 | nn.BatchNorm2d(noise_latent_channels*2), 73 | nn.ReLU(), 74 | nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*4, kernel_size=3, stride=2, padding=1), 75 | nn.BatchNorm2d(noise_latent_channels*4), 76 | nn.ReLU(), 77 | ) 78 | 79 | self.conv_layers_4 = nn.Sequential( 80 | nn.Conv2d(in_channels=noise_latent_channels*4, out_channels=noise_latent_channels*4, kernel_size=3, padding=1), 81 | nn.BatchNorm2d(noise_latent_channels*4), 82 | nn.ReLU(), 83 | ) 84 | 85 | if self.use_ca: 86 | self.cross_attn1 = Transformer2DModel(in_channels=noise_latent_channels) 87 | self.cross_attn2 = Transformer2DModel(in_channels=noise_latent_channels*2) 88 | self.cross_attn3 = Transformer2DModel(in_channels=noise_latent_channels*4) 89 | self.cross_attn4 = Transformer2DModel(in_channels=noise_latent_channels*4) 90 | 91 | # Initialize layers 92 | self._initialize_weights() 93 | 94 | self.scale = nn.Parameter(torch.ones(1) * 2) 95 | 96 | # def _initialize_weights(self): 97 | # # Initialize weights with Gaussian distribution and zero out the final layer 98 | # for m in self.conv_layers: 99 | # if isinstance(m, nn.Conv2d): 100 | # init.normal_(m.weight, mean=0.0, std=0.02) 101 | # if m.bias is not None: 102 | # init.zeros_(m.bias) 103 | 104 | # init.zeros_(self.final_proj.weight) 105 | # if self.final_proj.bias is not None: 106 | # init.zeros_(self.final_proj.bias) 107 | 108 | def _initialize_weights(self): 109 | # Initialize weights with He initialization and zero out the biases 110 | conv_blocks = [self.conv_layers, self.conv_layers_1, self.conv_layers_2, self.conv_layers_3, self.conv_layers_4] 111 | for block_item in conv_blocks: 112 | for m in block_item: 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 115 | init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n)) 116 | if m.bias is not None: 117 | init.zeros_(m.bias) 118 | 119 | # For the final projection layer, initialize weights to zero (or you may choose to use He initialization here as well) 120 | init.zeros_(self.final_proj.weight) 121 | if self.final_proj.bias is not None: 122 | init.zeros_(self.final_proj.bias) 123 | 124 | def forward(self, x, ref_x): 125 | fea = [] 126 | b = x.shape[0] 127 | 128 | x = rearrange(x, "b c f h w -> (b f) c h w") 129 | x = self.conv_layers(x) 130 | x = self.final_proj(x) 131 | x = x * self.scale 132 | # x = rearrange(x, "(b f) c h w -> b c f h w", b=b) 133 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 134 | 135 | x = self.conv_layers_1(x) 136 | if self.use_ca: 137 | ref_x = self.conv_layers(ref_x) 138 | ref_x = self.final_proj(ref_x) 139 | ref_x = ref_x * self.scale 140 | ref_x = self.conv_layers_1(ref_x) 141 | x = self.cross_attn1(x, ref_x) 142 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 143 | 144 | x = self.conv_layers_2(x) 145 | if self.use_ca: 146 | ref_x = self.conv_layers_2(ref_x) 147 | x = self.cross_attn2(x, ref_x) 148 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 149 | 150 | x = self.conv_layers_3(x) 151 | if self.use_ca: 152 | ref_x = self.conv_layers_3(ref_x) 153 | x = self.cross_attn3(x, ref_x) 154 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 155 | 156 | x = self.conv_layers_4(x) 157 | if self.use_ca: 158 | ref_x = self.conv_layers_4(ref_x) 159 | x = self.cross_attn4(x, ref_x) 160 | fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b)) 161 | 162 | return fea 163 | 164 | @classmethod 165 | def from_pretrained(cls,pretrained_model_path): 166 | if not os.path.exists(pretrained_model_path): 167 | print(f"There is no model file in {pretrained_model_path}") 168 | print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...") 169 | 170 | state_dict = torch.load(pretrained_model_path, map_location="cpu") 171 | model = Hack_PoseGuider(noise_latent_channels=320) 172 | 173 | m, u = model.load_state_dict(state_dict, strict=True) 174 | # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 175 | params = [p.numel() for n, p in model.named_parameters()] 176 | print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M") 177 | 178 | return model 179 | 180 | 181 | class Transformer2DModel(ModelMixin): 182 | _supports_gradient_checkpointing = True 183 | def __init__( 184 | self, 185 | num_attention_heads: int = 16, 186 | attention_head_dim: int = 88, 187 | in_channels: Optional[int] = None, 188 | num_layers: int = 1, 189 | dropout: float = 0.0, 190 | norm_num_groups: int = 32, 191 | cross_attention_dim: Optional[int] = None, 192 | attention_bias: bool = False, 193 | activation_fn: str = "geglu", 194 | num_embeds_ada_norm: Optional[int] = None, 195 | use_linear_projection: bool = False, 196 | only_cross_attention: bool = False, 197 | double_self_attention: bool = False, 198 | upcast_attention: bool = False, 199 | norm_type: str = "layer_norm", 200 | norm_elementwise_affine: bool = True, 201 | norm_eps: float = 1e-5, 202 | attention_type: str = "default", 203 | ): 204 | super().__init__() 205 | self.use_linear_projection = use_linear_projection 206 | self.num_attention_heads = num_attention_heads 207 | self.attention_head_dim = attention_head_dim 208 | inner_dim = num_attention_heads * attention_head_dim 209 | 210 | self.in_channels = in_channels 211 | 212 | self.norm = torch.nn.GroupNorm( 213 | num_groups=norm_num_groups, 214 | num_channels=in_channels, 215 | eps=1e-6, 216 | affine=True, 217 | ) 218 | if use_linear_projection: 219 | self.proj_in = nn.Linear(in_channels, inner_dim) 220 | else: 221 | self.proj_in = nn.Conv2d( 222 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 223 | ) 224 | 225 | # 3. Define transformers blocks 226 | self.transformer_blocks = nn.ModuleList( 227 | [ 228 | BasicTransformerBlock( 229 | inner_dim, 230 | num_attention_heads, 231 | attention_head_dim, 232 | dropout=dropout, 233 | cross_attention_dim=cross_attention_dim, 234 | activation_fn=activation_fn, 235 | num_embeds_ada_norm=num_embeds_ada_norm, 236 | attention_bias=attention_bias, 237 | only_cross_attention=only_cross_attention, 238 | double_self_attention=double_self_attention, 239 | upcast_attention=upcast_attention, 240 | norm_type=norm_type, 241 | norm_elementwise_affine=norm_elementwise_affine, 242 | norm_eps=norm_eps, 243 | attention_type=attention_type, 244 | ) 245 | for d in range(num_layers) 246 | ] 247 | ) 248 | 249 | if use_linear_projection: 250 | self.proj_out = nn.Linear(inner_dim, in_channels) 251 | else: 252 | self.proj_out = nn.Conv2d( 253 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 254 | ) 255 | 256 | self.gradient_checkpointing = False 257 | 258 | def _set_gradient_checkpointing(self, module, value=False): 259 | if hasattr(module, "gradient_checkpointing"): 260 | module.gradient_checkpointing = value 261 | 262 | def forward( 263 | self, 264 | hidden_states: torch.Tensor, 265 | encoder_hidden_states: Optional[torch.Tensor] = None, 266 | timestep: Optional[torch.LongTensor] = None, 267 | ): 268 | batch, _, height, width = hidden_states.shape 269 | residual = hidden_states 270 | 271 | hidden_states = self.norm(hidden_states) 272 | if not self.use_linear_projection: 273 | hidden_states = self.proj_in(hidden_states) 274 | inner_dim = hidden_states.shape[1] 275 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 276 | batch, height * width, inner_dim 277 | ) 278 | else: 279 | inner_dim = hidden_states.shape[1] 280 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 281 | batch, height * width, inner_dim 282 | ) 283 | hidden_states = self.proj_in(hidden_states) 284 | 285 | for block in self.transformer_blocks: 286 | hidden_states = block( 287 | hidden_states, 288 | encoder_hidden_states=encoder_hidden_states, 289 | timestep=timestep, 290 | ) 291 | 292 | if not self.use_linear_projection: 293 | hidden_states = ( 294 | hidden_states.reshape(batch, height, width, inner_dim) 295 | .permute(0, 3, 1, 2) 296 | .contiguous() 297 | ) 298 | hidden_states = self.proj_out(hidden_states) 299 | else: 300 | hidden_states = self.proj_out(hidden_states) 301 | hidden_states = ( 302 | hidden_states.reshape(batch, height, width, inner_dim) 303 | .permute(0, 3, 1, 2) 304 | .contiguous() 305 | ) 306 | 307 | output = hidden_states + residual 308 | return output 309 | 310 | 311 | if __name__ == '__main__': 312 | model = PoseGuider(noise_latent_channels=320).to(device="cuda") 313 | 314 | input_data = torch.randn(1,3,1,512,512).to(device="cuda") 315 | input_data1 = torch.randn(1,3,512,512).to(device="cuda") 316 | 317 | output = model(input_data, input_data1) 318 | for item in output: 319 | print(item.shape) 320 | 321 | # tf_model = Transformer2DModel( 322 | # in_channels=320 323 | # ).to('cuda') 324 | 325 | # input_data = torch.randn(4,320,32,32).to(device="cuda") 326 | # # input_emb = torch.randn(4,1,768).to(device="cuda") 327 | # input_emb = torch.randn(4,320,32,32).to(device="cuda") 328 | # o1 = tf_model(input_data, input_emb) 329 | # print(o1.shape) 330 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from typing import Dict, Optional 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class InflatedGroupNorm(nn.GroupNorm): 22 | def forward(self, x): 23 | video_length = x.shape[2] 24 | 25 | x = rearrange(x, "b c f h w -> (b f) c h w") 26 | x = super().forward(x) 27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 28 | 29 | return x 30 | 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__( 34 | self, 35 | channels, 36 | use_conv=False, 37 | use_conv_transpose=False, 38 | out_channels=None, 39 | name="conv", 40 | ): 41 | super().__init__() 42 | self.channels = channels 43 | self.out_channels = out_channels or channels 44 | self.use_conv = use_conv 45 | self.use_conv_transpose = use_conv_transpose 46 | self.name = name 47 | 48 | conv = None 49 | if use_conv_transpose: 50 | raise NotImplementedError 51 | elif use_conv: 52 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 53 | 54 | def forward(self, hidden_states, output_size=None): 55 | assert hidden_states.shape[1] == self.channels 56 | 57 | if self.use_conv_transpose: 58 | raise NotImplementedError 59 | 60 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 61 | dtype = hidden_states.dtype 62 | if dtype == torch.bfloat16: 63 | hidden_states = hidden_states.to(torch.float32) 64 | 65 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 66 | if hidden_states.shape[0] >= 64: 67 | hidden_states = hidden_states.contiguous() 68 | 69 | # if `output_size` is passed we force the interpolation output 70 | # size and do not make use of `scale_factor=2` 71 | if output_size is None: 72 | hidden_states = F.interpolate( 73 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 74 | ) 75 | else: 76 | hidden_states = F.interpolate( 77 | hidden_states, size=output_size, mode="nearest" 78 | ) 79 | 80 | # If the input is bfloat16, we cast back to bfloat16 81 | if dtype == torch.bfloat16: 82 | hidden_states = hidden_states.to(dtype) 83 | 84 | # if self.use_conv: 85 | # if self.name == "conv": 86 | # hidden_states = self.conv(hidden_states) 87 | # else: 88 | # hidden_states = self.Conv2d_0(hidden_states) 89 | hidden_states = self.conv(hidden_states) 90 | 91 | return hidden_states 92 | 93 | 94 | class Downsample3D(nn.Module): 95 | def __init__( 96 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 97 | ): 98 | super().__init__() 99 | self.channels = channels 100 | self.out_channels = out_channels or channels 101 | self.use_conv = use_conv 102 | self.padding = padding 103 | stride = 2 104 | self.name = name 105 | 106 | if use_conv: 107 | self.conv = InflatedConv3d( 108 | self.channels, self.out_channels, 3, stride=stride, padding=padding 109 | ) 110 | else: 111 | raise NotImplementedError 112 | 113 | def forward(self, hidden_states): 114 | assert hidden_states.shape[1] == self.channels 115 | if self.use_conv and self.padding == 0: 116 | raise NotImplementedError 117 | 118 | assert hidden_states.shape[1] == self.channels 119 | hidden_states = self.conv(hidden_states) 120 | 121 | return hidden_states 122 | 123 | 124 | class ResnetBlock3D(nn.Module): 125 | def __init__( 126 | self, 127 | *, 128 | in_channels, 129 | out_channels=None, 130 | conv_shortcut=False, 131 | dropout=0.0, 132 | temb_channels=512, 133 | groups=32, 134 | groups_out=None, 135 | pre_norm=True, 136 | eps=1e-6, 137 | non_linearity="swish", 138 | time_embedding_norm="default", 139 | output_scale_factor=1.0, 140 | use_in_shortcut=None, 141 | use_inflated_groupnorm=None, 142 | ): 143 | super().__init__() 144 | self.pre_norm = pre_norm 145 | self.pre_norm = True 146 | self.in_channels = in_channels 147 | out_channels = in_channels if out_channels is None else out_channels 148 | self.out_channels = out_channels 149 | self.use_conv_shortcut = conv_shortcut 150 | self.time_embedding_norm = time_embedding_norm 151 | self.output_scale_factor = output_scale_factor 152 | 153 | if groups_out is None: 154 | groups_out = groups 155 | 156 | assert use_inflated_groupnorm != None 157 | if use_inflated_groupnorm: 158 | self.norm1 = InflatedGroupNorm( 159 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 160 | ) 161 | else: 162 | self.norm1 = torch.nn.GroupNorm( 163 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 164 | ) 165 | 166 | self.conv1 = InflatedConv3d( 167 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 168 | ) 169 | 170 | if temb_channels is not None: 171 | if self.time_embedding_norm == "default": 172 | time_emb_proj_out_channels = out_channels 173 | elif self.time_embedding_norm == "scale_shift": 174 | time_emb_proj_out_channels = out_channels * 2 175 | else: 176 | raise ValueError( 177 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 178 | ) 179 | 180 | self.time_emb_proj = torch.nn.Linear( 181 | temb_channels, time_emb_proj_out_channels 182 | ) 183 | else: 184 | self.time_emb_proj = None 185 | 186 | if use_inflated_groupnorm: 187 | self.norm2 = InflatedGroupNorm( 188 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 189 | ) 190 | else: 191 | self.norm2 = torch.nn.GroupNorm( 192 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 193 | ) 194 | self.dropout = torch.nn.Dropout(dropout) 195 | self.conv2 = InflatedConv3d( 196 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 197 | ) 198 | 199 | if non_linearity == "swish": 200 | self.nonlinearity = lambda x: F.silu(x) 201 | elif non_linearity == "mish": 202 | self.nonlinearity = Mish() 203 | elif non_linearity == "silu": 204 | self.nonlinearity = nn.SiLU() 205 | 206 | self.use_in_shortcut = ( 207 | self.in_channels != self.out_channels 208 | if use_in_shortcut is None 209 | else use_in_shortcut 210 | ) 211 | 212 | self.conv_shortcut = None 213 | if self.use_in_shortcut: 214 | self.conv_shortcut = InflatedConv3d( 215 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 216 | ) 217 | 218 | def forward(self, input_tensor, temb): 219 | hidden_states = input_tensor 220 | 221 | hidden_states = self.norm1(hidden_states) 222 | hidden_states = self.nonlinearity(hidden_states) 223 | 224 | hidden_states = self.conv1(hidden_states) 225 | 226 | if temb is not None: 227 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 228 | 229 | if temb is not None and self.time_embedding_norm == "default": 230 | hidden_states = hidden_states + temb 231 | 232 | hidden_states = self.norm2(hidden_states) 233 | 234 | if temb is not None and self.time_embedding_norm == "scale_shift": 235 | scale, shift = torch.chunk(temb, 2, dim=1) 236 | hidden_states = hidden_states * (1 + scale) + shift 237 | 238 | hidden_states = self.nonlinearity(hidden_states) 239 | 240 | hidden_states = self.dropout(hidden_states) 241 | hidden_states = self.conv2(hidden_states) 242 | 243 | if self.conv_shortcut is not None: 244 | input_tensor = self.conv_shortcut(input_tensor) 245 | 246 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 247 | 248 | return output_tensor 249 | 250 | class Mish(torch.nn.Module): 251 | def forward(self, hidden_states): 252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 253 | -------------------------------------------------------------------------------- /src/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Dict 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from .attention import TemporalBasicTransformerBlock, ResidualTemporalBasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer3DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | 19 | 20 | if is_xformers_available(): 21 | import xformers 22 | import xformers.ops 23 | else: 24 | xformers = None 25 | 26 | 27 | class Transformer3DModel(ModelMixin, ConfigMixin): 28 | _supports_gradient_checkpointing = True 29 | 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm( 59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 60 | ) 61 | if use_linear_projection: 62 | self.proj_in = nn.Linear(in_channels, inner_dim) 63 | else: 64 | self.proj_in = nn.Conv2d( 65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 66 | ) 67 | 68 | # Define transformers blocks 69 | self.transformer_blocks = nn.ModuleList( 70 | [ 71 | TemporalBasicTransformerBlock( 72 | inner_dim, 73 | num_attention_heads, 74 | attention_head_dim, 75 | dropout=dropout, 76 | cross_attention_dim=cross_attention_dim, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | attention_bias=attention_bias, 80 | only_cross_attention=only_cross_attention, 81 | upcast_attention=upcast_attention, 82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 83 | unet_use_temporal_attention=unet_use_temporal_attention, 84 | ) 85 | for d in range(num_layers) 86 | ] 87 | ) 88 | 89 | # 4. Define output layers 90 | if use_linear_projection: 91 | self.proj_out = nn.Linear(in_channels, inner_dim) 92 | else: 93 | self.proj_out = nn.Conv2d( 94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 95 | ) 96 | 97 | self.gradient_checkpointing = False 98 | 99 | def _set_gradient_checkpointing(self, module, value=False): 100 | if hasattr(module, "gradient_checkpointing"): 101 | module.gradient_checkpointing = value 102 | 103 | def forward( 104 | self, 105 | hidden_states, 106 | encoder_hidden_states=None, 107 | timestep=None, 108 | return_dict: bool = True, 109 | ): 110 | # Input 111 | assert ( 112 | hidden_states.dim() == 5 113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 114 | video_length = hidden_states.shape[2] 115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 116 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 117 | encoder_hidden_states = repeat( 118 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 119 | ) 120 | 121 | batch, channel, height, weight = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | if not self.use_linear_projection: 126 | hidden_states = self.proj_in(hidden_states) 127 | inner_dim = hidden_states.shape[1] 128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 129 | batch, height * weight, inner_dim 130 | ) 131 | else: 132 | inner_dim = hidden_states.shape[1] 133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 134 | batch, height * weight, inner_dim 135 | ) 136 | hidden_states = self.proj_in(hidden_states) 137 | 138 | # Blocks 139 | for i, block in enumerate(self.transformer_blocks): 140 | hidden_states = block( 141 | hidden_states, 142 | encoder_hidden_states=encoder_hidden_states, 143 | timestep=timestep, 144 | video_length=video_length, 145 | ) 146 | 147 | # Output 148 | if not self.use_linear_projection: 149 | hidden_states = ( 150 | hidden_states.reshape(batch, height, weight, inner_dim) 151 | .permute(0, 3, 1, 2) 152 | .contiguous() 153 | ) 154 | hidden_states = self.proj_out(hidden_states) 155 | else: 156 | hidden_states = self.proj_out(hidden_states) 157 | hidden_states = ( 158 | hidden_states.reshape(batch, height, weight, inner_dim) 159 | .permute(0, 3, 1, 2) 160 | .contiguous() 161 | ) 162 | 163 | output = hidden_states + residual 164 | 165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 166 | if not return_dict: 167 | return (output,) 168 | 169 | return Transformer3DModelOutput(sample=output) 170 | -------------------------------------------------------------------------------- /src/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # TODO: Adapted from cli 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | def ordered_halving(val): 8 | bin_str = f"{val:064b}" 9 | bin_flip = bin_str[::-1] 10 | as_int = int(bin_flip, 2) 11 | 12 | return as_int / (1 << 64) 13 | 14 | 15 | def uniform( 16 | step: int = ..., 17 | num_steps: Optional[int] = None, 18 | num_frames: int = ..., 19 | context_size: Optional[int] = None, 20 | context_stride: int = 3, 21 | context_overlap: int = 4, 22 | closed_loop: bool = True, 23 | ): 24 | if num_frames <= context_size: 25 | yield list(range(num_frames)) 26 | return 27 | 28 | context_stride = min( 29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 30 | ) 31 | 32 | for context_step in 1 << np.arange(context_stride): 33 | pad = int(round(num_frames * ordered_halving(step))) 34 | for j in range( 35 | int(ordered_halving(step) * context_step) + pad, 36 | num_frames + pad + (0 if closed_loop else -context_overlap), 37 | (context_size * context_step - context_overlap), 38 | ): 39 | yield [ 40 | e % num_frames 41 | for e in range(j, j + context_size * context_step, context_step) 42 | ] 43 | 44 | 45 | def get_context_scheduler(name: str) -> Callable: 46 | if name == "uniform": 47 | return uniform 48 | else: 49 | raise ValueError(f"Unknown context_overlap policy {name}") 50 | 51 | 52 | def get_total_steps( 53 | scheduler, 54 | timesteps: List[int], 55 | num_steps: Optional[int] = None, 56 | num_frames: int = ..., 57 | context_size: Optional[int] = None, 58 | context_stride: int = 3, 59 | context_overlap: int = 4, 60 | closed_loop: bool = True, 61 | ): 62 | return sum( 63 | len( 64 | list( 65 | scheduler( 66 | i, 67 | num_steps, 68 | num_frames, 69 | context_size, 70 | context_stride, 71 | context_overlap, 72 | ) 73 | ) 74 | ) 75 | for i in range(len(timesteps)) 76 | ) 77 | -------------------------------------------------------------------------------- /src/pipelines/pipeline_pose2img.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | from diffusers import DiffusionPipeline 10 | from diffusers.image_processor import VaeImageProcessor 11 | from diffusers.schedulers import ( 12 | DDIMScheduler, 13 | DPMSolverMultistepScheduler, 14 | EulerAncestralDiscreteScheduler, 15 | EulerDiscreteScheduler, 16 | LMSDiscreteScheduler, 17 | PNDMScheduler, 18 | ) 19 | from diffusers.utils import BaseOutput, is_accelerate_available 20 | from diffusers.utils.torch_utils import randn_tensor 21 | from einops import rearrange 22 | from tqdm import tqdm 23 | from transformers import CLIPImageProcessor 24 | 25 | from src.models.mutual_self_attention import ReferenceAttentionControl 26 | 27 | 28 | @dataclass 29 | class Pose2ImagePipelineOutput(BaseOutput): 30 | images: Union[torch.Tensor, np.ndarray] 31 | 32 | 33 | class Pose2ImagePipeline(DiffusionPipeline): 34 | _optional_components = [] 35 | 36 | def __init__( 37 | self, 38 | vae, 39 | image_encoder, 40 | reference_unet, 41 | denoising_unet, 42 | pose_guider, 43 | scheduler: Union[ 44 | DDIMScheduler, 45 | PNDMScheduler, 46 | LMSDiscreteScheduler, 47 | EulerDiscreteScheduler, 48 | EulerAncestralDiscreteScheduler, 49 | DPMSolverMultistepScheduler, 50 | ], 51 | ): 52 | super().__init__() 53 | 54 | self.register_modules( 55 | vae=vae, 56 | image_encoder=image_encoder, 57 | reference_unet=reference_unet, 58 | denoising_unet=denoising_unet, 59 | pose_guider=pose_guider, 60 | scheduler=scheduler, 61 | ) 62 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 63 | self.clip_image_processor = CLIPImageProcessor() 64 | self.ref_image_processor = VaeImageProcessor( 65 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 66 | ) 67 | self.cond_image_processor = VaeImageProcessor( 68 | vae_scale_factor=self.vae_scale_factor, 69 | do_convert_rgb=True, 70 | do_normalize=True, 71 | ) 72 | 73 | def enable_vae_slicing(self): 74 | self.vae.enable_slicing() 75 | 76 | def disable_vae_slicing(self): 77 | self.vae.disable_slicing() 78 | 79 | def enable_sequential_cpu_offload(self, gpu_id=0): 80 | if is_accelerate_available(): 81 | from accelerate import cpu_offload 82 | else: 83 | raise ImportError("Please install accelerate via `pip install accelerate`") 84 | 85 | device = torch.device(f"cuda:{gpu_id}") 86 | 87 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 88 | if cpu_offloaded_model is not None: 89 | cpu_offload(cpu_offloaded_model, device) 90 | 91 | @property 92 | def _execution_device(self): 93 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 94 | return self.device 95 | for module in self.unet.modules(): 96 | if ( 97 | hasattr(module, "_hf_hook") 98 | and hasattr(module._hf_hook, "execution_device") 99 | and module._hf_hook.execution_device is not None 100 | ): 101 | return torch.device(module._hf_hook.execution_device) 102 | return self.device 103 | 104 | def decode_latents(self, latents): 105 | video_length = latents.shape[2] 106 | latents = 1 / 0.18215 * latents 107 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 108 | # video = self.vae.decode(latents).sample 109 | video = [] 110 | for frame_idx in tqdm(range(latents.shape[0])): 111 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 112 | video = torch.cat(video) 113 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 114 | video = (video / 2 + 0.5).clamp(0, 1) 115 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 116 | video = video.cpu().float().numpy() 117 | return video 118 | 119 | def prepare_extra_step_kwargs(self, generator, eta): 120 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 121 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 122 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 123 | # and should be between [0, 1] 124 | 125 | accepts_eta = "eta" in set( 126 | inspect.signature(self.scheduler.step).parameters.keys() 127 | ) 128 | extra_step_kwargs = {} 129 | if accepts_eta: 130 | extra_step_kwargs["eta"] = eta 131 | 132 | # check if the scheduler accepts generator 133 | accepts_generator = "generator" in set( 134 | inspect.signature(self.scheduler.step).parameters.keys() 135 | ) 136 | if accepts_generator: 137 | extra_step_kwargs["generator"] = generator 138 | return extra_step_kwargs 139 | 140 | def prepare_latents( 141 | self, 142 | batch_size, 143 | num_channels_latents, 144 | width, 145 | height, 146 | dtype, 147 | device, 148 | generator, 149 | latents=None, 150 | ): 151 | shape = ( 152 | batch_size, 153 | num_channels_latents, 154 | height // self.vae_scale_factor, 155 | width // self.vae_scale_factor, 156 | ) 157 | if isinstance(generator, list) and len(generator) != batch_size: 158 | raise ValueError( 159 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 160 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 161 | ) 162 | 163 | if latents is None: 164 | latents = randn_tensor( 165 | shape, generator=generator, device=device, dtype=dtype 166 | ) 167 | else: 168 | latents = latents.to(device) 169 | 170 | # scale the initial noise by the standard deviation required by the scheduler 171 | latents = latents * self.scheduler.init_noise_sigma 172 | return latents 173 | 174 | def prepare_condition( 175 | self, 176 | cond_image, 177 | width, 178 | height, 179 | device, 180 | dtype, 181 | do_classififer_free_guidance=False, 182 | ): 183 | image = self.cond_image_processor.preprocess( 184 | cond_image, height=height, width=width 185 | ).to(dtype=torch.float32) 186 | 187 | image = image.to(device=device, dtype=dtype) 188 | 189 | 190 | if do_classififer_free_guidance: 191 | image = torch.cat([image] * 2) 192 | 193 | return image 194 | 195 | @torch.no_grad() 196 | def __call__( 197 | self, 198 | ref_image, 199 | pose_image, 200 | ref_pose_image, 201 | width, 202 | height, 203 | num_inference_steps, 204 | guidance_scale, 205 | num_images_per_prompt=1, 206 | eta: float = 0.0, 207 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 208 | output_type: Optional[str] = "tensor", 209 | return_dict: bool = True, 210 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 211 | callback_steps: Optional[int] = 1, 212 | **kwargs, 213 | ): 214 | # Default height and width to unet 215 | height = height or self.unet.config.sample_size * self.vae_scale_factor 216 | width = width or self.unet.config.sample_size * self.vae_scale_factor 217 | 218 | device = self._execution_device 219 | 220 | do_classifier_free_guidance = guidance_scale > 1.0 221 | 222 | # Prepare timesteps 223 | self.scheduler.set_timesteps(num_inference_steps, device=device) 224 | timesteps = self.scheduler.timesteps 225 | 226 | batch_size = 1 227 | 228 | # Prepare clip image embeds 229 | clip_image = self.clip_image_processor.preprocess( 230 | ref_image.resize((224, 224)), return_tensors="pt" 231 | ).pixel_values 232 | clip_image_embeds = self.image_encoder( 233 | clip_image.to(device, dtype=self.image_encoder.dtype) 234 | ).image_embeds 235 | image_prompt_embeds = clip_image_embeds.unsqueeze(1) 236 | 237 | uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds) 238 | 239 | if do_classifier_free_guidance: 240 | image_prompt_embeds = torch.cat( 241 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0 242 | ) 243 | 244 | reference_control_writer = ReferenceAttentionControl( 245 | self.reference_unet, 246 | do_classifier_free_guidance=do_classifier_free_guidance, 247 | mode="write", 248 | batch_size=batch_size, 249 | fusion_blocks="full", 250 | ) 251 | reference_control_reader = ReferenceAttentionControl( 252 | self.denoising_unet, 253 | do_classifier_free_guidance=do_classifier_free_guidance, 254 | mode="read", 255 | batch_size=batch_size, 256 | fusion_blocks="full", 257 | ) 258 | 259 | num_channels_latents = self.denoising_unet.in_channels 260 | latents = self.prepare_latents( 261 | batch_size * num_images_per_prompt, 262 | num_channels_latents, 263 | width, 264 | height, 265 | clip_image_embeds.dtype, 266 | device, 267 | generator, 268 | ) 269 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w') 270 | latents_dtype = latents.dtype 271 | 272 | # Prepare extra step kwargs. 273 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 274 | 275 | # Prepare ref image latents 276 | ref_image_tensor = self.ref_image_processor.preprocess( 277 | ref_image, height=height, width=width 278 | ) # (bs, c, width, height) 279 | ref_image_tensor = ref_image_tensor.to( 280 | dtype=self.vae.dtype, device=self.vae.device 281 | ) 282 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 283 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 284 | 285 | # Prepare pose condition image 286 | pose_cond_tensor = self.cond_image_processor.preprocess( 287 | pose_image, height=height, width=width 288 | ) 289 | 290 | pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w) 291 | pose_cond_tensor = pose_cond_tensor.to( 292 | device=device, dtype=self.pose_guider.dtype 293 | ) 294 | 295 | ref_pose_tensor = self.cond_image_processor.preprocess( 296 | ref_pose_image, height=height, width=width 297 | ) 298 | ref_pose_tensor = ref_pose_tensor.to( 299 | device=device, dtype=self.pose_guider.dtype 300 | ) 301 | 302 | pose_fea = self.pose_guider(pose_cond_tensor, ref_pose_tensor) 303 | if do_classifier_free_guidance: 304 | for idxx in range(len(pose_fea)): 305 | pose_fea[idxx] = torch.cat([pose_fea[idxx]] * 2) 306 | 307 | # denoising loop 308 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 309 | with self.progress_bar(total=num_inference_steps) as progress_bar: 310 | for i, t in enumerate(timesteps): 311 | # 1. Forward reference image 312 | if i == 0: 313 | self.reference_unet( 314 | ref_image_latents.repeat( 315 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 316 | ), 317 | torch.zeros_like(t), 318 | encoder_hidden_states=image_prompt_embeds, 319 | return_dict=False, 320 | ) 321 | 322 | # 2. Update reference unet feature into denosing net 323 | reference_control_reader.update(reference_control_writer) 324 | 325 | # 3.1 expand the latents if we are doing classifier free guidance 326 | latent_model_input = ( 327 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 328 | ) 329 | latent_model_input = self.scheduler.scale_model_input( 330 | latent_model_input, t 331 | ) 332 | 333 | noise_pred = self.denoising_unet( 334 | latent_model_input, 335 | t, 336 | encoder_hidden_states=image_prompt_embeds, 337 | pose_cond_fea=pose_fea, 338 | return_dict=False, 339 | )[0] 340 | 341 | # perform guidance 342 | if do_classifier_free_guidance: 343 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 344 | noise_pred = noise_pred_uncond + guidance_scale * ( 345 | noise_pred_text - noise_pred_uncond 346 | ) 347 | 348 | # compute the previous noisy sample x_t -> x_t-1 349 | latents = self.scheduler.step( 350 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 351 | )[0] 352 | 353 | # call the callback, if provided 354 | if i == len(timesteps) - 1 or ( 355 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 356 | ): 357 | progress_bar.update() 358 | if callback is not None and i % callback_steps == 0: 359 | step_idx = i // getattr(self.scheduler, "order", 1) 360 | callback(step_idx, t, latents) 361 | reference_control_reader.clear() 362 | reference_control_writer.clear() 363 | 364 | # Post-processing 365 | image = self.decode_latents(latents) # (b, c, 1, h, w) 366 | 367 | # Convert to tensor 368 | if output_type == "tensor": 369 | image = torch.from_numpy(image) 370 | 371 | if not return_dict: 372 | return image 373 | 374 | return Pose2ImagePipelineOutput(images=image) 375 | -------------------------------------------------------------------------------- /src/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | tensor_interpolation = None 4 | 5 | 6 | def get_tensor_interpolation_method(): 7 | return tensor_interpolation 8 | 9 | 10 | def set_tensor_interpolation_method(is_slerp): 11 | global tensor_interpolation 12 | tensor_interpolation = slerp if is_slerp else linear 13 | 14 | 15 | def linear(v1, v2, t): 16 | return (1.0 - t) * v1 + t * v2 17 | 18 | 19 | def slerp( 20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 21 | ) -> torch.Tensor: 22 | u0 = v0 / v0.norm() 23 | u1 = v1 / v1.norm() 24 | dot = (u0 * u1).sum() 25 | if dot.abs() > DOT_THRESHOLD: 26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 27 | return (1.0 - t) * v0 + t * v1 28 | omega = dot.acos() 29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 30 | -------------------------------------------------------------------------------- /src/utils/audio_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import librosa 5 | import numpy as np 6 | from transformers import Wav2Vec2FeatureExtractor 7 | 8 | 9 | class DataProcessor: 10 | def __init__(self, sampling_rate, wav2vec_model_path): 11 | self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) 12 | self._sampling_rate = sampling_rate 13 | 14 | def extract_feature(self, audio_path): 15 | speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate) 16 | input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values) 17 | return input_value 18 | 19 | 20 | def prepare_audio_feature(wav_file, fps=30, sampling_rate=16000, wav2vec_model_path=None): 21 | data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path) 22 | 23 | input_value = data_preprocessor.extract_feature(wav_file) 24 | seq_len = math.ceil(len(input_value)/sampling_rate*fps) 25 | return { 26 | "audio_feature": input_value, 27 | "seq_len": seq_len 28 | } 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/utils/draw_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mediapipe as mp 3 | import numpy as np 4 | from mediapipe.framework.formats import landmark_pb2 5 | 6 | class FaceMeshVisualizer: 7 | def __init__(self, forehead_edge=False): 8 | self.mp_drawing = mp.solutions.drawing_utils 9 | mp_face_mesh = mp.solutions.face_mesh 10 | self.mp_face_mesh = mp_face_mesh 11 | self.forehead_edge = forehead_edge 12 | 13 | DrawingSpec = mp.solutions.drawing_styles.DrawingSpec 14 | f_thick = 2 15 | f_rad = 1 16 | right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) 17 | right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) 18 | right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) 19 | left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) 20 | left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) 21 | left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) 22 | head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) 23 | 24 | mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad) 25 | mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad) 26 | 27 | mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad) 28 | mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad) 29 | 30 | mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad) 31 | mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad) 32 | 33 | mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad) 34 | mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad) 35 | 36 | FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)] 37 | FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)] 38 | 39 | FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)] 40 | FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)] 41 | 42 | FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)] 43 | FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)] 44 | 45 | FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)] 46 | FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)] 47 | 48 | FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)] 49 | 50 | # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. 51 | face_connection_spec = {} 52 | if self.forehead_edge: 53 | for edge in mp_face_mesh.FACEMESH_FACE_OVAL: 54 | face_connection_spec[edge] = head_draw 55 | else: 56 | for edge in FACEMESH_CUSTOM_FACE_OVAL: 57 | face_connection_spec[edge] = head_draw 58 | for edge in mp_face_mesh.FACEMESH_LEFT_EYE: 59 | face_connection_spec[edge] = left_eye_draw 60 | for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: 61 | face_connection_spec[edge] = left_eyebrow_draw 62 | # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: 63 | # face_connection_spec[edge] = left_iris_draw 64 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: 65 | face_connection_spec[edge] = right_eye_draw 66 | for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: 67 | face_connection_spec[edge] = right_eyebrow_draw 68 | # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: 69 | # face_connection_spec[edge] = right_iris_draw 70 | # for edge in mp_face_mesh.FACEMESH_LIPS: 71 | # face_connection_spec[edge] = mouth_draw 72 | 73 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT: 74 | face_connection_spec[edge] = mouth_draw_obl 75 | for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT: 76 | face_connection_spec[edge] = mouth_draw_obr 77 | for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT: 78 | face_connection_spec[edge] = mouth_draw_ibl 79 | for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT: 80 | face_connection_spec[edge] = mouth_draw_ibr 81 | for edge in FACEMESH_LIPS_OUTER_TOP_LEFT: 82 | face_connection_spec[edge] = mouth_draw_otl 83 | for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT: 84 | face_connection_spec[edge] = mouth_draw_otr 85 | for edge in FACEMESH_LIPS_INNER_TOP_LEFT: 86 | face_connection_spec[edge] = mouth_draw_itl 87 | for edge in FACEMESH_LIPS_INNER_TOP_RIGHT: 88 | face_connection_spec[edge] = mouth_draw_itr 89 | 90 | 91 | iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} 92 | 93 | self.face_connection_spec = face_connection_spec 94 | def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2): 95 | """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all 96 | landmarks. Until our PR is merged into mediapipe, we need this separate method.""" 97 | if len(image.shape) != 3: 98 | raise ValueError("Input image must be H,W,C.") 99 | image_rows, image_cols, image_channels = image.shape 100 | if image_channels != 3: # BGR channels 101 | raise ValueError('Input image must contain three channel bgr data.') 102 | for idx, landmark in enumerate(landmark_list.landmark): 103 | if ( 104 | (landmark.HasField('visibility') and landmark.visibility < 0.9) or 105 | (landmark.HasField('presence') and landmark.presence < 0.5) 106 | ): 107 | continue 108 | if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: 109 | continue 110 | image_x = int(image_cols*landmark.x) 111 | image_y = int(image_rows*landmark.y) 112 | draw_color = None 113 | if isinstance(drawing_spec, Mapping): 114 | if drawing_spec.get(idx) is None: 115 | continue 116 | else: 117 | draw_color = drawing_spec[idx].color 118 | elif isinstance(drawing_spec, DrawingSpec): 119 | draw_color = drawing_spec.color 120 | image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color 121 | 122 | 123 | 124 | def draw_landmarks(self, image_size, keypoints, normed=False): 125 | ini_size = [512, 512] 126 | image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) 127 | new_landmarks = landmark_pb2.NormalizedLandmarkList() 128 | for i in range(keypoints.shape[0]): 129 | landmark = new_landmarks.landmark.add() 130 | if normed: 131 | landmark.x = keypoints[i, 0] 132 | landmark.y = keypoints[i, 1] 133 | else: 134 | landmark.x = keypoints[i, 0] / image_size[0] 135 | landmark.y = keypoints[i, 1] / image_size[1] 136 | landmark.z = 1.0 137 | 138 | self.mp_drawing.draw_landmarks( 139 | image=image, 140 | landmark_list=new_landmarks, 141 | connections=self.face_connection_spec.keys(), 142 | landmark_drawing_spec=None, 143 | connection_drawing_spec=self.face_connection_spec 144 | ) 145 | # draw_pupils(image, face_landmarks, iris_landmark_spec, 2) 146 | image = cv2.resize(image, (image_size[0], image_size[1])) 147 | 148 | return image 149 | 150 | -------------------------------------------------------------------------------- /src/utils/frame_interpolation.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/dajes/frame-interpolation-pytorch 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import bisect 7 | import shutil 8 | import pdb 9 | from tqdm import tqdm 10 | 11 | def init_frame_interpolation_model(): 12 | print("Initializing frame interpolation model") 13 | checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt") 14 | 15 | model = torch.jit.load(checkpoint_name, map_location='cpu') 16 | model.eval() 17 | model = model.half() 18 | model = model.to(device="cuda") 19 | return model 20 | 21 | 22 | def batch_images_interpolation_tool(input_tensor, model, inter_frames=1): 23 | 24 | video_tensor = [] 25 | frame_num = input_tensor.shape[2] # bs, channel, frame, height, width 26 | 27 | for idx in tqdm(range(frame_num-1)): 28 | image1 = input_tensor[:,:,idx] 29 | image2 = input_tensor[:,:,idx+1] 30 | 31 | results = [image1, image2] 32 | 33 | inter_frames = int(inter_frames) 34 | idxes = [0, inter_frames + 1] 35 | remains = list(range(1, inter_frames + 1)) 36 | 37 | splits = torch.linspace(0, 1, inter_frames + 2) 38 | 39 | for _ in range(len(remains)): 40 | starts = splits[idxes[:-1]] 41 | ends = splits[idxes[1:]] 42 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() 43 | matrix = torch.argmin(distances).item() 44 | start_i, step = np.unravel_index(matrix, distances.shape) 45 | end_i = start_i + 1 46 | 47 | x0 = results[start_i] 48 | x1 = results[end_i] 49 | 50 | x0 = x0.half() 51 | x1 = x1.half() 52 | x0 = x0.cuda() 53 | x1 = x1.cuda() 54 | 55 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) 56 | 57 | with torch.no_grad(): 58 | prediction = model(x0, x1, dt) 59 | insert_position = bisect.bisect_left(idxes, remains[step]) 60 | idxes.insert(insert_position, remains[step]) 61 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) 62 | del remains[step] 63 | 64 | for sub_idx in range(len(results)-1): 65 | video_tensor.append(results[sub_idx].unsqueeze(2)) 66 | 67 | video_tensor.append(input_tensor[:,:,-1].unsqueeze(2)) 68 | video_tensor = torch.cat(video_tensor, dim=2) 69 | return video_tensor -------------------------------------------------------------------------------- /src/utils/mp_models/blaze_face_short_range.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/src/utils/mp_models/blaze_face_short_range.tflite -------------------------------------------------------------------------------- /src/utils/mp_models/face_landmarker_v2_with_blendshapes.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task -------------------------------------------------------------------------------- /src/utils/mp_models/pose_landmarker_heavy.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zejun-Yang/AniPortrait/7f6593fd971c95b65f603c15ef8ce2c3fc7f5404/src/utils/mp_models/pose_landmarker_heavy.task -------------------------------------------------------------------------------- /src/utils/mp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import time 5 | from tqdm import tqdm 6 | import multiprocessing 7 | import glob 8 | 9 | import mediapipe as mp 10 | from mediapipe import solutions 11 | from mediapipe.framework.formats import landmark_pb2 12 | from mediapipe.tasks import python 13 | from mediapipe.tasks.python import vision 14 | from . import face_landmark 15 | 16 | CUR_DIR = os.path.dirname(__file__) 17 | 18 | 19 | class LMKExtractor(): 20 | def __init__(self, FPS=25): 21 | # Create an FaceLandmarker object. 22 | self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE 23 | base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task')) 24 | base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU 25 | options = vision.FaceLandmarkerOptions(base_options=base_options, 26 | running_mode=self.mode, 27 | output_face_blendshapes=True, 28 | output_facial_transformation_matrixes=True, 29 | num_faces=1) 30 | self.detector = face_landmark.FaceLandmarker.create_from_options(options) 31 | self.last_ts = 0 32 | self.frame_ms = int(1000 / FPS) 33 | 34 | det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite')) 35 | det_options = vision.FaceDetectorOptions(base_options=det_base_options) 36 | self.det_detector = vision.FaceDetector.create_from_options(det_options) 37 | 38 | 39 | def __call__(self, img): 40 | frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 41 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) 42 | t0 = time.time() 43 | if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO: 44 | det_result = self.det_detector.detect(image) 45 | if len(det_result.detections) != 1: 46 | return None 47 | self.last_ts += self.frame_ms 48 | try: 49 | detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts) 50 | except: 51 | return None 52 | elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE: 53 | # det_result = self.det_detector.detect(image) 54 | 55 | # if len(det_result.detections) != 1: 56 | # return None 57 | try: 58 | detection_result, mesh3d = self.detector.detect(image) 59 | except: 60 | return None 61 | 62 | 63 | bs_list = detection_result.face_blendshapes 64 | if len(bs_list) == 1: 65 | bs = bs_list[0] 66 | bs_values = [] 67 | for index in range(len(bs)): 68 | bs_values.append(bs[index].score) 69 | bs_values = bs_values[1:] # remove neutral 70 | trans_mat = detection_result.facial_transformation_matrixes[0] 71 | face_landmarks_list = detection_result.face_landmarks 72 | face_landmarks = face_landmarks_list[0] 73 | lmks = [] 74 | for index in range(len(face_landmarks)): 75 | x = face_landmarks[index].x 76 | y = face_landmarks[index].y 77 | z = face_landmarks[index].z 78 | lmks.append([x, y, z]) 79 | lmks = np.array(lmks) 80 | 81 | lmks3d = np.array(mesh3d.vertex_buffer) 82 | lmks3d = lmks3d.reshape(-1, 5)[:, :3] 83 | mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1 84 | 85 | return { 86 | "lmks": lmks, 87 | 'lmks3d': lmks3d, 88 | "trans_mat": trans_mat, 89 | 'faces': mp_tris, 90 | "bs": bs_values 91 | } 92 | else: 93 | # print('multiple faces in the image: {}'.format(img_path)) 94 | return None 95 | -------------------------------------------------------------------------------- /src/utils/pose_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | def create_perspective_matrix(aspect_ratio): 8 | kDegreesToRadians = np.pi / 180. 9 | near = 1 10 | far = 10000 11 | perspective_matrix = np.zeros(16, dtype=np.float32) 12 | 13 | # Standard perspective projection matrix calculations. 14 | f = 1.0 / np.tan(kDegreesToRadians * 63 / 2.) 15 | 16 | denom = 1.0 / (near - far) 17 | perspective_matrix[0] = f / aspect_ratio 18 | perspective_matrix[5] = f 19 | perspective_matrix[10] = (near + far) * denom 20 | perspective_matrix[11] = -1. 21 | perspective_matrix[14] = 1. * far * near * denom 22 | 23 | # If the environment's origin point location is in the top left corner, 24 | # then skip additional flip along Y-axis is required to render correctly. 25 | 26 | perspective_matrix[5] *= -1. 27 | return perspective_matrix 28 | 29 | 30 | def project_points(points_3d, transformation_matrix, pose_vectors, image_shape): 31 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T 32 | L, N, _ = points_3d.shape 33 | projected_points = np.zeros((L, N, 2)) 34 | for i in range(L): 35 | points_3d_frame = points_3d[i] 36 | ones = np.ones((points_3d_frame.shape[0], 1)) 37 | points_3d_homogeneous = np.hstack([points_3d_frame, ones]) 38 | transformed_points = points_3d_homogeneous @ (transformation_matrix @ euler_and_translation_to_matrix(pose_vectors[i][:3], pose_vectors[i][3:])).T @ P 39 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1 40 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1] 41 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0] 42 | projected_points[i] = projected_points_frame 43 | return projected_points 44 | 45 | 46 | def project_points_with_trans(points_3d, transformation_matrix, image_shape): 47 | P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T 48 | L, N, _ = points_3d.shape 49 | projected_points = np.zeros((L, N, 2)) 50 | for i in range(L): 51 | points_3d_frame = points_3d[i] 52 | ones = np.ones((points_3d_frame.shape[0], 1)) 53 | points_3d_homogeneous = np.hstack([points_3d_frame, ones]) 54 | transformed_points = points_3d_homogeneous @ transformation_matrix[i].T @ P 55 | projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1 56 | projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1] 57 | projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0] 58 | projected_points[i] = projected_points_frame 59 | return projected_points 60 | 61 | 62 | def euler_and_translation_to_matrix(euler_angles, translation_vector): 63 | rotation = R.from_euler('xyz', euler_angles, degrees=True) 64 | rotation_matrix = rotation.as_matrix() 65 | 66 | matrix = np.eye(4) 67 | matrix[:3, :3] = rotation_matrix 68 | matrix[:3, 3] = translation_vector 69 | 70 | return matrix 71 | 72 | 73 | def matrix_to_euler_and_translation(matrix): 74 | rotation_matrix = matrix[:3, :3] 75 | translation_vector = matrix[:3, 3] 76 | rotation = R.from_matrix(rotation_matrix) 77 | euler_angles = rotation.as_euler('xyz', degrees=True) 78 | return euler_angles, translation_vector 79 | 80 | 81 | def smooth_pose_seq(pose_seq, window_size=5): 82 | smoothed_pose_seq = np.zeros_like(pose_seq) 83 | 84 | for i in range(len(pose_seq)): 85 | start = max(0, i - window_size // 2) 86 | end = min(len(pose_seq), i + window_size // 2 + 1) 87 | smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) 88 | 89 | return smoothed_pose_seq -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import sys 6 | import cv2 7 | from pathlib import Path 8 | 9 | import av 10 | import numpy as np 11 | import torch 12 | import torchvision 13 | from einops import rearrange 14 | from PIL import Image 15 | 16 | 17 | def seed_everything(seed): 18 | import random 19 | 20 | import numpy as np 21 | 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | np.random.seed(seed % (2**32)) 25 | random.seed(seed) 26 | 27 | 28 | def import_filename(filename): 29 | spec = importlib.util.spec_from_file_location("mymodule", filename) 30 | module = importlib.util.module_from_spec(spec) 31 | sys.modules[spec.name] = module 32 | spec.loader.exec_module(module) 33 | return module 34 | 35 | 36 | def delete_additional_ckpt(base_path, num_keep): 37 | dirs = [] 38 | for d in os.listdir(base_path): 39 | if d.startswith("checkpoint-"): 40 | dirs.append(d) 41 | num_tot = len(dirs) 42 | if num_tot <= num_keep: 43 | return 44 | # ensure ckpt is sorted and delete the ealier! 45 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] 46 | for d in del_dirs: 47 | path_to_dir = osp.join(base_path, d) 48 | if osp.exists(path_to_dir): 49 | shutil.rmtree(path_to_dir) 50 | 51 | 52 | def save_videos_from_pil(pil_images, path, fps=8): 53 | import av 54 | 55 | save_fmt = Path(path).suffix 56 | os.makedirs(os.path.dirname(path), exist_ok=True) 57 | width, height = pil_images[0].size 58 | 59 | if save_fmt == ".mp4": 60 | codec = "libx264" 61 | container = av.open(path, "w") 62 | stream = container.add_stream(codec, rate=fps) 63 | 64 | stream.width = width 65 | stream.height = height 66 | 67 | for pil_image in pil_images: 68 | # pil_image = Image.fromarray(image_arr).convert("RGB") 69 | av_frame = av.VideoFrame.from_image(pil_image) 70 | container.mux(stream.encode(av_frame)) 71 | container.mux(stream.encode()) 72 | container.close() 73 | 74 | elif save_fmt == ".gif": 75 | pil_images[0].save( 76 | fp=path, 77 | format="GIF", 78 | append_images=pil_images[1:], 79 | save_all=True, 80 | duration=(1 / fps * 1000), 81 | loop=0, 82 | ) 83 | else: 84 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 85 | 86 | 87 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 88 | videos = rearrange(videos, "b c t h w -> t b c h w") 89 | height, width = videos.shape[-2:] 90 | outputs = [] 91 | 92 | for x in videos: 93 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 94 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 95 | if rescale: 96 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 97 | x = (x * 255).numpy().astype(np.uint8) 98 | x = Image.fromarray(x) 99 | 100 | outputs.append(x) 101 | 102 | os.makedirs(os.path.dirname(path), exist_ok=True) 103 | 104 | save_videos_from_pil(outputs, path, fps) 105 | 106 | 107 | def read_frames(video_path): 108 | container = av.open(video_path) 109 | 110 | video_stream = next(s for s in container.streams if s.type == "video") 111 | frames = [] 112 | for packet in container.demux(video_stream): 113 | for frame in packet.decode(): 114 | image = Image.frombytes( 115 | "RGB", 116 | (frame.width, frame.height), 117 | frame.to_rgb().to_ndarray(), 118 | ) 119 | frames.append(image) 120 | 121 | return frames 122 | 123 | 124 | def get_fps(video_path): 125 | container = av.open(video_path) 126 | video_stream = next(s for s in container.streams if s.type == "video") 127 | fps = video_stream.average_rate 128 | container.close() 129 | return fps 130 | 131 | def crop_face(img, lmk_extractor, expand=1.5): 132 | result = lmk_extractor(img) # cv2 BGR 133 | 134 | if result is None: 135 | return None 136 | 137 | H, W, _ = img.shape 138 | lmks = result['lmks'] 139 | lmks[:, 0] *= W 140 | lmks[:, 1] *= H 141 | 142 | x_min = np.min(lmks[:, 0]) 143 | x_max = np.max(lmks[:, 0]) 144 | y_min = np.min(lmks[:, 1]) 145 | y_max = np.max(lmks[:, 1]) 146 | 147 | width = x_max - x_min 148 | height = y_max - y_min 149 | 150 | if width*height >= W*H*0.15: 151 | if W == H: 152 | return img 153 | size = min(H, W) 154 | offset = int((max(H, W) - size)/2) 155 | if size == H: 156 | return img[:, offset:-offset] 157 | else: 158 | return img[offset:-offset, :] 159 | else: 160 | center_x = x_min + width / 2 161 | center_y = y_min + height / 2 162 | 163 | width *= expand 164 | height *= expand 165 | 166 | size = max(width, height) 167 | 168 | x_min = int(center_x - size / 2) 169 | x_max = int(center_x + size / 2) 170 | y_min = int(center_y - size / 2) 171 | y_max = int(center_y + size / 2) 172 | 173 | top = max(0, -y_min) 174 | bottom = max(0, y_max - img.shape[0]) 175 | left = max(0, -x_min) 176 | right = max(0, x_max - img.shape[1]) 177 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) 178 | 179 | cropped_img = img[y_min + top:y_max + top, x_min + left:x_max + left] 180 | 181 | return cropped_img --------------------------------------------------------------------------------