├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── app.py ├── assets └── mini_program_maliang.png ├── configs ├── inference │ ├── inference_v1.yaml │ ├── inference_v2.yaml │ ├── pose_images │ │ └── pose-1.png │ ├── pose_videos │ │ ├── anyone-video-1_kps.mp4 │ │ ├── anyone-video-2_kps.mp4 │ │ ├── anyone-video-4_kps.mp4 │ │ └── anyone-video-5_kps.mp4 │ ├── ref_images │ │ ├── anyone-1.png │ │ ├── anyone-10.png │ │ ├── anyone-11.png │ │ ├── anyone-2.png │ │ ├── anyone-3.png │ │ └── anyone-5.png │ ├── talkinghead_images │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ ├── 4.png │ │ └── 5.png │ └── talkinghead_videos │ │ ├── 1.mp4 │ │ ├── 2.mp4 │ │ ├── 3.mp4 │ │ └── 4.mp4 ├── prompts │ ├── animation.yaml │ ├── inference_reenact.yaml │ └── test_cases.py └── train │ ├── stage1.yaml │ └── stage2.yaml ├── requirements.txt ├── scripts ├── lmks2vid.py └── pose2vid.py ├── src ├── __init__.py ├── dataset │ ├── dance_image.py │ └── dance_video.py ├── dwpose │ ├── __init__.py │ ├── onnxdet.py │ ├── onnxpose.py │ ├── util.py │ └── wholebody.py ├── models │ ├── attention.py │ ├── motion_module.py │ ├── mutual_self_attention.py │ ├── pose_guider.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── transformer_3d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ ├── unet_3d.py │ └── unet_3d_blocks.py ├── pipelines │ ├── __init__.py │ ├── context.py │ ├── pipeline_lmks2vid_long.py │ ├── pipeline_pose2img.py │ ├── pipeline_pose2vid.py │ ├── pipeline_pose2vid_long.py │ └── utils.py └── utils │ └── util.py ├── tools ├── download_weights.py ├── extract_dwpose_from_vid.py ├── extract_meta_info.py ├── facetracker_api.py └── vid2pose.py ├── train_stage_1.py └── train_stage_2.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | pretrained_weights/ 3 | output/ 4 | .venv/ 5 | mlruns/ 6 | data/ 7 | 8 | *.pth 9 | *.pt 10 | *.pkl 11 | *.bin -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright @2023-2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | ============================================================== 2 | This repo also contains various third-party components and some code modified from other repos under other open source licenses. The following sections contain licensing infromation for such third-party libraries. 3 | 4 | ----------------------------- 5 | majic-animate 6 | BSD 3-Clause License 7 | Copyright (c) Bytedance Inc. 8 | 9 | ----------------------------- 10 | animatediff 11 | Apache License, Version 2.0 12 | 13 | ----------------------------- 14 | Dwpose 15 | Apache License, Version 2.0 16 | 17 | ----------------------------- 18 | inference pipeline for animatediff-cli-prompt-travel 19 | animatediff-cli-prompt-travel 20 | Apache License, Version 2.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤗 Introduction 2 | **update** 🔥🔥🔥 We propose a face reenactment method, based on our AnimateAnyone pipeline: Using the facial landmark of driving video to control the pose of given source image, and keeping the identity of source image. Specially, we disentangle head attitude (including eyes blink) and mouth motion from the landmark of driving video, and it can control the expression and movements of source face precisely. We release our inference codes and pretrained models of face reenactment!! 3 | 4 | 5 | **update** 🏋️🏋️🏋️ We release our training codes!! Now you can train your own AnimateAnyone models. See [here](#train) for more details. Have fun! 6 | 7 | **update**:🔥🔥🔥 We launch a HuggingFace Spaces demo of Moore-AnimateAnyone at [here](https://huggingface.co/spaces/xunsong/Moore-AnimateAnyone)!! 8 | 9 | This repository reproduces [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone). To align the results demonstrated by the original paper, we adopt various approaches and tricks, which may differ somewhat from the paper and another [implementation](https://github.com/guoqincode/Open-AnimateAnyone). 10 | 11 | It's worth noting that this is a very preliminary version, aiming for approximating the performance (roughly 80% under our test) showed in [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone). 12 | 13 | We will continue to develop it, and also welcome feedbacks and ideas from the community. The enhanced version will also be launched on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform. 14 | 15 | # 📝 Release Plans 16 | 17 | - [x] Inference codes and pretrained weights of AnimateAnyone 18 | - [x] Training scripts of AnimateAnyone 19 | - [x] Inference codes and pretrained weights of face reenactment 20 | - [ ] Training scripts of face reenactment 21 | - [ ] Inference scripts of audio driven portrait video generation 22 | - [ ] Training scripts of audio driven portrait video generation 23 | # 🎞️ Examples 24 | 25 | ## AnimateAnyone 26 | 27 | Here are some AnimateAnyone results we generated, with the resolution of 512x768. 28 | 29 | https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/f0454f30-6726-4ad4-80a7-5b7a15619057 30 | 31 | https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/337ff231-68a3-4760-a9f9-5113654acf48 32 | 33 | 34 | 35 | 36 | 39 | 42 | 43 | 44 | 45 | 48 | 51 | 52 |
37 | 38 | 40 | 41 |
46 | 47 | 49 | 50 |
53 | 54 | **Limitation**: We observe following shortcomings in current version: 55 | 1. The background may occur some artifacts, when the reference image has a clean background 56 | 2. Suboptimal results may arise when there is a scale mismatch between the reference image and keypoints. We have yet to implement preprocessing techniques as mentioned in the [paper](https://arxiv.org/pdf/2311.17117.pdf). 57 | 3. Some flickering and jittering may occur when the motion sequence is subtle or the scene is static. 58 | 59 | 60 | 61 | These issues will be addressed and improved in the near future. We appreciate your anticipation! 62 | 63 | ## Face Reenactment 64 | 65 | Here are some results we generated, with the resolution of 512x512. 66 | 67 | 68 | 69 | 70 | 73 | 76 | 77 | 78 | 79 | 82 | 85 | 86 |
71 | 72 | 74 | 75 |
80 | 81 | 83 | 84 |
87 | 88 | 89 | # ⚒️ Installation 90 | 91 | ## Build Environtment 92 | 93 | We Recommend a python version `>=3.10` and cuda version `=11.7`. Then build environment as follows: 94 | 95 | ```shell 96 | # [Optional] Create a virtual env 97 | python -m venv .venv 98 | source .venv/bin/activate 99 | # Install with pip: 100 | pip install -r requirements.txt 101 | # For face landmark extraction 102 | git clone https://github.com/emilianavt/OpenSeeFace.git 103 | ``` 104 | 105 | ## Download weights 106 | 107 | **Automatically downloading**: You can run the following command to download weights automatically: 108 | 109 | ```shell 110 | python tools/download_weights.py 111 | ``` 112 | 113 | Weights will be placed under the `./pretrained_weights` direcotry. The whole downloading process may take a long time. 114 | 115 | **Manually downloading**: You can also download weights manually, which has some steps: 116 | 117 | 1. Download our AnimateAnyone trained [weights](https://huggingface.co/patrolli/AnimateAnyone/tree/main), which include four parts: `denoising_unet.pth`, `reference_unet.pth`, `pose_guider.pth` and `motion_module.pth`. 118 | 119 | 2. Download our trained [weights](https://pan.baidu.com/s/1lS5CynyNfYlDbjowKKfG8g?pwd=crci) of face reenactment, and place these weights under `pretrained_weights`. 120 | 121 | 3. Download pretrained weight of based models and other components: 122 | - [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) 123 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 124 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder) 125 | 126 | 4. Download dwpose weights (`dw-ll_ucoco_384.onnx`, `yolox_l.onnx`) following [this](https://github.com/IDEA-Research/DWPose?tab=readme-ov-file#-dwpose-for-controlnet). 127 | 128 | Finally, these weights should be orgnized as follows: 129 | 130 | ```text 131 | ./pretrained_weights/ 132 | |-- DWPose 133 | | |-- dw-ll_ucoco_384.onnx 134 | | `-- yolox_l.onnx 135 | |-- image_encoder 136 | | |-- config.json 137 | | `-- pytorch_model.bin 138 | |-- denoising_unet.pth 139 | |-- motion_module.pth 140 | |-- pose_guider.pth 141 | |-- reference_unet.pth 142 | |-- sd-vae-ft-mse 143 | | |-- config.json 144 | | |-- diffusion_pytorch_model.bin 145 | | `-- diffusion_pytorch_model.safetensors 146 | |-- reenact 147 | | |-- denoising_unet.pth 148 | | |-- reference_unet.pth 149 | | |-- pose_guider1.pth 150 | | |-- pose_guider2.pth 151 | `-- stable-diffusion-v1-5 152 | |-- feature_extractor 153 | | `-- preprocessor_config.json 154 | |-- model_index.json 155 | |-- unet 156 | | |-- config.json 157 | | `-- diffusion_pytorch_model.bin 158 | `-- v1-inference.yaml 159 | ``` 160 | 161 | Note: If you have installed some of the pretrained models, such as `StableDiffusion V1.5`, you can specify their paths in the config file (e.g. `./config/prompts/animation.yaml`). 162 | 163 | # 🚀 Training and Inference 164 | 165 | ## Inference of AnimateAnyone 166 | 167 | Here is the cli command for running inference scripts: 168 | 169 | ```shell 170 | python -m scripts.pose2vid --config ./configs/prompts/animation.yaml -W 512 -H 784 -L 64 171 | ``` 172 | 173 | You can refer the format of `animation.yaml` to add your own reference images or pose videos. To convert the raw video into a pose video (keypoint sequence), you can run with the following command: 174 | 175 | ```shell 176 | python tools/vid2pose.py --video_path /path/to/your/video.mp4 177 | ``` 178 | 179 | ## Inference of Face Reenactment 180 | Here is the cli command for running inference scripts: 181 | 182 | ```shell 183 | python -m scripts.lmks2vid --config ./configs/prompts/inference_reenact.yaml --driving_video_path YOUR_OWN_DRIVING_VIDEO_PATH --source_image_path YOUR_OWN_SOURCE_IMAGE_PATH 184 | ``` 185 | We provide some face images in `./config/inference/talkinghead_images`, and some face videos in `./config/inference/talkinghead_videos` for inference. 186 | 187 | ## Training of AnimateAnyone 188 | 189 | Note: package dependencies have been updated, you may upgrade your environment via `pip install -r requirements.txt` before training. 190 | 191 | ### Data Preparation 192 | 193 | Extract keypoints from raw videos: 194 | 195 | ```shell 196 | python tools/extract_dwpose_from_vid.py --video_root /path/to/your/video_dir 197 | ``` 198 | 199 | Extract the meta info of dataset: 200 | 201 | ```shell 202 | python tools/extract_meta_info.py --root_path /path/to/your/video_dir --dataset_name anyone 203 | ``` 204 | 205 | Update lines in the training config file: 206 | 207 | ```yaml 208 | data: 209 | meta_paths: 210 | - "./data/anyone_meta.json" 211 | ``` 212 | 213 | ### Stage1 214 | 215 | Put [openpose controlnet weights](https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/tree/main) under `./pretrained_weights`, which is used to initialize the pose_guider. 216 | 217 | Put [sd-image-variation](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main) under `./pretrained_weights`, which is used to initialize unet weights. 218 | 219 | Run command: 220 | 221 | ```shell 222 | accelerate launch train_stage_1.py --config configs/train/stage1.yaml 223 | ``` 224 | 225 | ### Stage2 226 | 227 | Put the pretrained motion module weights `mm_sd_v15_v2.ckpt` ([download link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt)) under `./pretrained_weights`. 228 | 229 | Specify the stage1 training weights in the config file `stage2.yaml`, for example: 230 | 231 | ```yaml 232 | stage1_ckpt_dir: './exp_output/stage1' 233 | stage1_ckpt_step: 30000 234 | ``` 235 | 236 | Run command: 237 | 238 | ```shell 239 | accelerate launch train_stage_2.py --config configs/train/stage2.yaml 240 | ``` 241 | 242 | # 🎨 Gradio Demo 243 | 244 | **HuggingFace Demo**: We launch a quick preview demo of Moore-AnimateAnyone at [HuggingFace Spaces](https://huggingface.co/spaces/xunsong/Moore-AnimateAnyone)!! 245 | We appreciate the assistance provided by the HuggingFace team in setting up this demo. 246 | 247 | To reduce waiting time, we limit the size (width, height, and length) and inference steps when generating videos. 248 | 249 | If you have your own GPU resource (>= 16GB vram), you can run a local gradio app via following commands: 250 | 251 | `python app.py` 252 | 253 | # Community Contributions 254 | 255 | - Installation for Windows users: [Moore-AnimateAnyone-for-windows](https://github.com/sdbds/Moore-AnimateAnyone-for-windows) 256 | 257 | # 🖌️ Try on Mobi MaLiang 258 | 259 | We will launched this model on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform. Mobi MaLiang has now integrated various AIGC applications and functionalities (e.g. text-to-image, controllable generation...). You can experience it by [clicking this link](https://maliang.mthreads.com/) or scanning the QR code bellow via WeChat! 260 | 261 |

262 | 264 |

265 | 266 | # ⚖️ Disclaimer 267 | 268 | This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using the generative model. The project contributors have no legal affiliation with, nor accountability for, users' behaviors. It is imperative to use the generative model responsibly, adhering to both ethical and legal standards. 269 | 270 | # 🙏🏻 Acknowledgements 271 | 272 | We first thank the authors of [AnimateAnyone](). Additionally, we would like to thank the contributors to the [majic-animate](https://github.com/magic-research/magic-animate), [animatediff](https://github.com/guoyww/AnimateDiff) and [Open-AnimateAnyone](https://github.com/guoqincode/Open-AnimateAnyone) repositories, for their open research and exploration. Furthermore, our repo incorporates some codes from [dwpose](https://github.com/IDEA-Research/DWPose) and [animatediff-cli-prompt-travel](https://github.com/s9roll7/animatediff-cli-prompt-travel/), and we extend our thanks to them as well. 273 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from datetime import datetime 4 | 5 | import gradio as gr 6 | import numpy as np 7 | import torch 8 | from diffusers import AutoencoderKL, DDIMScheduler 9 | from einops import repeat 10 | from omegaconf import OmegaConf 11 | from PIL import Image 12 | from torchvision import transforms 13 | from transformers import CLIPVisionModelWithProjection 14 | 15 | from src.models.pose_guider import PoseGuider 16 | from src.models.unet_2d_condition import UNet2DConditionModel 17 | from src.models.unet_3d import UNet3DConditionModel 18 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 19 | from src.utils.util import get_fps, read_frames, save_videos_grid 20 | 21 | 22 | class AnimateController: 23 | def __init__( 24 | self, 25 | config_path="./configs/prompts/animation.yaml", 26 | weight_dtype=torch.float16, 27 | ): 28 | # Read pretrained weights path from config 29 | self.config = OmegaConf.load(config_path) 30 | self.pipeline = None 31 | self.weight_dtype = weight_dtype 32 | 33 | def animate( 34 | self, 35 | ref_image, 36 | pose_video_path, 37 | width=512, 38 | height=768, 39 | length=24, 40 | num_inference_steps=25, 41 | cfg=3.5, 42 | seed=123, 43 | ): 44 | generator = torch.manual_seed(seed) 45 | if isinstance(ref_image, np.ndarray): 46 | ref_image = Image.fromarray(ref_image) 47 | if self.pipeline is None: 48 | vae = AutoencoderKL.from_pretrained( 49 | self.config.pretrained_vae_path, 50 | ).to("cuda", dtype=self.weight_dtype) 51 | 52 | reference_unet = UNet2DConditionModel.from_pretrained( 53 | self.config.pretrained_base_model_path, 54 | subfolder="unet", 55 | ).to(dtype=self.weight_dtype, device="cuda") 56 | 57 | inference_config_path = self.config.inference_config 58 | infer_config = OmegaConf.load(inference_config_path) 59 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 60 | self.config.pretrained_base_model_path, 61 | self.config.motion_module_path, 62 | subfolder="unet", 63 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 64 | ).to(dtype=self.weight_dtype, device="cuda") 65 | 66 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( 67 | dtype=self.weight_dtype, device="cuda" 68 | ) 69 | 70 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 71 | self.config.image_encoder_path 72 | ).to(dtype=self.weight_dtype, device="cuda") 73 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 74 | scheduler = DDIMScheduler(**sched_kwargs) 75 | 76 | # load pretrained weights 77 | denoising_unet.load_state_dict( 78 | torch.load(self.config.denoising_unet_path, map_location="cpu"), 79 | strict=False, 80 | ) 81 | reference_unet.load_state_dict( 82 | torch.load(self.config.reference_unet_path, map_location="cpu"), 83 | ) 84 | pose_guider.load_state_dict( 85 | torch.load(self.config.pose_guider_path, map_location="cpu"), 86 | ) 87 | 88 | pipe = Pose2VideoPipeline( 89 | vae=vae, 90 | image_encoder=image_enc, 91 | reference_unet=reference_unet, 92 | denoising_unet=denoising_unet, 93 | pose_guider=pose_guider, 94 | scheduler=scheduler, 95 | ) 96 | pipe = pipe.to("cuda", dtype=self.weight_dtype) 97 | self.pipeline = pipe 98 | 99 | pose_images = read_frames(pose_video_path) 100 | src_fps = get_fps(pose_video_path) 101 | 102 | pose_list = [] 103 | pose_tensor_list = [] 104 | pose_transform = transforms.Compose( 105 | [transforms.Resize((height, width)), transforms.ToTensor()] 106 | ) 107 | for pose_image_pil in pose_images[:length]: 108 | pose_list.append(pose_image_pil) 109 | pose_tensor_list.append(pose_transform(pose_image_pil)) 110 | 111 | video = self.pipeline( 112 | ref_image, 113 | pose_list, 114 | width=width, 115 | height=height, 116 | video_length=length, 117 | num_inference_steps=num_inference_steps, 118 | guidance_scale=cfg, 119 | generator=generator, 120 | ).videos 121 | 122 | ref_image_tensor = pose_transform(ref_image) # (c, h, w) 123 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) 124 | ref_image_tensor = repeat( 125 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length 126 | ) 127 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 128 | pose_tensor = pose_tensor.transpose(0, 1) 129 | pose_tensor = pose_tensor.unsqueeze(0) 130 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) 131 | 132 | save_dir = f"./output/gradio" 133 | if not os.path.exists(save_dir): 134 | os.makedirs(save_dir, exist_ok=True) 135 | date_str = datetime.now().strftime("%Y%m%d") 136 | time_str = datetime.now().strftime("%H%M") 137 | out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4") 138 | save_videos_grid( 139 | video, 140 | out_path, 141 | n_rows=3, 142 | fps=src_fps, 143 | ) 144 | 145 | torch.cuda.empty_cache() 146 | 147 | return out_path 148 | 149 | 150 | controller = AnimateController() 151 | 152 | 153 | def ui(): 154 | with gr.Blocks() as demo: 155 | gr.Markdown( 156 | """ 157 | # Moore-AnimateAnyone Demo 158 | """ 159 | ) 160 | animation = gr.Video( 161 | format="mp4", 162 | label="Animation Results", 163 | height=448, 164 | autoplay=True, 165 | ) 166 | 167 | with gr.Row(): 168 | reference_image = gr.Image(label="Reference Image") 169 | motion_sequence = gr.Video( 170 | format="mp4", label="Motion Sequence", height=512 171 | ) 172 | 173 | with gr.Column(): 174 | width_slider = gr.Slider( 175 | label="Width", minimum=448, maximum=768, value=512, step=64 176 | ) 177 | height_slider = gr.Slider( 178 | label="Height", minimum=512, maximum=1024, value=768, step=64 179 | ) 180 | length_slider = gr.Slider( 181 | label="Video Length", minimum=24, maximum=128, value=24, step=24 182 | ) 183 | with gr.Row(): 184 | seed_textbox = gr.Textbox(label="Seed", value=-1) 185 | seed_button = gr.Button( 186 | value="\U0001F3B2", elem_classes="toolbutton" 187 | ) 188 | seed_button.click( 189 | fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), 190 | inputs=[], 191 | outputs=[seed_textbox], 192 | ) 193 | with gr.Row(): 194 | sampling_steps = gr.Slider( 195 | label="Sampling steps", 196 | value=25, 197 | info="default: 25", 198 | step=5, 199 | maximum=30, 200 | minimum=10, 201 | ) 202 | guidance_scale = gr.Slider( 203 | label="Guidance scale", 204 | value=3.5, 205 | info="default: 3.5", 206 | step=0.5, 207 | maximum=10, 208 | minimum=2.0, 209 | ) 210 | submit = gr.Button("Animate") 211 | 212 | def read_video(video): 213 | return video 214 | 215 | def read_image(image): 216 | return Image.fromarray(image) 217 | 218 | # when user uploads a new video 219 | motion_sequence.upload(read_video, motion_sequence, motion_sequence) 220 | # when `first_frame` is updated 221 | reference_image.upload(read_image, reference_image, reference_image) 222 | # when the `submit` button is clicked 223 | submit.click( 224 | controller.animate, 225 | [ 226 | reference_image, 227 | motion_sequence, 228 | width_slider, 229 | height_slider, 230 | length_slider, 231 | sampling_steps, 232 | guidance_scale, 233 | seed_textbox, 234 | ], 235 | animation, 236 | ) 237 | 238 | # Examples 239 | gr.Markdown("## Examples") 240 | gr.Examples( 241 | examples=[ 242 | [ 243 | "./configs/inference/ref_images/anyone-5.png", 244 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4", 245 | ], 246 | [ 247 | "./configs/inference/ref_images/anyone-10.png", 248 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4", 249 | ], 250 | [ 251 | "./configs/inference/ref_images/anyone-2.png", 252 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4", 253 | ], 254 | ], 255 | inputs=[reference_image, motion_sequence], 256 | outputs=animation, 257 | ) 258 | 259 | return demo 260 | 261 | 262 | demo = ui() 263 | demo.launch(share=True) 264 | -------------------------------------------------------------------------------- /assets/mini_program_maliang.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/assets/mini_program_maliang.png -------------------------------------------------------------------------------- /configs/inference/inference_v1.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | unet_use_cross_frame_attention: false 3 | unet_use_temporal_attention: false 4 | use_motion_module: true 5 | motion_module_resolutions: [1,2,4,8] 6 | motion_module_mid_block: false 7 | motion_module_decoder_only: false 8 | motion_module_type: "Vanilla" 9 | 10 | motion_module_kwargs: 11 | num_attention_heads: 8 12 | num_transformer_block: 1 13 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 14 | temporal_position_encoding: true 15 | temporal_position_encoding_max_len: 24 16 | temporal_attention_dim_div: 1 17 | 18 | noise_scheduler_kwargs: 19 | beta_start: 0.00085 20 | beta_end: 0.012 21 | beta_schedule: "linear" 22 | steps_offset: 1 23 | clip_sample: False -------------------------------------------------------------------------------- /configs/inference/inference_v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | motion_module_mid_block: true 12 | motion_module_decoder_only: false 13 | motion_module_type: Vanilla 14 | motion_module_kwargs: 15 | num_attention_heads: 8 16 | num_transformer_block: 1 17 | attention_block_types: 18 | - Temporal_Self 19 | - Temporal_Self 20 | temporal_position_encoding: true 21 | temporal_position_encoding_max_len: 32 22 | temporal_attention_dim_div: 1 23 | 24 | noise_scheduler_kwargs: 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "linear" 28 | clip_sample: false 29 | steps_offset: 1 30 | ### Zero-SNR params 31 | prediction_type: "v_prediction" 32 | rescale_betas_zero_snr: True 33 | timestep_spacing: "trailing" 34 | 35 | sampler: DDIM -------------------------------------------------------------------------------- /configs/inference/pose_images/pose-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_images/pose-1.png -------------------------------------------------------------------------------- /configs/inference/pose_videos/anyone-video-1_kps.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-1_kps.mp4 -------------------------------------------------------------------------------- /configs/inference/pose_videos/anyone-video-2_kps.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-2_kps.mp4 -------------------------------------------------------------------------------- /configs/inference/pose_videos/anyone-video-4_kps.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-4_kps.mp4 -------------------------------------------------------------------------------- /configs/inference/pose_videos/anyone-video-5_kps.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-5_kps.mp4 -------------------------------------------------------------------------------- /configs/inference/ref_images/anyone-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-1.png -------------------------------------------------------------------------------- /configs/inference/ref_images/anyone-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-10.png -------------------------------------------------------------------------------- /configs/inference/ref_images/anyone-11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-11.png -------------------------------------------------------------------------------- /configs/inference/ref_images/anyone-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-2.png -------------------------------------------------------------------------------- /configs/inference/ref_images/anyone-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-3.png -------------------------------------------------------------------------------- /configs/inference/ref_images/anyone-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-5.png -------------------------------------------------------------------------------- /configs/inference/talkinghead_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/1.png -------------------------------------------------------------------------------- /configs/inference/talkinghead_images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/2.png -------------------------------------------------------------------------------- /configs/inference/talkinghead_images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/3.png -------------------------------------------------------------------------------- /configs/inference/talkinghead_images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/4.png -------------------------------------------------------------------------------- /configs/inference/talkinghead_images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/5.png -------------------------------------------------------------------------------- /configs/inference/talkinghead_videos/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/1.mp4 -------------------------------------------------------------------------------- /configs/inference/talkinghead_videos/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/2.mp4 -------------------------------------------------------------------------------- /configs/inference/talkinghead_videos/3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/3.mp4 -------------------------------------------------------------------------------- /configs/inference/talkinghead_videos/4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/4.mp4 -------------------------------------------------------------------------------- /configs/prompts/animation.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: "./pretrained_weights/stable-diffusion-v1-5/" 2 | pretrained_vae_path: "./pretrained_weights/sd-vae-ft-mse" 3 | image_encoder_path: "./pretrained_weights/image_encoder" 4 | denoising_unet_path: "./pretrained_weights/denoising_unet.pth" 5 | reference_unet_path: "./pretrained_weights/reference_unet.pth" 6 | pose_guider_path: "./pretrained_weights/pose_guider.pth" 7 | motion_module_path: "./pretrained_weights/motion_module.pth" 8 | 9 | inference_config: "./configs/inference/inference_v2.yaml" 10 | weight_dtype: 'fp16' 11 | 12 | test_cases: 13 | "./configs/inference/ref_images/anyone-2.png": 14 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4" 15 | - "./configs/inference/pose_videos/anyone-video-5_kps.mp4" 16 | "./configs/inference/ref_images/anyone-10.png": 17 | - "./configs/inference/pose_videos/anyone-video-1_kps.mp4" 18 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4" 19 | "./configs/inference/ref_images/anyone-11.png": 20 | - "./configs/inference/pose_videos/anyone-video-1_kps.mp4" 21 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4" 22 | "./configs/inference/ref_images/anyone-3.png": 23 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4" 24 | - "./configs/inference/pose_videos/anyone-video-5_kps.mp4" 25 | "./configs/inference/ref_images/anyone-5.png": 26 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4" 27 | -------------------------------------------------------------------------------- /configs/prompts/inference_reenact.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: "./pretrained_weights/stable-diffusion-v1-5/" 2 | pretrained_vae_path: "./pretrained_weights/sd-vae-ft-mse" 3 | image_encoder_path: "./pretrained_weights/image_encoder" 4 | denoising_unet_path: "./pretrained_weights/reenact/denoising_unet.pth" 5 | reference_unet_path: "./pretrained_weights/reenact/reference_unet.pth" 6 | pose_guider1_path: "./pretrained_weights/reenact/pose_guider1.pth" 7 | pose_guider2_path: "./pretrained_weights/reenact/pose_guider2.pth" 8 | unet_additional_kwargs: 9 | task_type: "reenact" 10 | mode: "read" # "read" 11 | use_inflated_groupnorm: true 12 | unet_use_cross_frame_attention: false 13 | unet_use_temporal_attention: false 14 | use_motion_module: true 15 | motion_module_resolutions: 16 | - 1 17 | - 2 18 | - 4 19 | - 8 20 | motion_module_mid_block: true 21 | motion_module_decoder_only: false 22 | motion_module_type: Vanilla 23 | motion_module_kwargs: 24 | num_attention_heads: 8 25 | num_transformer_block: 1 26 | attention_block_types: 27 | - Temporal_Self 28 | - Temporal_Self 29 | temporal_position_encoding: true 30 | temporal_position_encoding_max_len: 32 31 | temporal_attention_dim_div: 1 32 | 33 | noise_scheduler_kwargs: 34 | beta_start: 0.00085 35 | beta_end: 0.012 36 | beta_schedule: "linear" 37 | # beta_schedule: "scaled_linear" 38 | clip_sample: false 39 | # set_alpha_to_one: False 40 | # skip_prk_steps: true 41 | steps_offset: 1 42 | ### Zero-SNR params 43 | # prediction_type: "v_prediction" 44 | # rescale_betas_zero_snr: True 45 | # timestep_spacing: "trailing" 46 | 47 | weight_dtype: float16 48 | sampler: DDIM 49 | -------------------------------------------------------------------------------- /configs/prompts/test_cases.py: -------------------------------------------------------------------------------- 1 | TestCasesDict = { 2 | 0: [ 3 | { 4 | "./configs/inference/ref_images/anyone-2.png": [ 5 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4", 6 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4", 7 | ] 8 | }, 9 | { 10 | "./configs/inference/ref_images/anyone-10.png": [ 11 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4", 12 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4", 13 | ] 14 | }, 15 | { 16 | "./configs/inference/ref_images/anyone-11.png": [ 17 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4", 18 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4", 19 | ] 20 | }, 21 | { 22 | "./configs/inference/anyone-ref-3.png": [ 23 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4", 24 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4", 25 | ] 26 | }, 27 | { 28 | "./configs/inference/ref_images/anyone-5.png": [ 29 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4" 30 | ] 31 | }, 32 | ], 33 | } 34 | -------------------------------------------------------------------------------- /configs/train/stage1.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 4 3 | train_width: 768 4 | train_height: 768 5 | meta_paths: 6 | - "./data/fashion_meta.json" 7 | # Margin of frame indexes between ref and tgt images 8 | sample_margin: 30 9 | 10 | solver: 11 | gradient_accumulation_steps: 1 12 | mixed_precision: 'fp16' 13 | enable_xformers_memory_efficient_attention: True 14 | gradient_checkpointing: False 15 | max_train_steps: 30000 16 | max_grad_norm: 1.0 17 | # lr 18 | learning_rate: 1.0e-5 19 | scale_lr: False 20 | lr_warmup_steps: 1 21 | lr_scheduler: 'constant' 22 | 23 | # optimizer 24 | use_8bit_adam: False 25 | adam_beta1: 0.9 26 | adam_beta2: 0.999 27 | adam_weight_decay: 1.0e-2 28 | adam_epsilon: 1.0e-8 29 | 30 | val: 31 | validation_steps: 200 32 | 33 | 34 | noise_scheduler_kwargs: 35 | num_train_timesteps: 1000 36 | beta_start: 0.00085 37 | beta_end: 0.012 38 | beta_schedule: "scaled_linear" 39 | steps_offset: 1 40 | clip_sample: false 41 | 42 | base_model_path: './pretrained_weights/sd-image-variations-diffusers' 43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse' 44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder' 45 | controlnet_openpose_path: './pretrained_weights/control_v11p_sd15_openpose/diffusion_pytorch_model.bin' 46 | 47 | weight_dtype: 'fp16' # [fp16, fp32] 48 | uncond_ratio: 0.1 49 | noise_offset: 0.05 50 | snr_gamma: 5.0 51 | enable_zero_snr: True 52 | pose_guider_pretrain: True 53 | 54 | seed: 12580 55 | resume_from_checkpoint: '' 56 | checkpointing_steps: 2000 57 | save_model_epoch_interval: 5 58 | exp_name: 'stage1' 59 | output_dir: './exp_output' -------------------------------------------------------------------------------- /configs/train/stage2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 1 3 | train_width: 512 4 | train_height: 512 5 | meta_paths: 6 | - "./data/fashion_meta.json" 7 | sample_rate: 4 8 | n_sample_frames: 24 9 | 10 | solver: 11 | gradient_accumulation_steps: 1 12 | mixed_precision: 'fp16' 13 | enable_xformers_memory_efficient_attention: True 14 | gradient_checkpointing: True 15 | max_train_steps: 10000 16 | max_grad_norm: 1.0 17 | # lr 18 | learning_rate: 1e-5 19 | scale_lr: False 20 | lr_warmup_steps: 1 21 | lr_scheduler: 'constant' 22 | 23 | # optimizer 24 | use_8bit_adam: True 25 | adam_beta1: 0.9 26 | adam_beta2: 0.999 27 | adam_weight_decay: 1.0e-2 28 | adam_epsilon: 1.0e-8 29 | 30 | val: 31 | validation_steps: 20 32 | 33 | 34 | noise_scheduler_kwargs: 35 | num_train_timesteps: 1000 36 | beta_start: 0.00085 37 | beta_end: 0.012 38 | beta_schedule: "linear" 39 | steps_offset: 1 40 | clip_sample: false 41 | 42 | base_model_path: './pretrained_weights/stable-diffusion-v1-5' 43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse' 44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder' 45 | mm_path: './pretrained_weights/mm_sd_v15_v2.ckpt' 46 | 47 | weight_dtype: 'fp16' # [fp16, fp32] 48 | uncond_ratio: 0.1 49 | noise_offset: 0.05 50 | snr_gamma: 5.0 51 | enable_zero_snr: True 52 | stage1_ckpt_dir: './exp_output/stage1' 53 | stage1_ckpt_step: 980 54 | 55 | seed: 12580 56 | resume_from_checkpoint: '' 57 | checkpointing_steps: 2000 58 | exp_name: 'stage2' 59 | output_dir: './exp_output' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | av==11.0.0 3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a 4 | decord==0.6.0 5 | diffusers==0.24.0 6 | einops==0.4.1 7 | gradio==3.41.2 8 | gradio_client==0.5.0 9 | imageio==2.33.0 10 | imageio-ffmpeg==0.4.9 11 | numpy==1.23.5 12 | omegaconf==2.2.3 13 | onnxruntime-gpu==1.16.3 14 | open-clip-torch==2.20.0 15 | opencv-contrib-python==4.8.1.78 16 | opencv-python==4.8.1.78 17 | Pillow==9.5.0 18 | scikit-image==0.21.0 19 | scikit-learn==1.3.2 20 | scipy==1.11.4 21 | torch==2.0.1 22 | torchdiffeq==0.2.3 23 | torchmetrics==1.2.1 24 | torchsde==0.2.5 25 | torchvision==0.15.2 26 | tqdm==4.66.1 27 | transformers==4.30.2 28 | mlflow==2.9.2 29 | xformers==0.0.22 30 | controlnet-aux==0.0.7 -------------------------------------------------------------------------------- /scripts/lmks2vid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from datetime import datetime 5 | from pathlib import Path 6 | from typing import List 7 | 8 | import av 9 | import cv2 10 | import numpy as np 11 | import torch 12 | 13 | # 初始化模型 14 | import torchvision 15 | from diffusers import AutoencoderKL, DDIMScheduler 16 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 17 | from einops import rearrange, repeat 18 | from omegaconf import OmegaConf 19 | from PIL import Image 20 | from torchvision import transforms 21 | from transformers import ( 22 | CLIPImageProcessor, 23 | CLIPTextModel, 24 | CLIPTokenizer, 25 | CLIPVisionModel, 26 | CLIPVisionModelWithProjection, 27 | ) 28 | 29 | import sys 30 | from src.models.unet_3d import UNet3DConditionModel 31 | from src.pipelines.pipeline_lmks2vid_long import Pose2VideoPipeline 32 | from src.models.pose_guider import PoseGuider 33 | from src.utils.util import get_fps, read_frames, save_videos_grid 34 | from tools.facetracker_api import face_image 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument( 40 | "--config", type=str, help="Path of inference configs", 41 | default="./configs/prompts/inference_reenact.yaml" 42 | ) 43 | parser.add_argument( 44 | "--save_dir", type=str, help="Path of save results", 45 | default="./output/stage2_infer" 46 | ) 47 | parser.add_argument( 48 | "--source_image_path", type=str, help="Path of source image", 49 | default="", 50 | ) 51 | parser.add_argument( 52 | "--driving_video_path", type=str, help="Path of driving video", 53 | default="", 54 | ) 55 | parser.add_argument( 56 | "--batch_size", 57 | type=int, 58 | default=320, 59 | help="Checkpoint step of pretrained model", 60 | ) 61 | parser.add_argument("--mask_ratio", type=float, default=0.55) # 0.55~0.6 62 | parser.add_argument("-W", type=int, default=512) 63 | parser.add_argument("-H", type=int, default=512) 64 | parser.add_argument("-L", type=int, default=24) 65 | parser.add_argument("--seed", type=int, default=42) 66 | parser.add_argument("--cfg", type=float, default=3.5) 67 | parser.add_argument("--steps", type=int, default=30) 68 | parser.add_argument("--fps", type=int, default=25) 69 | args = parser.parse_args() 70 | 71 | return args 72 | 73 | 74 | def lmks_vis(img, lms): 75 | # Visualize the mouth, nose, and entire face based on landmarks 76 | h, w, c = img.shape 77 | lms = lms[:, :2] 78 | mouth = lms[48:66] 79 | nose = lms[27:36] 80 | color = (0, 255, 0) 81 | # Center mouth and nose 82 | x_c, y_c = np.mean(lms[:, 0]), np.mean(lms[:, 1]) 83 | h_c, w_c = h // 2, w // 2 84 | img_face, img_mouth, img_nose = img.copy(), img.copy(), img.copy() 85 | for pt_num, (x, y) in enumerate(mouth): 86 | x = x - (x_c - w_c) 87 | y = y - (y_c - h_c) 88 | x = int(x + 0.5) 89 | y = int(y + 0.5) 90 | cv2.circle(img_mouth, (y, x), 1, color, -1) 91 | for pt_num, (x, y) in enumerate(nose): 92 | x = x - (x_c - w_c) 93 | y = y - (y_c - h_c) 94 | x = int(x + 0.5) 95 | y = int(y + 0.5) 96 | cv2.circle(img_nose, (y, x), 1, color, -1) 97 | for pt_num, (x, y) in enumerate(lms): 98 | x = int(x + 0.5) 99 | y = int(y + 0.5) 100 | if pt_num >= 66: 101 | color = (255, 255, 0) 102 | else: 103 | color = (0, 255, 0) 104 | cv2.circle(img_face, (y, x), 1, color, -1) 105 | return img_face, img_mouth, img_nose 106 | 107 | 108 | def batch_rearrange(pose_len, batch_size=24): 109 | # To rearrange the pose sequence based on batch size 110 | batch_ind_list = [] 111 | for i in range(0, pose_len, batch_size): 112 | if i + batch_size < pose_len: 113 | batch_ind_list.append(list(range(i, i + batch_size))) 114 | else: 115 | batch_ind_list.append(list(range(i, min(i + batch_size, pose_len)))) 116 | return batch_ind_list 117 | 118 | 119 | def lmks_video_extract(video_path): 120 | # To extract the landmark sequence of video (single face video) 121 | video_stream = cv2.VideoCapture(video_path) 122 | lmks_list, frames = [], [] 123 | while 1: 124 | still_reading, frame = video_stream.read() 125 | if not still_reading: 126 | video_stream.release() 127 | break 128 | h, w, c = frame.shape 129 | lmk_img, lmks = face_image(frame) 130 | if lmks is not None: 131 | lmks_list.append(lmks) 132 | frames.append(frame) 133 | return frames, np.array(lmks_list), [h, w] 134 | 135 | 136 | def adjust_pose(src_lms_list, src_size, ref_lms, ref_size): 137 | # To align the center of source landmarks based on reference landmark 138 | new_src_lms_list = [] 139 | ref_lms = ref_lms[:, :2] 140 | src_lms = src_lms_list[0][:, :2] 141 | ref_lms[:, 0] = ref_lms[:, 0] / ref_size[1] 142 | ref_lms[:, 1] = ref_lms[:, 1] / ref_size[0] 143 | src_lms[:, 0] = src_lms[:, 0] / src_size[1] 144 | src_lms[:, 1] = src_lms[:, 1] / src_size[0] 145 | ref_cx, ref_cy = np.mean(ref_lms[:, 0]), np.mean(ref_lms[:, 1]) 146 | src_cx, src_cy = np.mean(src_lms[:, 0]), np.mean(src_lms[:, 1]) 147 | for item in src_lms_list: 148 | item = item[:, :2] 149 | item[:, 0] = item[:, 0] - int((src_cx - ref_cx)) * src_size[1] 150 | item[:, 1] = item[:, 1] - int((src_cy - ref_cy)) * src_size[0] 151 | new_src_lms_list.append(item) 152 | return np.array(new_src_lms_list) 153 | 154 | 155 | def main(): 156 | args = parse_args() 157 | infer_config = OmegaConf.load(args.config) 158 | 159 | # base_model_path = "./pretrained_weights/huggingface-models/sd-image-variations-diffusers/" 160 | base_model_path = infer_config.pretrained_base_model_path 161 | weight_dtype = torch.float16 162 | 163 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 164 | # "./pretrained_weights/huggingface-models/sd-image-variations-diffusers/image_encoder" 165 | infer_config.image_encoder_path 166 | ).to(dtype=weight_dtype, device="cuda") 167 | vae = AutoencoderKL.from_pretrained( 168 | # "./pretrained_weights/huggingface-models/sd-vae-ft-mse" 169 | infer_config.pretrained_vae_path 170 | ).to("cuda", dtype=weight_dtype) 171 | # initial reference unet, denoise unet, pose guider 172 | reference_unet = UNet3DConditionModel.from_pretrained_2d( 173 | base_model_path, 174 | "", 175 | subfolder="unet", 176 | unet_additional_kwargs={ 177 | "task_type": "reenact", 178 | "use_motion_module": False, 179 | "unet_use_temporal_attention": False, 180 | "mode": "write", 181 | }, 182 | ).to(device="cuda", dtype=weight_dtype) 183 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 184 | base_model_path, 185 | "./pretrained_weights/mm_sd_v15_v2.ckpt", 186 | subfolder="unet", 187 | unet_additional_kwargs=OmegaConf.to_container( 188 | infer_config.unet_additional_kwargs 189 | ), 190 | # mm_zero_proj_out=True, 191 | ).to(device="cuda") 192 | pose_guider1 = PoseGuider( 193 | conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) 194 | ).to(device="cuda", dtype=weight_dtype) 195 | pose_guider2 = PoseGuider( 196 | conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) 197 | ).to(device="cuda", dtype=weight_dtype) 198 | print("------------------initial all networks------------------") 199 | # load model from pretrained models 200 | denoising_unet.load_state_dict( 201 | torch.load( 202 | infer_config.denoising_unet_path, 203 | map_location="cpu", 204 | ), 205 | strict=True, 206 | ) 207 | reference_unet.load_state_dict( 208 | torch.load( 209 | infer_config.reference_unet_path, 210 | map_location="cpu", 211 | ) 212 | ) 213 | pose_guider1.load_state_dict( 214 | torch.load( 215 | infer_config.pose_guider1_path, 216 | map_location="cpu", 217 | ) 218 | ) 219 | pose_guider2.load_state_dict( 220 | torch.load( 221 | infer_config.pose_guider2_path, 222 | map_location="cpu", 223 | ) 224 | ) 225 | print("---------load pretrained denoising unet, reference unet and pose guider----------") 226 | # scheduler 227 | enable_zero_snr = True 228 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 229 | if enable_zero_snr: 230 | sched_kwargs.update( 231 | rescale_betas_zero_snr=True, 232 | timestep_spacing="trailing", 233 | prediction_type="v_prediction", 234 | ) 235 | scheduler = DDIMScheduler(**sched_kwargs) 236 | pipe = Pose2VideoPipeline( 237 | vae=vae, 238 | image_encoder=image_enc, 239 | reference_unet=reference_unet, 240 | denoising_unet=denoising_unet, 241 | pose_guider1=pose_guider1, 242 | pose_guider2=pose_guider2, 243 | scheduler=scheduler, 244 | ) 245 | pipe = pipe.to("cuda", dtype=weight_dtype) 246 | height, width, clip_length = args.H, args.W, args.L 247 | generator = torch.manual_seed(42) 248 | date_str = datetime.now().strftime("%Y%m%d") 249 | save_dir = Path(f"{args.save_dir}/{date_str}") 250 | save_dir.mkdir(exist_ok=True, parents=True) 251 | 252 | ref_image_path, pose_video_path = args.source_image_path, args.driving_video_path 253 | ref_name = Path(ref_image_path).stem 254 | pose_name = Path(pose_video_path).stem 255 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 256 | ref_image = cv2.imread(ref_image_path) 257 | ref_h, ref_w, c = ref_image.shape 258 | ref_pose, ref_pose_lms = face_image(ref_image) 259 | # To extract landmarks from driving video 260 | pose_frames, pose_lms_list, pose_size = lmks_video_extract(pose_video_path) 261 | pose_lms_list = adjust_pose(pose_lms_list, pose_size, ref_pose_lms, [ref_h, ref_w]) 262 | pose_h, pose_w = int(pose_size[0]), int(pose_size[1]) 263 | pose_len = pose_lms_list.shape[0] 264 | # Truncating the video tail if its frames less than 24 to obtain stable effect. 265 | pose_len = pose_len // 24 * 24 266 | batch_index_list = batch_rearrange(pose_len, args.batch_size) 267 | pose_transform = transforms.Compose( 268 | [transforms.Resize((height, width)), transforms.ToTensor()] 269 | ) 270 | videos = [] 271 | zero_map = np.zeros_like(ref_pose) 272 | zero_map = cv2.resize(zero_map, (pose_w, pose_h)) 273 | for batch_index in batch_index_list: 274 | pose_list, pose_up_list, pose_down_list = [], [], [] 275 | pose_frame_list = [] 276 | pose_tensor_list, pose_up_tensor_list, pose_down_tensor_list = [], [], [] 277 | batch_len = len(batch_index) 278 | for pose_idx in batch_index: 279 | pose_lms = pose_lms_list[pose_idx] 280 | pose_frame = pose_frames[pose_idx][:, :, ::-1] 281 | pose_image, pose_mouth_image, _ = lmks_vis(zero_map, pose_lms) 282 | h, w, c = pose_image.shape 283 | pose_up_image = pose_image.copy() 284 | pose_up_image[int(h * args.mask_ratio):, :, :] = 0. 285 | pose_image_pil = Image.fromarray(pose_image) 286 | pose_frame = Image.fromarray(pose_frame) 287 | pose_up_pil = Image.fromarray(pose_up_image) 288 | pose_mouth_pil = Image.fromarray(pose_mouth_image) 289 | pose_list.append(pose_image_pil) 290 | pose_up_list.append(pose_up_pil) 291 | pose_down_list.append(pose_mouth_pil) 292 | pose_tensor_list.append(pose_transform(pose_image_pil)) 293 | pose_up_tensor_list.append(pose_transform(pose_up_pil)) 294 | pose_down_tensor_list.append(pose_transform(pose_mouth_pil)) 295 | pose_frame_list.append(pose_transform(pose_frame)) 296 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 297 | pose_tensor = pose_tensor.transpose(0, 1) 298 | pose_tensor = pose_tensor.unsqueeze(0) 299 | pose_frames_tensor = torch.stack(pose_frame_list, dim=0) # (f, c, h, w) 300 | pose_frames_tensor = pose_frames_tensor.transpose(0, 1) 301 | pose_frames_tensor = pose_frames_tensor.unsqueeze(0) 302 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 303 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) 304 | ref_image_tensor = repeat( 305 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=batch_len 306 | ) 307 | # To disentangle head attitude control (including eyes blink) and mouth motion control 308 | pipeline_output = pipe( 309 | ref_image_pil, 310 | pose_up_list, 311 | pose_down_list, 312 | width, 313 | height, 314 | batch_len, 315 | 20, 316 | 3.5, 317 | generator=generator, 318 | ) 319 | video = pipeline_output.videos 320 | video = torch.cat([ref_image_tensor, pose_frames_tensor, video], dim=0) 321 | videos.append(video) 322 | videos = torch.cat(videos, dim=2) 323 | time_str = datetime.now().strftime("%H%M") 324 | save_video_path = f"{save_dir}/{ref_name}_{pose_name}_{time_str}.mp4" 325 | save_videos_grid( 326 | videos, 327 | save_video_path, 328 | n_rows=3, 329 | fps=args.fps, 330 | ) 331 | print("infer results: {}".format(save_video_path)) 332 | del pipe 333 | torch.cuda.empty_cache() 334 | 335 | 336 | if __name__ == "__main__": 337 | main() 338 | -------------------------------------------------------------------------------- /scripts/pose2vid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import av 8 | import numpy as np 9 | import torch 10 | import torchvision 11 | from diffusers import AutoencoderKL, DDIMScheduler 12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 13 | from einops import repeat 14 | from omegaconf import OmegaConf 15 | from PIL import Image 16 | from torchvision import transforms 17 | from transformers import CLIPVisionModelWithProjection 18 | 19 | from configs.prompts.test_cases import TestCasesDict 20 | from src.models.pose_guider import PoseGuider 21 | from src.models.unet_2d_condition import UNet2DConditionModel 22 | from src.models.unet_3d import UNet3DConditionModel 23 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 24 | from src.utils.util import get_fps, read_frames, save_videos_grid 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--config") 30 | parser.add_argument("-W", type=int, default=512) 31 | parser.add_argument("-H", type=int, default=784) 32 | parser.add_argument("-L", type=int, default=24) 33 | parser.add_argument("--seed", type=int, default=42) 34 | parser.add_argument("--cfg", type=float, default=3.5) 35 | parser.add_argument("--steps", type=int, default=30) 36 | parser.add_argument("--fps", type=int) 37 | args = parser.parse_args() 38 | 39 | return args 40 | 41 | 42 | def main(): 43 | args = parse_args() 44 | 45 | config = OmegaConf.load(args.config) 46 | 47 | if config.weight_dtype == "fp16": 48 | weight_dtype = torch.float16 49 | else: 50 | weight_dtype = torch.float32 51 | 52 | vae = AutoencoderKL.from_pretrained( 53 | config.pretrained_vae_path, 54 | ).to("cuda", dtype=weight_dtype) 55 | 56 | reference_unet = UNet2DConditionModel.from_pretrained( 57 | config.pretrained_base_model_path, 58 | subfolder="unet", 59 | ).to(dtype=weight_dtype, device="cuda") 60 | 61 | inference_config_path = config.inference_config 62 | infer_config = OmegaConf.load(inference_config_path) 63 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 64 | config.pretrained_base_model_path, 65 | config.motion_module_path, 66 | subfolder="unet", 67 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 68 | ).to(dtype=weight_dtype, device="cuda") 69 | 70 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( 71 | dtype=weight_dtype, device="cuda" 72 | ) 73 | 74 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 75 | config.image_encoder_path 76 | ).to(dtype=weight_dtype, device="cuda") 77 | 78 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 79 | scheduler = DDIMScheduler(**sched_kwargs) 80 | 81 | generator = torch.manual_seed(args.seed) 82 | 83 | width, height = args.W, args.H 84 | 85 | # load pretrained weights 86 | denoising_unet.load_state_dict( 87 | torch.load(config.denoising_unet_path, map_location="cpu"), 88 | strict=False, 89 | ) 90 | reference_unet.load_state_dict( 91 | torch.load(config.reference_unet_path, map_location="cpu"), 92 | ) 93 | pose_guider.load_state_dict( 94 | torch.load(config.pose_guider_path, map_location="cpu"), 95 | ) 96 | 97 | pipe = Pose2VideoPipeline( 98 | vae=vae, 99 | image_encoder=image_enc, 100 | reference_unet=reference_unet, 101 | denoising_unet=denoising_unet, 102 | pose_guider=pose_guider, 103 | scheduler=scheduler, 104 | ) 105 | pipe = pipe.to("cuda", dtype=weight_dtype) 106 | 107 | date_str = datetime.now().strftime("%Y%m%d") 108 | time_str = datetime.now().strftime("%H%M") 109 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}" 110 | 111 | save_dir = Path(f"output/{date_str}/{save_dir_name}") 112 | save_dir.mkdir(exist_ok=True, parents=True) 113 | 114 | for ref_image_path in config["test_cases"].keys(): 115 | # Each ref_image may correspond to multiple actions 116 | for pose_video_path in config["test_cases"][ref_image_path]: 117 | ref_name = Path(ref_image_path).stem 118 | pose_name = Path(pose_video_path).stem.replace("_kps", "") 119 | 120 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 121 | 122 | pose_list = [] 123 | pose_tensor_list = [] 124 | pose_images = read_frames(pose_video_path) 125 | src_fps = get_fps(pose_video_path) 126 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") 127 | pose_transform = transforms.Compose( 128 | [transforms.Resize((height, width)), transforms.ToTensor()] 129 | ) 130 | for pose_image_pil in pose_images[: args.L]: 131 | pose_tensor_list.append(pose_transform(pose_image_pil)) 132 | pose_list.append(pose_image_pil) 133 | 134 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 135 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze( 136 | 0 137 | ) # (1, c, 1, h, w) 138 | ref_image_tensor = repeat( 139 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=args.L 140 | ) 141 | 142 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 143 | pose_tensor = pose_tensor.transpose(0, 1) 144 | pose_tensor = pose_tensor.unsqueeze(0) 145 | 146 | video = pipe( 147 | ref_image_pil, 148 | pose_list, 149 | width, 150 | height, 151 | args.L, 152 | args.steps, 153 | args.cfg, 154 | generator=generator, 155 | ).videos 156 | 157 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) 158 | save_videos_grid( 159 | video, 160 | f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.mp4", 161 | n_rows=3, 162 | fps=src_fps if args.fps is None else args.fps, 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/src/__init__.py -------------------------------------------------------------------------------- /src/dataset/dance_image.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | from decord import VideoReader 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from transformers import CLIPImageProcessor 10 | 11 | 12 | class HumanDanceDataset(Dataset): 13 | def __init__( 14 | self, 15 | img_size, 16 | img_scale=(1.0, 1.0), 17 | img_ratio=(0.9, 1.0), 18 | drop_ratio=0.1, 19 | data_meta_paths=["./data/fahsion_meta.json"], 20 | sample_margin=30, 21 | ): 22 | super().__init__() 23 | 24 | self.img_size = img_size 25 | self.img_scale = img_scale 26 | self.img_ratio = img_ratio 27 | self.sample_margin = sample_margin 28 | 29 | # ----- 30 | # vid_meta format: 31 | # [{'video_path': , 'kps_path': , 'other':}, 32 | # {'video_path': , 'kps_path': , 'other':}] 33 | # ----- 34 | vid_meta = [] 35 | for data_meta_path in data_meta_paths: 36 | vid_meta.extend(json.load(open(data_meta_path, "r"))) 37 | self.vid_meta = vid_meta 38 | 39 | self.clip_image_processor = CLIPImageProcessor() 40 | 41 | self.transform = transforms.Compose( 42 | [ 43 | transforms.RandomResizedCrop( 44 | self.img_size, 45 | scale=self.img_scale, 46 | ratio=self.img_ratio, 47 | interpolation=transforms.InterpolationMode.BILINEAR, 48 | ), 49 | transforms.ToTensor(), 50 | transforms.Normalize([0.5], [0.5]), 51 | ] 52 | ) 53 | 54 | self.cond_transform = transforms.Compose( 55 | [ 56 | transforms.RandomResizedCrop( 57 | self.img_size, 58 | scale=self.img_scale, 59 | ratio=self.img_ratio, 60 | interpolation=transforms.InterpolationMode.BILINEAR, 61 | ), 62 | transforms.ToTensor(), 63 | ] 64 | ) 65 | 66 | self.drop_ratio = drop_ratio 67 | 68 | def augmentation(self, image, transform, state=None): 69 | if state is not None: 70 | torch.set_rng_state(state) 71 | return transform(image) 72 | 73 | def __getitem__(self, index): 74 | video_meta = self.vid_meta[index] 75 | video_path = video_meta["video_path"] 76 | kps_path = video_meta["kps_path"] 77 | 78 | video_reader = VideoReader(video_path) 79 | kps_reader = VideoReader(kps_path) 80 | 81 | assert len(video_reader) == len( 82 | kps_reader 83 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}" 84 | 85 | video_length = len(video_reader) 86 | 87 | margin = min(self.sample_margin, video_length) 88 | 89 | ref_img_idx = random.randint(0, video_length - 1) 90 | if ref_img_idx + margin < video_length: 91 | tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1) 92 | elif ref_img_idx - margin > 0: 93 | tgt_img_idx = random.randint(0, ref_img_idx - margin) 94 | else: 95 | tgt_img_idx = random.randint(0, video_length - 1) 96 | 97 | ref_img = video_reader[ref_img_idx] 98 | ref_img_pil = Image.fromarray(ref_img.asnumpy()) 99 | tgt_img = video_reader[tgt_img_idx] 100 | tgt_img_pil = Image.fromarray(tgt_img.asnumpy()) 101 | 102 | tgt_pose = kps_reader[tgt_img_idx] 103 | tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy()) 104 | 105 | state = torch.get_rng_state() 106 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state) 107 | tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state) 108 | ref_img_vae = self.augmentation(ref_img_pil, self.transform, state) 109 | clip_image = self.clip_image_processor( 110 | images=ref_img_pil, return_tensors="pt" 111 | ).pixel_values[0] 112 | 113 | sample = dict( 114 | video_dir=video_path, 115 | img=tgt_img, 116 | tgt_pose=tgt_pose_img, 117 | ref_img=ref_img_vae, 118 | clip_images=clip_image, 119 | ) 120 | 121 | return sample 122 | 123 | def __len__(self): 124 | return len(self.vid_meta) 125 | -------------------------------------------------------------------------------- /src/dataset/dance_video.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from typing import List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torchvision.transforms as transforms 9 | from decord import VideoReader 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | from transformers import CLIPImageProcessor 13 | 14 | 15 | class HumanDanceVideoDataset(Dataset): 16 | def __init__( 17 | self, 18 | sample_rate, 19 | n_sample_frames, 20 | width, 21 | height, 22 | img_scale=(1.0, 1.0), 23 | img_ratio=(0.9, 1.0), 24 | drop_ratio=0.1, 25 | data_meta_paths=["./data/fashion_meta.json"], 26 | ): 27 | super().__init__() 28 | self.sample_rate = sample_rate 29 | self.n_sample_frames = n_sample_frames 30 | self.width = width 31 | self.height = height 32 | self.img_scale = img_scale 33 | self.img_ratio = img_ratio 34 | 35 | vid_meta = [] 36 | for data_meta_path in data_meta_paths: 37 | vid_meta.extend(json.load(open(data_meta_path, "r"))) 38 | self.vid_meta = vid_meta 39 | 40 | self.clip_image_processor = CLIPImageProcessor() 41 | 42 | self.pixel_transform = transforms.Compose( 43 | [ 44 | transforms.RandomResizedCrop( 45 | (height, width), 46 | scale=self.img_scale, 47 | ratio=self.img_ratio, 48 | interpolation=transforms.InterpolationMode.BILINEAR, 49 | ), 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.5], [0.5]), 52 | ] 53 | ) 54 | 55 | self.cond_transform = transforms.Compose( 56 | [ 57 | transforms.RandomResizedCrop( 58 | (height, width), 59 | scale=self.img_scale, 60 | ratio=self.img_ratio, 61 | interpolation=transforms.InterpolationMode.BILINEAR, 62 | ), 63 | transforms.ToTensor(), 64 | ] 65 | ) 66 | 67 | self.drop_ratio = drop_ratio 68 | 69 | def augmentation(self, images, transform, state=None): 70 | if state is not None: 71 | torch.set_rng_state(state) 72 | if isinstance(images, List): 73 | transformed_images = [transform(img) for img in images] 74 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) 75 | else: 76 | ret_tensor = transform(images) # (c, h, w) 77 | return ret_tensor 78 | 79 | def __getitem__(self, index): 80 | video_meta = self.vid_meta[index] 81 | video_path = video_meta["video_path"] 82 | kps_path = video_meta["kps_path"] 83 | 84 | video_reader = VideoReader(video_path) 85 | kps_reader = VideoReader(kps_path) 86 | 87 | assert len(video_reader) == len( 88 | kps_reader 89 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}" 90 | 91 | video_length = len(video_reader) 92 | 93 | clip_length = min( 94 | video_length, (self.n_sample_frames - 1) * self.sample_rate + 1 95 | ) 96 | start_idx = random.randint(0, video_length - clip_length) 97 | batch_index = np.linspace( 98 | start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int 99 | ).tolist() 100 | 101 | # read frames and kps 102 | vid_pil_image_list = [] 103 | pose_pil_image_list = [] 104 | for index in batch_index: 105 | img = video_reader[index] 106 | vid_pil_image_list.append(Image.fromarray(img.asnumpy())) 107 | img = kps_reader[index] 108 | pose_pil_image_list.append(Image.fromarray(img.asnumpy())) 109 | 110 | ref_img_idx = random.randint(0, video_length - 1) 111 | ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy()) 112 | 113 | # transform 114 | state = torch.get_rng_state() 115 | pixel_values_vid = self.augmentation( 116 | vid_pil_image_list, self.pixel_transform, state 117 | ) 118 | pixel_values_pose = self.augmentation( 119 | pose_pil_image_list, self.cond_transform, state 120 | ) 121 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) 122 | clip_ref_img = self.clip_image_processor( 123 | images=ref_img, return_tensors="pt" 124 | ).pixel_values[0] 125 | 126 | sample = dict( 127 | video_dir=video_path, 128 | pixel_values_vid=pixel_values_vid, 129 | pixel_values_pose=pixel_values_pose, 130 | pixel_values_ref_img=pixel_values_ref_img, 131 | clip_ref_img=clip_ref_img, 132 | ) 133 | 134 | return sample 135 | 136 | def __len__(self): 137 | return len(self.vid_meta) 138 | -------------------------------------------------------------------------------- /src/dwpose/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | # Openpose 3 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 4 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 5 | # 3rd Edited by ControlNet 6 | # 4th Edited by ControlNet (added face and correct hands) 7 | 8 | import copy 9 | import os 10 | 11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 12 | import cv2 13 | import numpy as np 14 | import torch 15 | from controlnet_aux.util import HWC3, resize_image 16 | from PIL import Image 17 | 18 | from . import util 19 | from .wholebody import Wholebody 20 | 21 | 22 | def draw_pose(pose, H, W): 23 | bodies = pose["bodies"] 24 | faces = pose["faces"] 25 | hands = pose["hands"] 26 | candidate = bodies["candidate"] 27 | subset = bodies["subset"] 28 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 29 | 30 | canvas = util.draw_bodypose(canvas, candidate, subset) 31 | 32 | canvas = util.draw_handpose(canvas, hands) 33 | 34 | canvas = util.draw_facepose(canvas, faces) 35 | 36 | return canvas 37 | 38 | 39 | class DWposeDetector: 40 | def __init__(self): 41 | pass 42 | 43 | def to(self, device): 44 | self.pose_estimation = Wholebody(device) 45 | return self 46 | 47 | def cal_height(self, input_image): 48 | input_image = cv2.cvtColor( 49 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR 50 | ) 51 | 52 | input_image = HWC3(input_image) 53 | H, W, C = input_image.shape 54 | with torch.no_grad(): 55 | candidate, subset = self.pose_estimation(input_image) 56 | nums, keys, locs = candidate.shape 57 | # candidate[..., 0] /= float(W) 58 | # candidate[..., 1] /= float(H) 59 | body = candidate 60 | return body[0, ..., 1].min(), body[..., 1].max() - body[..., 1].min() 61 | 62 | def __call__( 63 | self, 64 | input_image, 65 | detect_resolution=512, 66 | image_resolution=512, 67 | output_type="pil", 68 | **kwargs, 69 | ): 70 | input_image = cv2.cvtColor( 71 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR 72 | ) 73 | 74 | input_image = HWC3(input_image) 75 | input_image = resize_image(input_image, detect_resolution) 76 | H, W, C = input_image.shape 77 | with torch.no_grad(): 78 | candidate, subset = self.pose_estimation(input_image) 79 | nums, keys, locs = candidate.shape 80 | candidate[..., 0] /= float(W) 81 | candidate[..., 1] /= float(H) 82 | score = subset[:, :18] 83 | max_ind = np.mean(score, axis=-1).argmax(axis=0) 84 | score = score[[max_ind]] 85 | body = candidate[:, :18].copy() 86 | body = body[[max_ind]] 87 | nums = 1 88 | body = body.reshape(nums * 18, locs) 89 | body_score = copy.deepcopy(score) 90 | for i in range(len(score)): 91 | for j in range(len(score[i])): 92 | if score[i][j] > 0.3: 93 | score[i][j] = int(18 * i + j) 94 | else: 95 | score[i][j] = -1 96 | 97 | un_visible = subset < 0.3 98 | candidate[un_visible] = -1 99 | 100 | foot = candidate[:, 18:24] 101 | 102 | faces = candidate[[max_ind], 24:92] 103 | 104 | hands = candidate[[max_ind], 92:113] 105 | hands = np.vstack([hands, candidate[[max_ind], 113:]]) 106 | 107 | bodies = dict(candidate=body, subset=score) 108 | pose = dict(bodies=bodies, hands=hands, faces=faces) 109 | 110 | detected_map = draw_pose(pose, H, W) 111 | detected_map = HWC3(detected_map) 112 | 113 | img = resize_image(input_image, image_resolution) 114 | H, W, C = img.shape 115 | 116 | detected_map = cv2.resize( 117 | detected_map, (W, H), interpolation=cv2.INTER_LINEAR 118 | ) 119 | 120 | if output_type == "pil": 121 | detected_map = Image.fromarray(detected_map) 122 | 123 | return detected_map, body_score 124 | -------------------------------------------------------------------------------- /src/dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | import cv2 3 | import numpy as np 4 | import onnxruntime 5 | 6 | 7 | def nms(boxes, scores, nms_thr): 8 | """Single class NMS implemented in Numpy.""" 9 | x1 = boxes[:, 0] 10 | y1 = boxes[:, 1] 11 | x2 = boxes[:, 2] 12 | y2 = boxes[:, 3] 13 | 14 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 15 | order = scores.argsort()[::-1] 16 | 17 | keep = [] 18 | while order.size > 0: 19 | i = order[0] 20 | keep.append(i) 21 | xx1 = np.maximum(x1[i], x1[order[1:]]) 22 | yy1 = np.maximum(y1[i], y1[order[1:]]) 23 | xx2 = np.minimum(x2[i], x2[order[1:]]) 24 | yy2 = np.minimum(y2[i], y2[order[1:]]) 25 | 26 | w = np.maximum(0.0, xx2 - xx1 + 1) 27 | h = np.maximum(0.0, yy2 - yy1 + 1) 28 | inter = w * h 29 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 30 | 31 | inds = np.where(ovr <= nms_thr)[0] 32 | order = order[inds + 1] 33 | 34 | return keep 35 | 36 | 37 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 38 | """Multiclass NMS implemented in Numpy. Class-aware version.""" 39 | final_dets = [] 40 | num_classes = scores.shape[1] 41 | for cls_ind in range(num_classes): 42 | cls_scores = scores[:, cls_ind] 43 | valid_score_mask = cls_scores > score_thr 44 | if valid_score_mask.sum() == 0: 45 | continue 46 | else: 47 | valid_scores = cls_scores[valid_score_mask] 48 | valid_boxes = boxes[valid_score_mask] 49 | keep = nms(valid_boxes, valid_scores, nms_thr) 50 | if len(keep) > 0: 51 | cls_inds = np.ones((len(keep), 1)) * cls_ind 52 | dets = np.concatenate( 53 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 54 | ) 55 | final_dets.append(dets) 56 | if len(final_dets) == 0: 57 | return None 58 | return np.concatenate(final_dets, 0) 59 | 60 | 61 | def demo_postprocess(outputs, img_size, p6=False): 62 | grids = [] 63 | expanded_strides = [] 64 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 65 | 66 | hsizes = [img_size[0] // stride for stride in strides] 67 | wsizes = [img_size[1] // stride for stride in strides] 68 | 69 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 70 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 71 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 72 | grids.append(grid) 73 | shape = grid.shape[:2] 74 | expanded_strides.append(np.full((*shape, 1), stride)) 75 | 76 | grids = np.concatenate(grids, 1) 77 | expanded_strides = np.concatenate(expanded_strides, 1) 78 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 79 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 80 | 81 | return outputs 82 | 83 | 84 | def preprocess(img, input_size, swap=(2, 0, 1)): 85 | if len(img.shape) == 3: 86 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 87 | else: 88 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 89 | 90 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 91 | resized_img = cv2.resize( 92 | img, 93 | (int(img.shape[1] * r), int(img.shape[0] * r)), 94 | interpolation=cv2.INTER_LINEAR, 95 | ).astype(np.uint8) 96 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 97 | 98 | padded_img = padded_img.transpose(swap) 99 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 100 | return padded_img, r 101 | 102 | 103 | def inference_detector(session, oriImg): 104 | input_shape = (640, 640) 105 | img, ratio = preprocess(oriImg, input_shape) 106 | 107 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 108 | output = session.run(None, ort_inputs) 109 | predictions = demo_postprocess(output[0], input_shape)[0] 110 | 111 | boxes = predictions[:, :4] 112 | scores = predictions[:, 4:5] * predictions[:, 5:] 113 | 114 | boxes_xyxy = np.ones_like(boxes) 115 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0 116 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0 117 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0 118 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0 119 | boxes_xyxy /= ratio 120 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 121 | if dets is not None: 122 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 123 | isscore = final_scores > 0.3 124 | iscat = final_cls_inds == 0 125 | isbbox = [i and j for (i, j) in zip(isscore, iscat)] 126 | final_boxes = final_boxes[isbbox] 127 | else: 128 | return [] 129 | 130 | return final_boxes 131 | -------------------------------------------------------------------------------- /src/dwpose/onnxpose.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | from typing import List, Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | import onnxruntime as ort 7 | 8 | 9 | def preprocess( 10 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) 11 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 12 | """Do preprocessing for RTMPose model inference. 13 | 14 | Args: 15 | img (np.ndarray): Input image in shape. 16 | input_size (tuple): Input image size in shape (w, h). 17 | 18 | Returns: 19 | tuple: 20 | - resized_img (np.ndarray): Preprocessed image. 21 | - center (np.ndarray): Center of image. 22 | - scale (np.ndarray): Scale of image. 23 | """ 24 | # get shape of image 25 | img_shape = img.shape[:2] 26 | out_img, out_center, out_scale = [], [], [] 27 | if len(out_bbox) == 0: 28 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]] 29 | for i in range(len(out_bbox)): 30 | x0 = out_bbox[i][0] 31 | y0 = out_bbox[i][1] 32 | x1 = out_bbox[i][2] 33 | y1 = out_bbox[i][3] 34 | bbox = np.array([x0, y0, x1, y1]) 35 | 36 | # get center and scale 37 | center, scale = bbox_xyxy2cs(bbox, padding=1.25) 38 | 39 | # do affine transformation 40 | resized_img, scale = top_down_affine(input_size, scale, center, img) 41 | 42 | # normalize image 43 | mean = np.array([123.675, 116.28, 103.53]) 44 | std = np.array([58.395, 57.12, 57.375]) 45 | resized_img = (resized_img - mean) / std 46 | 47 | out_img.append(resized_img) 48 | out_center.append(center) 49 | out_scale.append(scale) 50 | 51 | return out_img, out_center, out_scale 52 | 53 | 54 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: 55 | """Inference RTMPose model. 56 | 57 | Args: 58 | sess (ort.InferenceSession): ONNXRuntime session. 59 | img (np.ndarray): Input image in shape. 60 | 61 | Returns: 62 | outputs (np.ndarray): Output of RTMPose model. 63 | """ 64 | all_out = [] 65 | # build input 66 | for i in range(len(img)): 67 | input = [img[i].transpose(2, 0, 1)] 68 | 69 | # build output 70 | sess_input = {sess.get_inputs()[0].name: input} 71 | sess_output = [] 72 | for out in sess.get_outputs(): 73 | sess_output.append(out.name) 74 | 75 | # run model 76 | outputs = sess.run(sess_output, sess_input) 77 | all_out.append(outputs) 78 | 79 | return all_out 80 | 81 | 82 | def postprocess( 83 | outputs: List[np.ndarray], 84 | model_input_size: Tuple[int, int], 85 | center: Tuple[int, int], 86 | scale: Tuple[int, int], 87 | simcc_split_ratio: float = 2.0, 88 | ) -> Tuple[np.ndarray, np.ndarray]: 89 | """Postprocess for RTMPose model output. 90 | 91 | Args: 92 | outputs (np.ndarray): Output of RTMPose model. 93 | model_input_size (tuple): RTMPose model Input image size. 94 | center (tuple): Center of bbox in shape (x, y). 95 | scale (tuple): Scale of bbox in shape (w, h). 96 | simcc_split_ratio (float): Split ratio of simcc. 97 | 98 | Returns: 99 | tuple: 100 | - keypoints (np.ndarray): Rescaled keypoints. 101 | - scores (np.ndarray): Model predict scores. 102 | """ 103 | all_key = [] 104 | all_score = [] 105 | for i in range(len(outputs)): 106 | # use simcc to decode 107 | simcc_x, simcc_y = outputs[i] 108 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) 109 | 110 | # rescale keypoints 111 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 112 | all_key.append(keypoints[0]) 113 | all_score.append(scores[0]) 114 | 115 | return np.array(all_key), np.array(all_score) 116 | 117 | 118 | def bbox_xyxy2cs( 119 | bbox: np.ndarray, padding: float = 1.0 120 | ) -> Tuple[np.ndarray, np.ndarray]: 121 | """Transform the bbox format from (x,y,w,h) into (center, scale) 122 | 123 | Args: 124 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted 125 | as (left, top, right, bottom) 126 | padding (float): BBox padding factor that will be multilied to scale. 127 | Default: 1.0 128 | 129 | Returns: 130 | tuple: A tuple containing center and scale. 131 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or 132 | (n, 2) 133 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or 134 | (n, 2) 135 | """ 136 | # convert single bbox from (4, ) to (1, 4) 137 | dim = bbox.ndim 138 | if dim == 1: 139 | bbox = bbox[None, :] 140 | 141 | # get bbox center and scale 142 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) 143 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5 144 | scale = np.hstack([x2 - x1, y2 - y1]) * padding 145 | 146 | if dim == 1: 147 | center = center[0] 148 | scale = scale[0] 149 | 150 | return center, scale 151 | 152 | 153 | def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float) -> np.ndarray: 154 | """Extend the scale to match the given aspect ratio. 155 | 156 | Args: 157 | scale (np.ndarray): The image scale (w, h) in shape (2, ) 158 | aspect_ratio (float): The ratio of ``w/h`` 159 | 160 | Returns: 161 | np.ndarray: The reshaped image scale in (2, ) 162 | """ 163 | w, h = np.hsplit(bbox_scale, [1]) 164 | bbox_scale = np.where( 165 | w > h * aspect_ratio, 166 | np.hstack([w, w / aspect_ratio]), 167 | np.hstack([h * aspect_ratio, h]), 168 | ) 169 | return bbox_scale 170 | 171 | 172 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: 173 | """Rotate a point by an angle. 174 | 175 | Args: 176 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) 177 | angle_rad (float): rotation angle in radian 178 | 179 | Returns: 180 | np.ndarray: Rotated point in shape (2, ) 181 | """ 182 | sn, cs = np.sin(angle_rad), np.cos(angle_rad) 183 | rot_mat = np.array([[cs, -sn], [sn, cs]]) 184 | return rot_mat @ pt 185 | 186 | 187 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: 188 | """To calculate the affine matrix, three pairs of points are required. This 189 | function is used to get the 3rd point, given 2D points a & b. 190 | 191 | The 3rd point is defined by rotating vector `a - b` by 90 degrees 192 | anticlockwise, using b as the rotation center. 193 | 194 | Args: 195 | a (np.ndarray): The 1st point (x,y) in shape (2, ) 196 | b (np.ndarray): The 2nd point (x,y) in shape (2, ) 197 | 198 | Returns: 199 | np.ndarray: The 3rd point. 200 | """ 201 | direction = a - b 202 | c = b + np.r_[-direction[1], direction[0]] 203 | return c 204 | 205 | 206 | def get_warp_matrix( 207 | center: np.ndarray, 208 | scale: np.ndarray, 209 | rot: float, 210 | output_size: Tuple[int, int], 211 | shift: Tuple[float, float] = (0.0, 0.0), 212 | inv: bool = False, 213 | ) -> np.ndarray: 214 | """Calculate the affine transformation matrix that can warp the bbox area 215 | in the input image to the output size. 216 | 217 | Args: 218 | center (np.ndarray[2, ]): Center of the bounding box (x, y). 219 | scale (np.ndarray[2, ]): Scale of the bounding box 220 | wrt [width, height]. 221 | rot (float): Rotation angle (degree). 222 | output_size (np.ndarray[2, ] | list(2,)): Size of the 223 | destination heatmaps. 224 | shift (0-100%): Shift translation ratio wrt the width/height. 225 | Default (0., 0.). 226 | inv (bool): Option to inverse the affine transform direction. 227 | (inv=False: src->dst or inv=True: dst->src) 228 | 229 | Returns: 230 | np.ndarray: A 2x3 transformation matrix 231 | """ 232 | shift = np.array(shift) 233 | src_w = scale[0] 234 | dst_w = output_size[0] 235 | dst_h = output_size[1] 236 | 237 | # compute transformation matrix 238 | rot_rad = np.deg2rad(rot) 239 | src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad) 240 | dst_dir = np.array([0.0, dst_w * -0.5]) 241 | 242 | # get four corners of the src rectangle in the original image 243 | src = np.zeros((3, 2), dtype=np.float32) 244 | src[0, :] = center + scale * shift 245 | src[1, :] = center + src_dir + scale * shift 246 | src[2, :] = _get_3rd_point(src[0, :], src[1, :]) 247 | 248 | # get four corners of the dst rectangle in the input image 249 | dst = np.zeros((3, 2), dtype=np.float32) 250 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 251 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 252 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) 253 | 254 | if inv: 255 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 256 | else: 257 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 258 | 259 | return warp_mat 260 | 261 | 262 | def top_down_affine( 263 | input_size: dict, bbox_scale: dict, bbox_center: dict, img: np.ndarray 264 | ) -> Tuple[np.ndarray, np.ndarray]: 265 | """Get the bbox image as the model input by affine transform. 266 | 267 | Args: 268 | input_size (dict): The input size of the model. 269 | bbox_scale (dict): The bbox scale of the img. 270 | bbox_center (dict): The bbox center of the img. 271 | img (np.ndarray): The original image. 272 | 273 | Returns: 274 | tuple: A tuple containing center and scale. 275 | - np.ndarray[float32]: img after affine transform. 276 | - np.ndarray[float32]: bbox scale after affine transform. 277 | """ 278 | w, h = input_size 279 | warp_size = (int(w), int(h)) 280 | 281 | # reshape bbox to fixed aspect ratio 282 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) 283 | 284 | # get the affine matrix 285 | center = bbox_center 286 | scale = bbox_scale 287 | rot = 0 288 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) 289 | 290 | # do affine transform 291 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) 292 | 293 | return img, bbox_scale 294 | 295 | 296 | def get_simcc_maximum( 297 | simcc_x: np.ndarray, simcc_y: np.ndarray 298 | ) -> Tuple[np.ndarray, np.ndarray]: 299 | """Get maximum response location and value from simcc representations. 300 | 301 | Note: 302 | instance number: N 303 | num_keypoints: K 304 | heatmap height: H 305 | heatmap width: W 306 | 307 | Args: 308 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) 309 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) 310 | 311 | Returns: 312 | tuple: 313 | - locs (np.ndarray): locations of maximum heatmap responses in shape 314 | (K, 2) or (N, K, 2) 315 | - vals (np.ndarray): values of maximum heatmap responses in shape 316 | (K,) or (N, K) 317 | """ 318 | N, K, Wx = simcc_x.shape 319 | simcc_x = simcc_x.reshape(N * K, -1) 320 | simcc_y = simcc_y.reshape(N * K, -1) 321 | 322 | # get maximum value locations 323 | x_locs = np.argmax(simcc_x, axis=1) 324 | y_locs = np.argmax(simcc_y, axis=1) 325 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) 326 | max_val_x = np.amax(simcc_x, axis=1) 327 | max_val_y = np.amax(simcc_y, axis=1) 328 | 329 | # get maximum value across x and y axis 330 | mask = max_val_x > max_val_y 331 | max_val_x[mask] = max_val_y[mask] 332 | vals = max_val_x 333 | locs[vals <= 0.0] = -1 334 | 335 | # reshape 336 | locs = locs.reshape(N, K, 2) 337 | vals = vals.reshape(N, K) 338 | 339 | return locs, vals 340 | 341 | 342 | def decode( 343 | simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio 344 | ) -> Tuple[np.ndarray, np.ndarray]: 345 | """Modulate simcc distribution with Gaussian. 346 | 347 | Args: 348 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. 349 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. 350 | simcc_split_ratio (int): The split ratio of simcc. 351 | 352 | Returns: 353 | tuple: A tuple containing center and scale. 354 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) 355 | - np.ndarray[float32]: scores in shape (K,) or (n, K) 356 | """ 357 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) 358 | keypoints /= simcc_split_ratio 359 | 360 | return keypoints, scores 361 | 362 | 363 | def inference_pose(session, out_bbox, oriImg): 364 | h, w = session.get_inputs()[0].shape[2:] 365 | model_input_size = (w, h) 366 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) 367 | outputs = inference(session, resized_img) 368 | keypoints, scores = postprocess(outputs, model_input_size, center, scale) 369 | 370 | return keypoints, scores 371 | -------------------------------------------------------------------------------- /src/dwpose/util.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | import math 3 | import numpy as np 4 | import matplotlib 5 | import cv2 6 | 7 | 8 | eps = 0.01 9 | 10 | 11 | def smart_resize(x, s): 12 | Ht, Wt = s 13 | if x.ndim == 2: 14 | Ho, Wo = x.shape 15 | Co = 1 16 | else: 17 | Ho, Wo, Co = x.shape 18 | if Co == 3 or Co == 1: 19 | k = float(Ht + Wt) / float(Ho + Wo) 20 | return cv2.resize( 21 | x, 22 | (int(Wt), int(Ht)), 23 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4, 24 | ) 25 | else: 26 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) 27 | 28 | 29 | def smart_resize_k(x, fx, fy): 30 | if x.ndim == 2: 31 | Ho, Wo = x.shape 32 | Co = 1 33 | else: 34 | Ho, Wo, Co = x.shape 35 | Ht, Wt = Ho * fy, Wo * fx 36 | if Co == 3 or Co == 1: 37 | k = float(Ht + Wt) / float(Ho + Wo) 38 | return cv2.resize( 39 | x, 40 | (int(Wt), int(Ht)), 41 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4, 42 | ) 43 | else: 44 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) 45 | 46 | 47 | def padRightDownCorner(img, stride, padValue): 48 | h = img.shape[0] 49 | w = img.shape[1] 50 | 51 | pad = 4 * [None] 52 | pad[0] = 0 # up 53 | pad[1] = 0 # left 54 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down 55 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right 56 | 57 | img_padded = img 58 | pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) 59 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 60 | pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) 61 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 62 | pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) 63 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 64 | pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) 65 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 66 | 67 | return img_padded, pad 68 | 69 | 70 | def transfer(model, model_weights): 71 | transfered_model_weights = {} 72 | for weights_name in model.state_dict().keys(): 73 | transfered_model_weights[weights_name] = model_weights[ 74 | ".".join(weights_name.split(".")[1:]) 75 | ] 76 | return transfered_model_weights 77 | 78 | 79 | def draw_bodypose(canvas, candidate, subset): 80 | H, W, C = canvas.shape 81 | candidate = np.array(candidate) 82 | subset = np.array(subset) 83 | 84 | stickwidth = 4 85 | 86 | limbSeq = [ 87 | [2, 3], 88 | [2, 6], 89 | [3, 4], 90 | [4, 5], 91 | [6, 7], 92 | [7, 8], 93 | [2, 9], 94 | [9, 10], 95 | [10, 11], 96 | [2, 12], 97 | [12, 13], 98 | [13, 14], 99 | [2, 1], 100 | [1, 15], 101 | [15, 17], 102 | [1, 16], 103 | [16, 18], 104 | [3, 17], 105 | [6, 18], 106 | ] 107 | 108 | colors = [ 109 | [255, 0, 0], 110 | [255, 85, 0], 111 | [255, 170, 0], 112 | [255, 255, 0], 113 | [170, 255, 0], 114 | [85, 255, 0], 115 | [0, 255, 0], 116 | [0, 255, 85], 117 | [0, 255, 170], 118 | [0, 255, 255], 119 | [0, 170, 255], 120 | [0, 85, 255], 121 | [0, 0, 255], 122 | [85, 0, 255], 123 | [170, 0, 255], 124 | [255, 0, 255], 125 | [255, 0, 170], 126 | [255, 0, 85], 127 | ] 128 | 129 | for i in range(17): 130 | for n in range(len(subset)): 131 | index = subset[n][np.array(limbSeq[i]) - 1] 132 | if -1 in index: 133 | continue 134 | Y = candidate[index.astype(int), 0] * float(W) 135 | X = candidate[index.astype(int), 1] * float(H) 136 | mX = np.mean(X) 137 | mY = np.mean(Y) 138 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 139 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 140 | polygon = cv2.ellipse2Poly( 141 | (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 142 | ) 143 | cv2.fillConvexPoly(canvas, polygon, colors[i]) 144 | 145 | canvas = (canvas * 0.6).astype(np.uint8) 146 | 147 | for i in range(18): 148 | for n in range(len(subset)): 149 | index = int(subset[n][i]) 150 | if index == -1: 151 | continue 152 | x, y = candidate[index][0:2] 153 | x = int(x * W) 154 | y = int(y * H) 155 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) 156 | 157 | return canvas 158 | 159 | 160 | def draw_handpose(canvas, all_hand_peaks): 161 | H, W, C = canvas.shape 162 | 163 | edges = [ 164 | [0, 1], 165 | [1, 2], 166 | [2, 3], 167 | [3, 4], 168 | [0, 5], 169 | [5, 6], 170 | [6, 7], 171 | [7, 8], 172 | [0, 9], 173 | [9, 10], 174 | [10, 11], 175 | [11, 12], 176 | [0, 13], 177 | [13, 14], 178 | [14, 15], 179 | [15, 16], 180 | [0, 17], 181 | [17, 18], 182 | [18, 19], 183 | [19, 20], 184 | ] 185 | 186 | for peaks in all_hand_peaks: 187 | peaks = np.array(peaks) 188 | 189 | for ie, e in enumerate(edges): 190 | x1, y1 = peaks[e[0]] 191 | x2, y2 = peaks[e[1]] 192 | x1 = int(x1 * W) 193 | y1 = int(y1 * H) 194 | x2 = int(x2 * W) 195 | y2 = int(y2 * H) 196 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 197 | cv2.line( 198 | canvas, 199 | (x1, y1), 200 | (x2, y2), 201 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) 202 | * 255, 203 | thickness=2, 204 | ) 205 | 206 | for i, keyponit in enumerate(peaks): 207 | x, y = keyponit 208 | x = int(x * W) 209 | y = int(y * H) 210 | if x > eps and y > eps: 211 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) 212 | return canvas 213 | 214 | 215 | def draw_facepose(canvas, all_lmks): 216 | H, W, C = canvas.shape 217 | for lmks in all_lmks: 218 | lmks = np.array(lmks) 219 | for lmk in lmks: 220 | x, y = lmk 221 | x = int(x * W) 222 | y = int(y * H) 223 | if x > eps and y > eps: 224 | cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) 225 | return canvas 226 | 227 | 228 | # detect hand according to body pose keypoints 229 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp 230 | def handDetect(candidate, subset, oriImg): 231 | # right hand: wrist 4, elbow 3, shoulder 2 232 | # left hand: wrist 7, elbow 6, shoulder 5 233 | ratioWristElbow = 0.33 234 | detect_result = [] 235 | image_height, image_width = oriImg.shape[0:2] 236 | for person in subset.astype(int): 237 | # if any of three not detected 238 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0 239 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0 240 | if not (has_left or has_right): 241 | continue 242 | hands = [] 243 | # left hand 244 | if has_left: 245 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] 246 | x1, y1 = candidate[left_shoulder_index][:2] 247 | x2, y2 = candidate[left_elbow_index][:2] 248 | x3, y3 = candidate[left_wrist_index][:2] 249 | hands.append([x1, y1, x2, y2, x3, y3, True]) 250 | # right hand 251 | if has_right: 252 | right_shoulder_index, right_elbow_index, right_wrist_index = person[ 253 | [2, 3, 4] 254 | ] 255 | x1, y1 = candidate[right_shoulder_index][:2] 256 | x2, y2 = candidate[right_elbow_index][:2] 257 | x3, y3 = candidate[right_wrist_index][:2] 258 | hands.append([x1, y1, x2, y2, x3, y3, False]) 259 | 260 | for x1, y1, x2, y2, x3, y3, is_left in hands: 261 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox 262 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); 263 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); 264 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); 265 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); 266 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); 267 | x = x3 + ratioWristElbow * (x3 - x2) 268 | y = y3 + ratioWristElbow * (y3 - y2) 269 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) 270 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) 271 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) 272 | # x-y refers to the center --> offset to topLeft point 273 | # handRectangle.x -= handRectangle.width / 2.f; 274 | # handRectangle.y -= handRectangle.height / 2.f; 275 | x -= width / 2 276 | y -= width / 2 # width = height 277 | # overflow the image 278 | if x < 0: 279 | x = 0 280 | if y < 0: 281 | y = 0 282 | width1 = width 283 | width2 = width 284 | if x + width > image_width: 285 | width1 = image_width - x 286 | if y + width > image_height: 287 | width2 = image_height - y 288 | width = min(width1, width2) 289 | # the max hand box value is 20 pixels 290 | if width >= 20: 291 | detect_result.append([int(x), int(y), int(width), is_left]) 292 | 293 | """ 294 | return value: [[x, y, w, True if left hand else False]]. 295 | width=height since the network require squared input. 296 | x, y is the coordinate of top left 297 | """ 298 | return detect_result 299 | 300 | 301 | # Written by Lvmin 302 | def faceDetect(candidate, subset, oriImg): 303 | # left right eye ear 14 15 16 17 304 | detect_result = [] 305 | image_height, image_width = oriImg.shape[0:2] 306 | for person in subset.astype(int): 307 | has_head = person[0] > -1 308 | if not has_head: 309 | continue 310 | 311 | has_left_eye = person[14] > -1 312 | has_right_eye = person[15] > -1 313 | has_left_ear = person[16] > -1 314 | has_right_ear = person[17] > -1 315 | 316 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): 317 | continue 318 | 319 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] 320 | 321 | width = 0.0 322 | x0, y0 = candidate[head][:2] 323 | 324 | if has_left_eye: 325 | x1, y1 = candidate[left_eye][:2] 326 | d = max(abs(x0 - x1), abs(y0 - y1)) 327 | width = max(width, d * 3.0) 328 | 329 | if has_right_eye: 330 | x1, y1 = candidate[right_eye][:2] 331 | d = max(abs(x0 - x1), abs(y0 - y1)) 332 | width = max(width, d * 3.0) 333 | 334 | if has_left_ear: 335 | x1, y1 = candidate[left_ear][:2] 336 | d = max(abs(x0 - x1), abs(y0 - y1)) 337 | width = max(width, d * 1.5) 338 | 339 | if has_right_ear: 340 | x1, y1 = candidate[right_ear][:2] 341 | d = max(abs(x0 - x1), abs(y0 - y1)) 342 | width = max(width, d * 1.5) 343 | 344 | x, y = x0, y0 345 | 346 | x -= width 347 | y -= width 348 | 349 | if x < 0: 350 | x = 0 351 | 352 | if y < 0: 353 | y = 0 354 | 355 | width1 = width * 2 356 | width2 = width * 2 357 | 358 | if x + width > image_width: 359 | width1 = image_width - x 360 | 361 | if y + width > image_height: 362 | width2 = image_height - y 363 | 364 | width = min(width1, width2) 365 | 366 | if width >= 20: 367 | detect_result.append([int(x), int(y), int(width)]) 368 | 369 | return detect_result 370 | 371 | 372 | # get max index of 2d array 373 | def npmax(array): 374 | arrayindex = array.argmax(1) 375 | arrayvalue = array.max(1) 376 | i = arrayvalue.argmax() 377 | j = arrayindex[i] 378 | return i, j 379 | -------------------------------------------------------------------------------- /src/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | import onnxruntime as ort 7 | 8 | from .onnxdet import inference_detector 9 | from .onnxpose import inference_pose 10 | 11 | ModelDataPathPrefix = Path("./pretrained_weights") 12 | 13 | 14 | class Wholebody: 15 | def __init__(self, device="cuda:0"): 16 | providers = ( 17 | ["CPUExecutionProvider"] if device == "cpu" else ["CUDAExecutionProvider"] 18 | ) 19 | onnx_det = ModelDataPathPrefix.joinpath("DWPose/yolox_l.onnx") 20 | onnx_pose = ModelDataPathPrefix.joinpath("DWPose/dw-ll_ucoco_384.onnx") 21 | 22 | self.session_det = ort.InferenceSession( 23 | path_or_bytes=onnx_det, providers=providers 24 | ) 25 | self.session_pose = ort.InferenceSession( 26 | path_or_bytes=onnx_pose, providers=providers 27 | ) 28 | 29 | def __call__(self, oriImg): 30 | det_result = inference_detector(self.session_det, oriImg) 31 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) 32 | 33 | keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1) 34 | # compute neck joint 35 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 36 | # neck score when visualizing pred 37 | neck[:, 2:4] = np.logical_and( 38 | keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3 39 | ).astype(int) 40 | new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1) 41 | mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3] 42 | openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17] 43 | new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx] 44 | keypoints_info = new_keypoints_info 45 | 46 | keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2] 47 | 48 | return keypoints, scores 49 | -------------------------------------------------------------------------------- /src/models/motion_module.py: -------------------------------------------------------------------------------- 1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Callable, Optional 5 | 6 | import torch 7 | from diffusers.models.attention import FeedForward 8 | from diffusers.models.attention_processor import Attention, AttnProcessor 9 | from diffusers.utils import BaseOutput 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from einops import rearrange, repeat 12 | from torch import nn 13 | 14 | 15 | def zero_module(module): 16 | # Zero out the parameters of a module and return it. 17 | for p in module.parameters(): 18 | p.detach().zero_() 19 | return module 20 | 21 | 22 | @dataclass 23 | class TemporalTransformer3DModelOutput(BaseOutput): 24 | sample: torch.FloatTensor 25 | 26 | 27 | if is_xformers_available(): 28 | import xformers 29 | import xformers.ops 30 | else: 31 | xformers = None 32 | 33 | 34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): 35 | if motion_module_type == "Vanilla": 36 | return VanillaTemporalModule( 37 | in_channels=in_channels, 38 | **motion_module_kwargs, 39 | ) 40 | else: 41 | raise ValueError 42 | 43 | 44 | class VanillaTemporalModule(nn.Module): 45 | def __init__( 46 | self, 47 | in_channels, 48 | num_attention_heads=8, 49 | num_transformer_block=2, 50 | attention_block_types=("Temporal_Self", "Temporal_Self"), 51 | cross_frame_attention_mode=None, 52 | temporal_position_encoding=False, 53 | temporal_position_encoding_max_len=24, 54 | temporal_attention_dim_div=1, 55 | zero_initialize=True, 56 | ): 57 | super().__init__() 58 | 59 | self.temporal_transformer = TemporalTransformer3DModel( 60 | in_channels=in_channels, 61 | num_attention_heads=num_attention_heads, 62 | attention_head_dim=in_channels 63 | // num_attention_heads 64 | // temporal_attention_dim_div, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_frame_attention_mode=cross_frame_attention_mode, 68 | temporal_position_encoding=temporal_position_encoding, 69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 70 | ) 71 | 72 | if zero_initialize: 73 | self.temporal_transformer.proj_out = zero_module( 74 | self.temporal_transformer.proj_out 75 | ) 76 | 77 | def forward( 78 | self, 79 | input_tensor, 80 | temb, 81 | encoder_hidden_states, 82 | attention_mask=None, 83 | anchor_frame_idx=None, 84 | ): 85 | hidden_states = input_tensor 86 | hidden_states = self.temporal_transformer( 87 | hidden_states, encoder_hidden_states, attention_mask 88 | ) 89 | 90 | output = hidden_states 91 | return output 92 | 93 | 94 | class TemporalTransformer3DModel(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels, 98 | num_attention_heads, 99 | attention_head_dim, 100 | num_layers, 101 | attention_block_types=( 102 | "Temporal_Self", 103 | "Temporal_Self", 104 | ), 105 | dropout=0.0, 106 | norm_num_groups=32, 107 | cross_attention_dim=768, 108 | activation_fn="geglu", 109 | attention_bias=False, 110 | upcast_attention=False, 111 | cross_frame_attention_mode=None, 112 | temporal_position_encoding=False, 113 | temporal_position_encoding_max_len=24, 114 | ): 115 | super().__init__() 116 | 117 | inner_dim = num_attention_heads * attention_head_dim 118 | 119 | self.norm = torch.nn.GroupNorm( 120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 121 | ) 122 | self.proj_in = nn.Linear(in_channels, inner_dim) 123 | 124 | self.transformer_blocks = nn.ModuleList( 125 | [ 126 | TemporalTransformerBlock( 127 | dim=inner_dim, 128 | num_attention_heads=num_attention_heads, 129 | attention_head_dim=attention_head_dim, 130 | attention_block_types=attention_block_types, 131 | dropout=dropout, 132 | norm_num_groups=norm_num_groups, 133 | cross_attention_dim=cross_attention_dim, 134 | activation_fn=activation_fn, 135 | attention_bias=attention_bias, 136 | upcast_attention=upcast_attention, 137 | cross_frame_attention_mode=cross_frame_attention_mode, 138 | temporal_position_encoding=temporal_position_encoding, 139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 140 | ) 141 | for d in range(num_layers) 142 | ] 143 | ) 144 | self.proj_out = nn.Linear(inner_dim, in_channels) 145 | 146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 147 | assert ( 148 | hidden_states.dim() == 5 149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 150 | video_length = hidden_states.shape[2] 151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 152 | 153 | batch, channel, height, weight = hidden_states.shape 154 | residual = hidden_states 155 | 156 | hidden_states = self.norm(hidden_states) 157 | inner_dim = hidden_states.shape[1] 158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 159 | batch, height * weight, inner_dim 160 | ) 161 | hidden_states = self.proj_in(hidden_states) 162 | 163 | # Transformer Blocks 164 | for block in self.transformer_blocks: 165 | hidden_states = block( 166 | hidden_states, 167 | encoder_hidden_states=encoder_hidden_states, 168 | video_length=video_length, 169 | ) 170 | 171 | # output 172 | hidden_states = self.proj_out(hidden_states) 173 | hidden_states = ( 174 | hidden_states.reshape(batch, height, weight, inner_dim) 175 | .permute(0, 3, 1, 2) 176 | .contiguous() 177 | ) 178 | 179 | output = hidden_states + residual 180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 181 | 182 | return output 183 | 184 | 185 | class TemporalTransformerBlock(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | num_attention_heads, 190 | attention_head_dim, 191 | attention_block_types=( 192 | "Temporal_Self", 193 | "Temporal_Self", 194 | ), 195 | dropout=0.0, 196 | norm_num_groups=32, 197 | cross_attention_dim=768, 198 | activation_fn="geglu", 199 | attention_bias=False, 200 | upcast_attention=False, 201 | cross_frame_attention_mode=None, 202 | temporal_position_encoding=False, 203 | temporal_position_encoding_max_len=24, 204 | ): 205 | super().__init__() 206 | 207 | attention_blocks = [] 208 | norms = [] 209 | 210 | for block_name in attention_block_types: 211 | attention_blocks.append( 212 | VersatileAttention( 213 | attention_mode=block_name.split("_")[0], 214 | cross_attention_dim=cross_attention_dim 215 | if block_name.endswith("_Cross") 216 | else None, 217 | query_dim=dim, 218 | heads=num_attention_heads, 219 | dim_head=attention_head_dim, 220 | dropout=dropout, 221 | bias=attention_bias, 222 | upcast_attention=upcast_attention, 223 | cross_frame_attention_mode=cross_frame_attention_mode, 224 | temporal_position_encoding=temporal_position_encoding, 225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 226 | ) 227 | ) 228 | norms.append(nn.LayerNorm(dim)) 229 | 230 | self.attention_blocks = nn.ModuleList(attention_blocks) 231 | self.norms = nn.ModuleList(norms) 232 | 233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 234 | self.ff_norm = nn.LayerNorm(dim) 235 | 236 | def forward( 237 | self, 238 | hidden_states, 239 | encoder_hidden_states=None, 240 | attention_mask=None, 241 | video_length=None, 242 | ): 243 | for attention_block, norm in zip(self.attention_blocks, self.norms): 244 | norm_hidden_states = norm(hidden_states) 245 | hidden_states = ( 246 | attention_block( 247 | norm_hidden_states, 248 | encoder_hidden_states=encoder_hidden_states 249 | if attention_block.is_cross_attention 250 | else None, 251 | video_length=video_length, 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 257 | 258 | output = hidden_states 259 | return output 260 | 261 | 262 | class PositionalEncoding(nn.Module): 263 | def __init__(self, d_model, dropout=0.0, max_len=24): 264 | super().__init__() 265 | self.dropout = nn.Dropout(p=dropout) 266 | position = torch.arange(max_len).unsqueeze(1) 267 | div_term = torch.exp( 268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 269 | ) 270 | pe = torch.zeros(1, max_len, d_model) 271 | pe[0, :, 0::2] = torch.sin(position * div_term) 272 | pe[0, :, 1::2] = torch.cos(position * div_term) 273 | self.register_buffer("pe", pe) 274 | 275 | def forward(self, x): 276 | x = x + self.pe[:, : x.size(1)] 277 | return self.dropout(x) 278 | 279 | 280 | class VersatileAttention(Attention): 281 | def __init__( 282 | self, 283 | attention_mode=None, 284 | cross_frame_attention_mode=None, 285 | temporal_position_encoding=False, 286 | temporal_position_encoding_max_len=24, 287 | *args, 288 | **kwargs, 289 | ): 290 | super().__init__(*args, **kwargs) 291 | assert attention_mode == "Temporal" 292 | 293 | self.attention_mode = attention_mode 294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 295 | 296 | self.pos_encoder = ( 297 | PositionalEncoding( 298 | kwargs["query_dim"], 299 | dropout=0.0, 300 | max_len=temporal_position_encoding_max_len, 301 | ) 302 | if (temporal_position_encoding and attention_mode == "Temporal") 303 | else None 304 | ) 305 | 306 | def extra_repr(self): 307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 308 | 309 | def set_use_memory_efficient_attention_xformers( 310 | self, 311 | use_memory_efficient_attention_xformers: bool, 312 | attention_op: Optional[Callable] = None, 313 | ): 314 | if use_memory_efficient_attention_xformers: 315 | if not is_xformers_available(): 316 | raise ModuleNotFoundError( 317 | ( 318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 319 | " xformers" 320 | ), 321 | name="xformers", 322 | ) 323 | elif not torch.cuda.is_available(): 324 | raise ValueError( 325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 326 | " only available for GPU " 327 | ) 328 | else: 329 | try: 330 | # Make sure we can run the memory efficient attention 331 | _ = xformers.ops.memory_efficient_attention( 332 | torch.randn((1, 2, 40), device="cuda"), 333 | torch.randn((1, 2, 40), device="cuda"), 334 | torch.randn((1, 2, 40), device="cuda"), 335 | ) 336 | except Exception as e: 337 | raise e 338 | 339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13. 340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13. 341 | # You don't need XFormersAttnProcessor here. 342 | # processor = XFormersAttnProcessor( 343 | # attention_op=attention_op, 344 | # ) 345 | processor = AttnProcessor() 346 | else: 347 | processor = AttnProcessor() 348 | 349 | self.set_processor(processor) 350 | 351 | def forward( 352 | self, 353 | hidden_states, 354 | encoder_hidden_states=None, 355 | attention_mask=None, 356 | video_length=None, 357 | **cross_attention_kwargs, 358 | ): 359 | if self.attention_mode == "Temporal": 360 | d = hidden_states.shape[1] # d means HxW 361 | hidden_states = rearrange( 362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 363 | ) 364 | 365 | if self.pos_encoder is not None: 366 | hidden_states = self.pos_encoder(hidden_states) 367 | 368 | encoder_hidden_states = ( 369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) 370 | if encoder_hidden_states is not None 371 | else encoder_hidden_states 372 | ) 373 | 374 | else: 375 | raise NotImplementedError 376 | 377 | hidden_states = self.processor( 378 | self, 379 | hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | 385 | if self.attention_mode == "Temporal": 386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 387 | 388 | return hidden_states 389 | -------------------------------------------------------------------------------- /src/models/mutual_self_attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from src.models.attention import TemporalBasicTransformerBlock 8 | 9 | from .attention import BasicTransformerBlock 10 | 11 | 12 | def torch_dfs(model: torch.nn.Module): 13 | result = [model] 14 | for child in model.children(): 15 | result += torch_dfs(child) 16 | return result 17 | 18 | 19 | class ReferenceAttentionControl: 20 | def __init__( 21 | self, 22 | unet, 23 | mode="write", 24 | do_classifier_free_guidance=False, 25 | attention_auto_machine_weight=float("inf"), 26 | gn_auto_machine_weight=1.0, 27 | style_fidelity=1.0, 28 | reference_attn=True, 29 | reference_adain=False, 30 | fusion_blocks="midup", 31 | batch_size=1, 32 | ) -> None: 33 | # 10. Modify self attention and group norm 34 | self.unet = unet 35 | assert mode in ["read", "write"] 36 | assert fusion_blocks in ["midup", "full"] 37 | self.reference_attn = reference_attn 38 | self.reference_adain = reference_adain 39 | self.fusion_blocks = fusion_blocks 40 | self.register_reference_hooks( 41 | mode, 42 | do_classifier_free_guidance, 43 | attention_auto_machine_weight, 44 | gn_auto_machine_weight, 45 | style_fidelity, 46 | reference_attn, 47 | reference_adain, 48 | fusion_blocks, 49 | batch_size=batch_size, 50 | ) 51 | 52 | def register_reference_hooks( 53 | self, 54 | mode, 55 | do_classifier_free_guidance, 56 | attention_auto_machine_weight, 57 | gn_auto_machine_weight, 58 | style_fidelity, 59 | reference_attn, 60 | reference_adain, 61 | dtype=torch.float16, 62 | batch_size=1, 63 | num_images_per_prompt=1, 64 | device=torch.device("cpu"), 65 | fusion_blocks="midup", 66 | ): 67 | MODE = mode 68 | do_classifier_free_guidance = do_classifier_free_guidance 69 | attention_auto_machine_weight = attention_auto_machine_weight 70 | gn_auto_machine_weight = gn_auto_machine_weight 71 | style_fidelity = style_fidelity 72 | reference_attn = reference_attn 73 | reference_adain = reference_adain 74 | fusion_blocks = fusion_blocks 75 | num_images_per_prompt = num_images_per_prompt 76 | dtype = dtype 77 | if do_classifier_free_guidance: 78 | uc_mask = ( 79 | torch.Tensor( 80 | [1] * batch_size * num_images_per_prompt * 16 81 | + [0] * batch_size * num_images_per_prompt * 16 82 | ) 83 | .to(device) 84 | .bool() 85 | ) 86 | else: 87 | uc_mask = ( 88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 89 | .to(device) 90 | .bool() 91 | ) 92 | 93 | def hacked_basic_transformer_inner_forward( 94 | self, 95 | hidden_states: torch.FloatTensor, 96 | attention_mask: Optional[torch.FloatTensor] = None, 97 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 98 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 99 | timestep: Optional[torch.LongTensor] = None, 100 | cross_attention_kwargs: Dict[str, Any] = None, 101 | class_labels: Optional[torch.LongTensor] = None, 102 | video_length=None, 103 | self_attention_additional_feats=None, 104 | mode=None, 105 | ): 106 | if self.use_ada_layer_norm: # False 107 | norm_hidden_states = self.norm1(hidden_states, timestep) 108 | elif self.use_ada_layer_norm_zero: 109 | ( 110 | norm_hidden_states, 111 | gate_msa, 112 | shift_mlp, 113 | scale_mlp, 114 | gate_mlp, 115 | ) = self.norm1( 116 | hidden_states, 117 | timestep, 118 | class_labels, 119 | hidden_dtype=hidden_states.dtype, 120 | ) 121 | else: 122 | norm_hidden_states = self.norm1(hidden_states) 123 | 124 | # 1. Self-Attention 125 | # self.only_cross_attention = False 126 | cross_attention_kwargs = ( 127 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 128 | ) 129 | if self.only_cross_attention: 130 | attn_output = self.attn1( 131 | norm_hidden_states, 132 | encoder_hidden_states=encoder_hidden_states 133 | if self.only_cross_attention 134 | else None, 135 | attention_mask=attention_mask, 136 | **cross_attention_kwargs, 137 | ) 138 | else: 139 | if MODE == "write": 140 | self.bank.append(norm_hidden_states.clone()) 141 | attn_output = self.attn1( 142 | norm_hidden_states, 143 | encoder_hidden_states=encoder_hidden_states 144 | if self.only_cross_attention 145 | else None, 146 | attention_mask=attention_mask, 147 | **cross_attention_kwargs, 148 | ) 149 | if MODE == "read": 150 | bank_fea = [ 151 | rearrange( 152 | d.unsqueeze(1).repeat(1, video_length, 1, 1), 153 | "b t l c -> (b t) l c", 154 | ) 155 | for d in self.bank 156 | ] 157 | modify_norm_hidden_states = torch.cat( 158 | [norm_hidden_states] + bank_fea, dim=1 159 | ) 160 | hidden_states_uc = ( 161 | self.attn1( 162 | norm_hidden_states, 163 | encoder_hidden_states=modify_norm_hidden_states, 164 | attention_mask=attention_mask, 165 | ) 166 | + hidden_states 167 | ) 168 | if do_classifier_free_guidance: 169 | hidden_states_c = hidden_states_uc.clone() 170 | _uc_mask = uc_mask.clone() 171 | if hidden_states.shape[0] != _uc_mask.shape[0]: 172 | _uc_mask = ( 173 | torch.Tensor( 174 | [1] * (hidden_states.shape[0] // 2) 175 | + [0] * (hidden_states.shape[0] // 2) 176 | ) 177 | .to(device) 178 | .bool() 179 | ) 180 | hidden_states_c[_uc_mask] = ( 181 | self.attn1( 182 | norm_hidden_states[_uc_mask], 183 | encoder_hidden_states=norm_hidden_states[_uc_mask], 184 | attention_mask=attention_mask, 185 | ) 186 | + hidden_states[_uc_mask] 187 | ) 188 | hidden_states = hidden_states_c.clone() 189 | else: 190 | hidden_states = hidden_states_uc 191 | 192 | # self.bank.clear() 193 | if self.attn2 is not None: 194 | # Cross-Attention 195 | norm_hidden_states = ( 196 | self.norm2(hidden_states, timestep) 197 | if self.use_ada_layer_norm 198 | else self.norm2(hidden_states) 199 | ) 200 | hidden_states = ( 201 | self.attn2( 202 | norm_hidden_states, 203 | encoder_hidden_states=encoder_hidden_states, 204 | attention_mask=attention_mask, 205 | ) 206 | + hidden_states 207 | ) 208 | 209 | # Feed-forward 210 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 211 | 212 | # Temporal-Attention 213 | if self.unet_use_temporal_attention: 214 | d = hidden_states.shape[1] 215 | hidden_states = rearrange( 216 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 217 | ) 218 | norm_hidden_states = ( 219 | self.norm_temp(hidden_states, timestep) 220 | if self.use_ada_layer_norm 221 | else self.norm_temp(hidden_states) 222 | ) 223 | hidden_states = ( 224 | self.attn_temp(norm_hidden_states) + hidden_states 225 | ) 226 | hidden_states = rearrange( 227 | hidden_states, "(b d) f c -> (b f) d c", d=d 228 | ) 229 | 230 | return hidden_states 231 | 232 | if self.use_ada_layer_norm_zero: 233 | attn_output = gate_msa.unsqueeze(1) * attn_output 234 | hidden_states = attn_output + hidden_states 235 | 236 | if self.attn2 is not None: 237 | norm_hidden_states = ( 238 | self.norm2(hidden_states, timestep) 239 | if self.use_ada_layer_norm 240 | else self.norm2(hidden_states) 241 | ) 242 | 243 | # 2. Cross-Attention 244 | attn_output = self.attn2( 245 | norm_hidden_states, 246 | encoder_hidden_states=encoder_hidden_states, 247 | attention_mask=encoder_attention_mask, 248 | **cross_attention_kwargs, 249 | ) 250 | hidden_states = attn_output + hidden_states 251 | 252 | # 3. Feed-forward 253 | norm_hidden_states = self.norm3(hidden_states) 254 | 255 | if self.use_ada_layer_norm_zero: 256 | norm_hidden_states = ( 257 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 258 | ) 259 | 260 | ff_output = self.ff(norm_hidden_states) 261 | 262 | if self.use_ada_layer_norm_zero: 263 | ff_output = gate_mlp.unsqueeze(1) * ff_output 264 | 265 | hidden_states = ff_output + hidden_states 266 | 267 | return hidden_states 268 | 269 | if self.reference_attn: 270 | if self.fusion_blocks == "midup": 271 | attn_modules = [ 272 | module 273 | for module in ( 274 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 275 | ) 276 | if isinstance(module, BasicTransformerBlock) 277 | or isinstance(module, TemporalBasicTransformerBlock) 278 | ] 279 | elif self.fusion_blocks == "full": 280 | attn_modules = [ 281 | module 282 | for module in torch_dfs(self.unet) 283 | if isinstance(module, BasicTransformerBlock) 284 | or isinstance(module, TemporalBasicTransformerBlock) 285 | ] 286 | attn_modules = sorted( 287 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 288 | ) 289 | 290 | for i, module in enumerate(attn_modules): 291 | module._original_inner_forward = module.forward 292 | if isinstance(module, BasicTransformerBlock): 293 | module.forward = hacked_basic_transformer_inner_forward.__get__( 294 | module, BasicTransformerBlock 295 | ) 296 | if isinstance(module, TemporalBasicTransformerBlock): 297 | module.forward = hacked_basic_transformer_inner_forward.__get__( 298 | module, TemporalBasicTransformerBlock 299 | ) 300 | 301 | module.bank = [] 302 | module.attn_weight = float(i) / float(len(attn_modules)) 303 | 304 | def update(self, writer, dtype=torch.float16): 305 | if self.reference_attn: 306 | if self.fusion_blocks == "midup": 307 | reader_attn_modules = [ 308 | module 309 | for module in ( 310 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 311 | ) 312 | if isinstance(module, TemporalBasicTransformerBlock) 313 | ] 314 | writer_attn_modules = [ 315 | module 316 | for module in ( 317 | torch_dfs(writer.unet.mid_block) 318 | + torch_dfs(writer.unet.up_blocks) 319 | ) 320 | if isinstance(module, BasicTransformerBlock) 321 | ] 322 | elif self.fusion_blocks == "full": 323 | reader_attn_modules = [ 324 | module 325 | for module in torch_dfs(self.unet) 326 | if isinstance(module, TemporalBasicTransformerBlock) 327 | ] 328 | writer_attn_modules = [ 329 | module 330 | for module in torch_dfs(writer.unet) 331 | if isinstance(module, BasicTransformerBlock) 332 | ] 333 | reader_attn_modules = sorted( 334 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 335 | ) 336 | writer_attn_modules = sorted( 337 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 338 | ) 339 | for r, w in zip(reader_attn_modules, writer_attn_modules): 340 | r.bank = [v.clone().to(dtype) for v in w.bank] 341 | # w.bank.clear() 342 | 343 | def clear(self): 344 | if self.reference_attn: 345 | if self.fusion_blocks == "midup": 346 | reader_attn_modules = [ 347 | module 348 | for module in ( 349 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) 350 | ) 351 | if isinstance(module, BasicTransformerBlock) 352 | or isinstance(module, TemporalBasicTransformerBlock) 353 | ] 354 | elif self.fusion_blocks == "full": 355 | reader_attn_modules = [ 356 | module 357 | for module in torch_dfs(self.unet) 358 | if isinstance(module, BasicTransformerBlock) 359 | or isinstance(module, TemporalBasicTransformerBlock) 360 | ] 361 | reader_attn_modules = sorted( 362 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] 363 | ) 364 | for r in reader_attn_modules: 365 | r.bank.clear() 366 | -------------------------------------------------------------------------------- /src/models/pose_guider.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from diffusers.models.modeling_utils import ModelMixin 7 | 8 | from src.models.motion_module import zero_module 9 | from src.models.resnet import InflatedConv3d 10 | 11 | 12 | class PoseGuider(ModelMixin): 13 | def __init__( 14 | self, 15 | conditioning_embedding_channels: int, 16 | conditioning_channels: int = 3, 17 | block_out_channels: Tuple[int] = (16, 32, 64, 128), 18 | ): 19 | super().__init__() 20 | self.conv_in = InflatedConv3d( 21 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 22 | ) 23 | 24 | self.blocks = nn.ModuleList([]) 25 | 26 | for i in range(len(block_out_channels) - 1): 27 | channel_in = block_out_channels[i] 28 | channel_out = block_out_channels[i + 1] 29 | self.blocks.append( 30 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) 31 | ) 32 | self.blocks.append( 33 | InflatedConv3d( 34 | channel_in, channel_out, kernel_size=3, padding=1, stride=2 35 | ) 36 | ) 37 | 38 | self.conv_out = zero_module( 39 | InflatedConv3d( 40 | block_out_channels[-1], 41 | conditioning_embedding_channels, 42 | kernel_size=3, 43 | padding=1, 44 | ) 45 | ) 46 | 47 | def forward(self, conditioning): 48 | embedding = self.conv_in(conditioning) 49 | embedding = F.silu(embedding) 50 | 51 | for block in self.blocks: 52 | embedding = block(embedding) 53 | embedding = F.silu(embedding) 54 | 55 | embedding = self.conv_out(embedding) 56 | 57 | return embedding 58 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | class InflatedConv3d(nn.Conv2d): 10 | def forward(self, x): 11 | video_length = x.shape[2] 12 | 13 | x = rearrange(x, "b c f h w -> (b f) c h w") 14 | x = super().forward(x) 15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 16 | 17 | return x 18 | 19 | 20 | class InflatedGroupNorm(nn.GroupNorm): 21 | def forward(self, x): 22 | video_length = x.shape[2] 23 | 24 | x = rearrange(x, "b c f h w -> (b f) c h w") 25 | x = super().forward(x) 26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 27 | 28 | return x 29 | 30 | 31 | class Upsample3D(nn.Module): 32 | def __init__( 33 | self, 34 | channels, 35 | use_conv=False, 36 | use_conv_transpose=False, 37 | out_channels=None, 38 | name="conv", 39 | ): 40 | super().__init__() 41 | self.channels = channels 42 | self.out_channels = out_channels or channels 43 | self.use_conv = use_conv 44 | self.use_conv_transpose = use_conv_transpose 45 | self.name = name 46 | 47 | conv = None 48 | if use_conv_transpose: 49 | raise NotImplementedError 50 | elif use_conv: 51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 52 | 53 | def forward(self, hidden_states, output_size=None): 54 | assert hidden_states.shape[1] == self.channels 55 | 56 | if self.use_conv_transpose: 57 | raise NotImplementedError 58 | 59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 60 | dtype = hidden_states.dtype 61 | if dtype == torch.bfloat16: 62 | hidden_states = hidden_states.to(torch.float32) 63 | 64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 65 | if hidden_states.shape[0] >= 64: 66 | hidden_states = hidden_states.contiguous() 67 | 68 | # if `output_size` is passed we force the interpolation output 69 | # size and do not make use of `scale_factor=2` 70 | if output_size is None: 71 | hidden_states = F.interpolate( 72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 73 | ) 74 | else: 75 | hidden_states = F.interpolate( 76 | hidden_states, size=output_size, mode="nearest" 77 | ) 78 | 79 | # If the input is bfloat16, we cast back to bfloat16 80 | if dtype == torch.bfloat16: 81 | hidden_states = hidden_states.to(dtype) 82 | 83 | # if self.use_conv: 84 | # if self.name == "conv": 85 | # hidden_states = self.conv(hidden_states) 86 | # else: 87 | # hidden_states = self.Conv2d_0(hidden_states) 88 | hidden_states = self.conv(hidden_states) 89 | 90 | return hidden_states 91 | 92 | 93 | class Downsample3D(nn.Module): 94 | def __init__( 95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 96 | ): 97 | super().__init__() 98 | self.channels = channels 99 | self.out_channels = out_channels or channels 100 | self.use_conv = use_conv 101 | self.padding = padding 102 | stride = 2 103 | self.name = name 104 | 105 | if use_conv: 106 | self.conv = InflatedConv3d( 107 | self.channels, self.out_channels, 3, stride=stride, padding=padding 108 | ) 109 | else: 110 | raise NotImplementedError 111 | 112 | def forward(self, hidden_states): 113 | assert hidden_states.shape[1] == self.channels 114 | if self.use_conv and self.padding == 0: 115 | raise NotImplementedError 116 | 117 | assert hidden_states.shape[1] == self.channels 118 | hidden_states = self.conv(hidden_states) 119 | 120 | return hidden_states 121 | 122 | 123 | class ResnetBlock3D(nn.Module): 124 | def __init__( 125 | self, 126 | *, 127 | in_channels, 128 | out_channels=None, 129 | conv_shortcut=False, 130 | dropout=0.0, 131 | temb_channels=512, 132 | groups=32, 133 | groups_out=None, 134 | pre_norm=True, 135 | eps=1e-6, 136 | non_linearity="swish", 137 | time_embedding_norm="default", 138 | output_scale_factor=1.0, 139 | use_in_shortcut=None, 140 | use_inflated_groupnorm=None, 141 | ): 142 | super().__init__() 143 | self.pre_norm = pre_norm 144 | self.pre_norm = True 145 | self.in_channels = in_channels 146 | out_channels = in_channels if out_channels is None else out_channels 147 | self.out_channels = out_channels 148 | self.use_conv_shortcut = conv_shortcut 149 | self.time_embedding_norm = time_embedding_norm 150 | self.output_scale_factor = output_scale_factor 151 | 152 | if groups_out is None: 153 | groups_out = groups 154 | 155 | assert use_inflated_groupnorm != None 156 | if use_inflated_groupnorm: 157 | self.norm1 = InflatedGroupNorm( 158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 159 | ) 160 | else: 161 | self.norm1 = torch.nn.GroupNorm( 162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 163 | ) 164 | 165 | self.conv1 = InflatedConv3d( 166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | if temb_channels is not None: 170 | if self.time_embedding_norm == "default": 171 | time_emb_proj_out_channels = out_channels 172 | elif self.time_embedding_norm == "scale_shift": 173 | time_emb_proj_out_channels = out_channels * 2 174 | else: 175 | raise ValueError( 176 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 177 | ) 178 | 179 | self.time_emb_proj = torch.nn.Linear( 180 | temb_channels, time_emb_proj_out_channels 181 | ) 182 | else: 183 | self.time_emb_proj = None 184 | 185 | if use_inflated_groupnorm: 186 | self.norm2 = InflatedGroupNorm( 187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 188 | ) 189 | else: 190 | self.norm2 = torch.nn.GroupNorm( 191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 192 | ) 193 | self.dropout = torch.nn.Dropout(dropout) 194 | self.conv2 = InflatedConv3d( 195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 196 | ) 197 | 198 | if non_linearity == "swish": 199 | self.nonlinearity = lambda x: F.silu(x) 200 | elif non_linearity == "mish": 201 | self.nonlinearity = Mish() 202 | elif non_linearity == "silu": 203 | self.nonlinearity = nn.SiLU() 204 | 205 | self.use_in_shortcut = ( 206 | self.in_channels != self.out_channels 207 | if use_in_shortcut is None 208 | else use_in_shortcut 209 | ) 210 | 211 | self.conv_shortcut = None 212 | if self.use_in_shortcut: 213 | self.conv_shortcut = InflatedConv3d( 214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 215 | ) 216 | 217 | def forward(self, input_tensor, temb): 218 | hidden_states = input_tensor 219 | 220 | hidden_states = self.norm1(hidden_states) 221 | hidden_states = self.nonlinearity(hidden_states) 222 | 223 | hidden_states = self.conv1(hidden_states) 224 | 225 | if temb is not None: 226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 227 | 228 | if temb is not None and self.time_embedding_norm == "default": 229 | hidden_states = hidden_states + temb 230 | 231 | hidden_states = self.norm2(hidden_states) 232 | 233 | if temb is not None and self.time_embedding_norm == "scale_shift": 234 | scale, shift = torch.chunk(temb, 2, dim=1) 235 | hidden_states = hidden_states * (1 + scale) + shift 236 | 237 | hidden_states = self.nonlinearity(hidden_states) 238 | 239 | hidden_states = self.dropout(hidden_states) 240 | hidden_states = self.conv2(hidden_states) 241 | 242 | if self.conv_shortcut is not None: 243 | input_tensor = self.conv_shortcut(input_tensor) 244 | 245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 246 | 247 | return output_tensor 248 | 249 | 250 | class Mish(torch.nn.Module): 251 | def forward(self, hidden_states): 252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 253 | -------------------------------------------------------------------------------- /src/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from .attention import TemporalBasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer3DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | 19 | 20 | if is_xformers_available(): 21 | import xformers 22 | import xformers.ops 23 | else: 24 | xformers = None 25 | 26 | 27 | class Transformer3DModel(ModelMixin, ConfigMixin): 28 | _supports_gradient_checkpointing = True 29 | 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | name=None, 49 | ): 50 | super().__init__() 51 | self.use_linear_projection = use_linear_projection 52 | self.num_attention_heads = num_attention_heads 53 | self.attention_head_dim = attention_head_dim 54 | inner_dim = num_attention_heads * attention_head_dim 55 | 56 | # Define input layers 57 | self.in_channels = in_channels 58 | self.name=name 59 | 60 | self.norm = torch.nn.GroupNorm( 61 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 62 | ) 63 | if use_linear_projection: 64 | self.proj_in = nn.Linear(in_channels, inner_dim) 65 | else: 66 | self.proj_in = nn.Conv2d( 67 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 68 | ) 69 | 70 | # Define transformers blocks 71 | self.transformer_blocks = nn.ModuleList( 72 | [ 73 | TemporalBasicTransformerBlock( 74 | inner_dim, 75 | num_attention_heads, 76 | attention_head_dim, 77 | dropout=dropout, 78 | cross_attention_dim=cross_attention_dim, 79 | activation_fn=activation_fn, 80 | num_embeds_ada_norm=num_embeds_ada_norm, 81 | attention_bias=attention_bias, 82 | only_cross_attention=only_cross_attention, 83 | upcast_attention=upcast_attention, 84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 85 | unet_use_temporal_attention=unet_use_temporal_attention, 86 | name=f"{self.name}_{d}_TransformerBlock" if self.name else None, 87 | ) 88 | for d in range(num_layers) 89 | ] 90 | ) 91 | 92 | # 4. Define output layers 93 | if use_linear_projection: 94 | self.proj_out = nn.Linear(in_channels, inner_dim) 95 | else: 96 | self.proj_out = nn.Conv2d( 97 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 98 | ) 99 | 100 | self.gradient_checkpointing = False 101 | 102 | def _set_gradient_checkpointing(self, module, value=False): 103 | if hasattr(module, "gradient_checkpointing"): 104 | module.gradient_checkpointing = value 105 | 106 | def forward( 107 | self, 108 | hidden_states, 109 | encoder_hidden_states=None, 110 | self_attention_additional_feats=None, 111 | mode=None, 112 | timestep=None, 113 | return_dict: bool = True, 114 | ): 115 | # Input 116 | assert ( 117 | hidden_states.dim() == 5 118 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 119 | video_length = hidden_states.shape[2] 120 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 121 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 122 | encoder_hidden_states = repeat( 123 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 124 | ) 125 | 126 | batch, channel, height, weight = hidden_states.shape 127 | residual = hidden_states 128 | 129 | hidden_states = self.norm(hidden_states) 130 | if not self.use_linear_projection: 131 | hidden_states = self.proj_in(hidden_states) 132 | inner_dim = hidden_states.shape[1] 133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 134 | batch, height * weight, inner_dim 135 | ) 136 | else: 137 | inner_dim = hidden_states.shape[1] 138 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 139 | batch, height * weight, inner_dim 140 | ) 141 | hidden_states = self.proj_in(hidden_states) 142 | 143 | # Blocks 144 | for i, block in enumerate(self.transformer_blocks): 145 | 146 | if self.training and self.gradient_checkpointing: 147 | 148 | def create_custom_forward(module, return_dict=None): 149 | def custom_forward(*inputs): 150 | if return_dict is not None: 151 | return module(*inputs, return_dict=return_dict) 152 | else: 153 | return module(*inputs) 154 | 155 | return custom_forward 156 | 157 | # if hasattr(self.block, 'bank') and len(self.block.bank) > 0: 158 | # hidden_states 159 | hidden_states = torch.utils.checkpoint.checkpoint( 160 | create_custom_forward(block), 161 | hidden_states, 162 | encoder_hidden_states=encoder_hidden_states, 163 | timestep=timestep, 164 | attention_mask=None, 165 | video_length=video_length, 166 | self_attention_additional_feats=self_attention_additional_feats, 167 | mode=mode, 168 | ) 169 | else: 170 | 171 | hidden_states = block( 172 | hidden_states, 173 | encoder_hidden_states=encoder_hidden_states, 174 | timestep=timestep, 175 | self_attention_additional_feats=self_attention_additional_feats, 176 | mode=mode, 177 | video_length=video_length, 178 | ) 179 | 180 | # Output 181 | if not self.use_linear_projection: 182 | hidden_states = ( 183 | hidden_states.reshape(batch, height, weight, inner_dim) 184 | .permute(0, 3, 1, 2) 185 | .contiguous() 186 | ) 187 | hidden_states = self.proj_out(hidden_states) 188 | else: 189 | hidden_states = self.proj_out(hidden_states) 190 | hidden_states = ( 191 | hidden_states.reshape(batch, height, weight, inner_dim) 192 | .permute(0, 3, 1, 2) 193 | .contiguous() 194 | ) 195 | 196 | output = hidden_states + residual 197 | 198 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 199 | if not return_dict: 200 | return (output,) 201 | 202 | return Transformer3DModelOutput(sample=output) 203 | -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/src/pipelines/__init__.py -------------------------------------------------------------------------------- /src/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # TODO: Adapted from cli 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | def ordered_halving(val): 8 | bin_str = f"{val:064b}" 9 | bin_flip = bin_str[::-1] 10 | as_int = int(bin_flip, 2) 11 | 12 | return as_int / (1 << 64) 13 | 14 | 15 | def uniform( 16 | step: int = ..., 17 | num_steps: Optional[int] = None, 18 | num_frames: int = ..., 19 | context_size: Optional[int] = None, 20 | context_stride: int = 3, 21 | context_overlap: int = 4, 22 | closed_loop: bool = True, 23 | ): 24 | if num_frames <= context_size: 25 | yield list(range(num_frames)) 26 | return 27 | 28 | context_stride = min( 29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 30 | ) 31 | 32 | for context_step in 1 << np.arange(context_stride): 33 | pad = int(round(num_frames * ordered_halving(step))) 34 | for j in range( 35 | int(ordered_halving(step) * context_step) + pad, 36 | num_frames + pad + (0 if closed_loop else -context_overlap), 37 | (context_size * context_step - context_overlap), 38 | ): 39 | yield [ 40 | e % num_frames 41 | for e in range(j, j + context_size * context_step, context_step) 42 | ] 43 | 44 | 45 | def get_context_scheduler(name: str) -> Callable: 46 | if name == "uniform": 47 | return uniform 48 | else: 49 | raise ValueError(f"Unknown context_overlap policy {name}") 50 | 51 | 52 | def get_total_steps( 53 | scheduler, 54 | timesteps: List[int], 55 | num_steps: Optional[int] = None, 56 | num_frames: int = ..., 57 | context_size: Optional[int] = None, 58 | context_stride: int = 3, 59 | context_overlap: int = 4, 60 | closed_loop: bool = True, 61 | ): 62 | return sum( 63 | len( 64 | list( 65 | scheduler( 66 | i, 67 | num_steps, 68 | num_frames, 69 | context_size, 70 | context_stride, 71 | context_overlap, 72 | ) 73 | ) 74 | ) 75 | for i in range(len(timesteps)) 76 | ) 77 | -------------------------------------------------------------------------------- /src/pipelines/pipeline_pose2img.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from diffusers.image_processor import VaeImageProcessor 9 | from diffusers.schedulers import ( 10 | DDIMScheduler, 11 | DPMSolverMultistepScheduler, 12 | EulerAncestralDiscreteScheduler, 13 | EulerDiscreteScheduler, 14 | LMSDiscreteScheduler, 15 | PNDMScheduler, 16 | ) 17 | from diffusers.utils import BaseOutput, is_accelerate_available 18 | from diffusers.utils.torch_utils import randn_tensor 19 | from einops import rearrange 20 | from tqdm import tqdm 21 | from transformers import CLIPImageProcessor 22 | 23 | from src.models.mutual_self_attention import ReferenceAttentionControl 24 | 25 | 26 | @dataclass 27 | class Pose2ImagePipelineOutput(BaseOutput): 28 | images: Union[torch.Tensor, np.ndarray] 29 | 30 | 31 | class Pose2ImagePipeline(DiffusionPipeline): 32 | _optional_components = [] 33 | 34 | def __init__( 35 | self, 36 | vae, 37 | image_encoder, 38 | reference_unet, 39 | denoising_unet, 40 | pose_guider, 41 | scheduler: Union[ 42 | DDIMScheduler, 43 | PNDMScheduler, 44 | LMSDiscreteScheduler, 45 | EulerDiscreteScheduler, 46 | EulerAncestralDiscreteScheduler, 47 | DPMSolverMultistepScheduler, 48 | ], 49 | ): 50 | super().__init__() 51 | 52 | self.register_modules( 53 | vae=vae, 54 | image_encoder=image_encoder, 55 | reference_unet=reference_unet, 56 | denoising_unet=denoising_unet, 57 | pose_guider=pose_guider, 58 | scheduler=scheduler, 59 | ) 60 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 61 | self.clip_image_processor = CLIPImageProcessor() 62 | self.ref_image_processor = VaeImageProcessor( 63 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 64 | ) 65 | self.cond_image_processor = VaeImageProcessor( 66 | vae_scale_factor=self.vae_scale_factor, 67 | do_convert_rgb=True, 68 | do_normalize=False, 69 | ) 70 | 71 | def enable_vae_slicing(self): 72 | self.vae.enable_slicing() 73 | 74 | def disable_vae_slicing(self): 75 | self.vae.disable_slicing() 76 | 77 | def enable_sequential_cpu_offload(self, gpu_id=0): 78 | if is_accelerate_available(): 79 | from accelerate import cpu_offload 80 | else: 81 | raise ImportError("Please install accelerate via `pip install accelerate`") 82 | 83 | device = torch.device(f"cuda:{gpu_id}") 84 | 85 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 86 | if cpu_offloaded_model is not None: 87 | cpu_offload(cpu_offloaded_model, device) 88 | 89 | @property 90 | def _execution_device(self): 91 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 92 | return self.device 93 | for module in self.unet.modules(): 94 | if ( 95 | hasattr(module, "_hf_hook") 96 | and hasattr(module._hf_hook, "execution_device") 97 | and module._hf_hook.execution_device is not None 98 | ): 99 | return torch.device(module._hf_hook.execution_device) 100 | return self.device 101 | 102 | def decode_latents(self, latents): 103 | video_length = latents.shape[2] 104 | latents = 1 / 0.18215 * latents 105 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 106 | # video = self.vae.decode(latents).sample 107 | video = [] 108 | for frame_idx in tqdm(range(latents.shape[0])): 109 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 110 | video = torch.cat(video) 111 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 112 | video = (video / 2 + 0.5).clamp(0, 1) 113 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 114 | video = video.cpu().float().numpy() 115 | return video 116 | 117 | def prepare_extra_step_kwargs(self, generator, eta): 118 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 119 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 120 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 121 | # and should be between [0, 1] 122 | 123 | accepts_eta = "eta" in set( 124 | inspect.signature(self.scheduler.step).parameters.keys() 125 | ) 126 | extra_step_kwargs = {} 127 | if accepts_eta: 128 | extra_step_kwargs["eta"] = eta 129 | 130 | # check if the scheduler accepts generator 131 | accepts_generator = "generator" in set( 132 | inspect.signature(self.scheduler.step).parameters.keys() 133 | ) 134 | if accepts_generator: 135 | extra_step_kwargs["generator"] = generator 136 | return extra_step_kwargs 137 | 138 | def prepare_latents( 139 | self, 140 | batch_size, 141 | num_channels_latents, 142 | width, 143 | height, 144 | dtype, 145 | device, 146 | generator, 147 | latents=None, 148 | ): 149 | shape = ( 150 | batch_size, 151 | num_channels_latents, 152 | height // self.vae_scale_factor, 153 | width // self.vae_scale_factor, 154 | ) 155 | if isinstance(generator, list) and len(generator) != batch_size: 156 | raise ValueError( 157 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 158 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 159 | ) 160 | 161 | if latents is None: 162 | latents = randn_tensor( 163 | shape, generator=generator, device=device, dtype=dtype 164 | ) 165 | else: 166 | latents = latents.to(device) 167 | 168 | # scale the initial noise by the standard deviation required by the scheduler 169 | latents = latents * self.scheduler.init_noise_sigma 170 | return latents 171 | 172 | def prepare_condition( 173 | self, 174 | cond_image, 175 | width, 176 | height, 177 | device, 178 | dtype, 179 | do_classififer_free_guidance=False, 180 | ): 181 | image = self.cond_image_processor.preprocess( 182 | cond_image, height=height, width=width 183 | ).to(dtype=torch.float32) 184 | 185 | image = image.to(device=device, dtype=dtype) 186 | 187 | if do_classififer_free_guidance: 188 | image = torch.cat([image] * 2) 189 | 190 | return image 191 | 192 | @torch.no_grad() 193 | def __call__( 194 | self, 195 | ref_image, 196 | pose_image, 197 | width, 198 | height, 199 | num_inference_steps, 200 | guidance_scale, 201 | num_images_per_prompt=1, 202 | eta: float = 0.0, 203 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 204 | output_type: Optional[str] = "tensor", 205 | return_dict: bool = True, 206 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 207 | callback_steps: Optional[int] = 1, 208 | **kwargs, 209 | ): 210 | # Default height and width to unet 211 | height = height or self.unet.config.sample_size * self.vae_scale_factor 212 | width = width or self.unet.config.sample_size * self.vae_scale_factor 213 | 214 | device = self._execution_device 215 | 216 | do_classifier_free_guidance = guidance_scale > 1.0 217 | 218 | # Prepare timesteps 219 | self.scheduler.set_timesteps(num_inference_steps, device=device) 220 | timesteps = self.scheduler.timesteps 221 | 222 | batch_size = 1 223 | 224 | # Prepare clip image embeds 225 | clip_image = self.clip_image_processor.preprocess( 226 | ref_image.resize((224, 224)), return_tensors="pt" 227 | ).pixel_values 228 | clip_image_embeds = self.image_encoder( 229 | clip_image.to(device, dtype=self.image_encoder.dtype) 230 | ).image_embeds 231 | image_prompt_embeds = clip_image_embeds.unsqueeze(1) 232 | uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds) 233 | 234 | if do_classifier_free_guidance: 235 | image_prompt_embeds = torch.cat( 236 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0 237 | ) 238 | 239 | reference_control_writer = ReferenceAttentionControl( 240 | self.reference_unet, 241 | do_classifier_free_guidance=do_classifier_free_guidance, 242 | mode="write", 243 | batch_size=batch_size, 244 | fusion_blocks="full", 245 | ) 246 | reference_control_reader = ReferenceAttentionControl( 247 | self.denoising_unet, 248 | do_classifier_free_guidance=do_classifier_free_guidance, 249 | mode="read", 250 | batch_size=batch_size, 251 | fusion_blocks="full", 252 | ) 253 | 254 | num_channels_latents = self.denoising_unet.in_channels 255 | latents = self.prepare_latents( 256 | batch_size * num_images_per_prompt, 257 | num_channels_latents, 258 | width, 259 | height, 260 | clip_image_embeds.dtype, 261 | device, 262 | generator, 263 | ) 264 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w') 265 | latents_dtype = latents.dtype 266 | 267 | # Prepare extra step kwargs. 268 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 269 | 270 | # Prepare ref image latents 271 | ref_image_tensor = self.ref_image_processor.preprocess( 272 | ref_image, height=height, width=width 273 | ) # (bs, c, width, height) 274 | ref_image_tensor = ref_image_tensor.to( 275 | dtype=self.vae.dtype, device=self.vae.device 276 | ) 277 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 278 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 279 | 280 | # Prepare pose condition image 281 | pose_cond_tensor = self.cond_image_processor.preprocess( 282 | pose_image, height=height, width=width 283 | ) 284 | pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w) 285 | pose_cond_tensor = pose_cond_tensor.to( 286 | device=device, dtype=self.pose_guider.dtype 287 | ) 288 | pose_fea = self.pose_guider(pose_cond_tensor) 289 | pose_fea = ( 290 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea 291 | ) 292 | 293 | # denoising loop 294 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 295 | with self.progress_bar(total=num_inference_steps) as progress_bar: 296 | for i, t in enumerate(timesteps): 297 | # 1. Forward reference image 298 | if i == 0: 299 | self.reference_unet( 300 | ref_image_latents.repeat( 301 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 302 | ), 303 | torch.zeros_like(t), 304 | encoder_hidden_states=image_prompt_embeds, 305 | return_dict=False, 306 | ) 307 | 308 | # 2. Update reference unet feature into denosing net 309 | reference_control_reader.update(reference_control_writer) 310 | 311 | # 3.1 expand the latents if we are doing classifier free guidance 312 | latent_model_input = ( 313 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 314 | ) 315 | latent_model_input = self.scheduler.scale_model_input( 316 | latent_model_input, t 317 | ) 318 | 319 | noise_pred = self.denoising_unet( 320 | latent_model_input, 321 | t, 322 | encoder_hidden_states=image_prompt_embeds, 323 | pose_cond_fea=pose_fea, 324 | return_dict=False, 325 | )[0] 326 | 327 | # perform guidance 328 | if do_classifier_free_guidance: 329 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 330 | noise_pred = noise_pred_uncond + guidance_scale * ( 331 | noise_pred_text - noise_pred_uncond 332 | ) 333 | 334 | # compute the previous noisy sample x_t -> x_t-1 335 | latents = self.scheduler.step( 336 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 337 | )[0] 338 | 339 | # call the callback, if provided 340 | if i == len(timesteps) - 1 or ( 341 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 342 | ): 343 | progress_bar.update() 344 | if callback is not None and i % callback_steps == 0: 345 | step_idx = i // getattr(self.scheduler, "order", 1) 346 | callback(step_idx, t, latents) 347 | reference_control_reader.clear() 348 | reference_control_writer.clear() 349 | 350 | # Post-processing 351 | image = self.decode_latents(latents) # (b, c, 1, h, w) 352 | 353 | # Convert to tensor 354 | if output_type == "tensor": 355 | image = torch.from_numpy(image) 356 | 357 | if not return_dict: 358 | return image 359 | 360 | return Pose2ImagePipelineOutput(images=image) 361 | -------------------------------------------------------------------------------- /src/pipelines/pipeline_pose2vid.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from diffusers.image_processor import VaeImageProcessor 9 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, 10 | EulerAncestralDiscreteScheduler, 11 | EulerDiscreteScheduler, LMSDiscreteScheduler, 12 | PNDMScheduler) 13 | from diffusers.utils import BaseOutput, is_accelerate_available 14 | from diffusers.utils.torch_utils import randn_tensor 15 | from einops import rearrange 16 | from tqdm import tqdm 17 | from transformers import CLIPImageProcessor 18 | 19 | from src.models.mutual_self_attention import ReferenceAttentionControl 20 | 21 | 22 | @dataclass 23 | class Pose2VideoPipelineOutput(BaseOutput): 24 | videos: Union[torch.Tensor, np.ndarray] 25 | 26 | 27 | class Pose2VideoPipeline(DiffusionPipeline): 28 | _optional_components = [] 29 | 30 | def __init__( 31 | self, 32 | vae, 33 | image_encoder, 34 | reference_unet, 35 | denoising_unet, 36 | pose_guider, 37 | scheduler: Union[ 38 | DDIMScheduler, 39 | PNDMScheduler, 40 | LMSDiscreteScheduler, 41 | EulerDiscreteScheduler, 42 | EulerAncestralDiscreteScheduler, 43 | DPMSolverMultistepScheduler, 44 | ], 45 | image_proj_model=None, 46 | tokenizer=None, 47 | text_encoder=None, 48 | ): 49 | super().__init__() 50 | 51 | self.register_modules( 52 | vae=vae, 53 | image_encoder=image_encoder, 54 | reference_unet=reference_unet, 55 | denoising_unet=denoising_unet, 56 | pose_guider=pose_guider, 57 | scheduler=scheduler, 58 | image_proj_model=image_proj_model, 59 | tokenizer=tokenizer, 60 | text_encoder=text_encoder, 61 | ) 62 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 63 | self.clip_image_processor = CLIPImageProcessor() 64 | self.ref_image_processor = VaeImageProcessor( 65 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 66 | ) 67 | self.cond_image_processor = VaeImageProcessor( 68 | vae_scale_factor=self.vae_scale_factor, 69 | do_convert_rgb=True, 70 | do_normalize=False, 71 | ) 72 | 73 | def enable_vae_slicing(self): 74 | self.vae.enable_slicing() 75 | 76 | def disable_vae_slicing(self): 77 | self.vae.disable_slicing() 78 | 79 | def enable_sequential_cpu_offload(self, gpu_id=0): 80 | if is_accelerate_available(): 81 | from accelerate import cpu_offload 82 | else: 83 | raise ImportError("Please install accelerate via `pip install accelerate`") 84 | 85 | device = torch.device(f"cuda:{gpu_id}") 86 | 87 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 88 | if cpu_offloaded_model is not None: 89 | cpu_offload(cpu_offloaded_model, device) 90 | 91 | @property 92 | def _execution_device(self): 93 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 94 | return self.device 95 | for module in self.unet.modules(): 96 | if ( 97 | hasattr(module, "_hf_hook") 98 | and hasattr(module._hf_hook, "execution_device") 99 | and module._hf_hook.execution_device is not None 100 | ): 101 | return torch.device(module._hf_hook.execution_device) 102 | return self.device 103 | 104 | def decode_latents(self, latents): 105 | video_length = latents.shape[2] 106 | latents = 1 / 0.18215 * latents 107 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 108 | # video = self.vae.decode(latents).sample 109 | video = [] 110 | for frame_idx in tqdm(range(latents.shape[0])): 111 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 112 | video = torch.cat(video) 113 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 114 | video = (video / 2 + 0.5).clamp(0, 1) 115 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 116 | video = video.cpu().float().numpy() 117 | return video 118 | 119 | def prepare_extra_step_kwargs(self, generator, eta): 120 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 121 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 122 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 123 | # and should be between [0, 1] 124 | 125 | accepts_eta = "eta" in set( 126 | inspect.signature(self.scheduler.step).parameters.keys() 127 | ) 128 | extra_step_kwargs = {} 129 | if accepts_eta: 130 | extra_step_kwargs["eta"] = eta 131 | 132 | # check if the scheduler accepts generator 133 | accepts_generator = "generator" in set( 134 | inspect.signature(self.scheduler.step).parameters.keys() 135 | ) 136 | if accepts_generator: 137 | extra_step_kwargs["generator"] = generator 138 | return extra_step_kwargs 139 | 140 | def prepare_latents( 141 | self, 142 | batch_size, 143 | num_channels_latents, 144 | width, 145 | height, 146 | video_length, 147 | dtype, 148 | device, 149 | generator, 150 | latents=None, 151 | ): 152 | shape = ( 153 | batch_size, 154 | num_channels_latents, 155 | video_length, 156 | height // self.vae_scale_factor, 157 | width // self.vae_scale_factor, 158 | ) 159 | if isinstance(generator, list) and len(generator) != batch_size: 160 | raise ValueError( 161 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 162 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 163 | ) 164 | 165 | if latents is None: 166 | latents = randn_tensor( 167 | shape, generator=generator, device=device, dtype=dtype 168 | ) 169 | else: 170 | latents = latents.to(device) 171 | 172 | # scale the initial noise by the standard deviation required by the scheduler 173 | latents = latents * self.scheduler.init_noise_sigma 174 | return latents 175 | 176 | def _encode_prompt( 177 | self, 178 | prompt, 179 | device, 180 | num_videos_per_prompt, 181 | do_classifier_free_guidance, 182 | negative_prompt, 183 | ): 184 | batch_size = len(prompt) if isinstance(prompt, list) else 1 185 | 186 | text_inputs = self.tokenizer( 187 | prompt, 188 | padding="max_length", 189 | max_length=self.tokenizer.model_max_length, 190 | truncation=True, 191 | return_tensors="pt", 192 | ) 193 | text_input_ids = text_inputs.input_ids 194 | untruncated_ids = self.tokenizer( 195 | prompt, padding="longest", return_tensors="pt" 196 | ).input_ids 197 | 198 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 199 | text_input_ids, untruncated_ids 200 | ): 201 | removed_text = self.tokenizer.batch_decode( 202 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 203 | ) 204 | 205 | if ( 206 | hasattr(self.text_encoder.config, "use_attention_mask") 207 | and self.text_encoder.config.use_attention_mask 208 | ): 209 | attention_mask = text_inputs.attention_mask.to(device) 210 | else: 211 | attention_mask = None 212 | 213 | text_embeddings = self.text_encoder( 214 | text_input_ids.to(device), 215 | attention_mask=attention_mask, 216 | ) 217 | text_embeddings = text_embeddings[0] 218 | 219 | # duplicate text embeddings for each generation per prompt, using mps friendly method 220 | bs_embed, seq_len, _ = text_embeddings.shape 221 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 222 | text_embeddings = text_embeddings.view( 223 | bs_embed * num_videos_per_prompt, seq_len, -1 224 | ) 225 | 226 | # get unconditional embeddings for classifier free guidance 227 | if do_classifier_free_guidance: 228 | uncond_tokens: List[str] 229 | if negative_prompt is None: 230 | uncond_tokens = [""] * batch_size 231 | elif type(prompt) is not type(negative_prompt): 232 | raise TypeError( 233 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 234 | f" {type(prompt)}." 235 | ) 236 | elif isinstance(negative_prompt, str): 237 | uncond_tokens = [negative_prompt] 238 | elif batch_size != len(negative_prompt): 239 | raise ValueError( 240 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 241 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 242 | " the batch size of `prompt`." 243 | ) 244 | else: 245 | uncond_tokens = negative_prompt 246 | 247 | max_length = text_input_ids.shape[-1] 248 | uncond_input = self.tokenizer( 249 | uncond_tokens, 250 | padding="max_length", 251 | max_length=max_length, 252 | truncation=True, 253 | return_tensors="pt", 254 | ) 255 | 256 | if ( 257 | hasattr(self.text_encoder.config, "use_attention_mask") 258 | and self.text_encoder.config.use_attention_mask 259 | ): 260 | attention_mask = uncond_input.attention_mask.to(device) 261 | else: 262 | attention_mask = None 263 | 264 | uncond_embeddings = self.text_encoder( 265 | uncond_input.input_ids.to(device), 266 | attention_mask=attention_mask, 267 | ) 268 | uncond_embeddings = uncond_embeddings[0] 269 | 270 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 271 | seq_len = uncond_embeddings.shape[1] 272 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 273 | uncond_embeddings = uncond_embeddings.view( 274 | batch_size * num_videos_per_prompt, seq_len, -1 275 | ) 276 | 277 | # For classifier free guidance, we need to do two forward passes. 278 | # Here we concatenate the unconditional and text embeddings into a single batch 279 | # to avoid doing two forward passes 280 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 281 | 282 | return text_embeddings 283 | 284 | @torch.no_grad() 285 | def __call__( 286 | self, 287 | ref_image, 288 | pose_images, 289 | width, 290 | height, 291 | video_length, 292 | num_inference_steps, 293 | guidance_scale, 294 | num_images_per_prompt=1, 295 | eta: float = 0.0, 296 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 297 | output_type: Optional[str] = "tensor", 298 | return_dict: bool = True, 299 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 300 | callback_steps: Optional[int] = 1, 301 | **kwargs, 302 | ): 303 | # Default height and width to unet 304 | height = height or self.unet.config.sample_size * self.vae_scale_factor 305 | width = width or self.unet.config.sample_size * self.vae_scale_factor 306 | 307 | device = self._execution_device 308 | 309 | do_classifier_free_guidance = guidance_scale > 1.0 310 | 311 | # Prepare timesteps 312 | self.scheduler.set_timesteps(num_inference_steps, device=device) 313 | timesteps = self.scheduler.timesteps 314 | 315 | batch_size = 1 316 | 317 | # Prepare clip image embeds 318 | clip_image = self.clip_image_processor.preprocess( 319 | ref_image, return_tensors="pt" 320 | ).pixel_values 321 | clip_image_embeds = self.image_encoder( 322 | clip_image.to(device, dtype=self.image_encoder.dtype) 323 | ).image_embeds 324 | encoder_hidden_states = clip_image_embeds.unsqueeze(1) 325 | uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) 326 | 327 | if do_classifier_free_guidance: 328 | encoder_hidden_states = torch.cat( 329 | [uncond_encoder_hidden_states, encoder_hidden_states], dim=0 330 | ) 331 | reference_control_writer = ReferenceAttentionControl( 332 | self.reference_unet, 333 | do_classifier_free_guidance=do_classifier_free_guidance, 334 | mode="write", 335 | batch_size=batch_size, 336 | fusion_blocks="full", 337 | ) 338 | reference_control_reader = ReferenceAttentionControl( 339 | self.denoising_unet, 340 | do_classifier_free_guidance=do_classifier_free_guidance, 341 | mode="read", 342 | batch_size=batch_size, 343 | fusion_blocks="full", 344 | ) 345 | 346 | num_channels_latents = self.denoising_unet.in_channels 347 | latents = self.prepare_latents( 348 | batch_size * num_images_per_prompt, 349 | num_channels_latents, 350 | width, 351 | height, 352 | video_length, 353 | clip_image_embeds.dtype, 354 | device, 355 | generator, 356 | ) 357 | 358 | # Prepare extra step kwargs. 359 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 360 | 361 | # Prepare ref image latents 362 | ref_image_tensor = self.ref_image_processor.preprocess( 363 | ref_image, height=height, width=width 364 | ) # (bs, c, width, height) 365 | ref_image_tensor = ref_image_tensor.to( 366 | dtype=self.vae.dtype, device=self.vae.device 367 | ) 368 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 369 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 370 | 371 | # Prepare a list of pose condition images 372 | pose_cond_tensor_list = [] 373 | for pose_image in pose_images: 374 | pose_cond_tensor = ( 375 | torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0 376 | ) 377 | pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze( 378 | 1 379 | ) # (c, 1, h, w) 380 | pose_cond_tensor_list.append(pose_cond_tensor) 381 | pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w) 382 | pose_cond_tensor = pose_cond_tensor.unsqueeze(0) 383 | pose_cond_tensor = pose_cond_tensor.to( 384 | device=device, dtype=self.pose_guider.dtype 385 | ) 386 | pose_fea = self.pose_guider(pose_cond_tensor) 387 | pose_fea = ( 388 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea 389 | ) 390 | 391 | # denoising loop 392 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 393 | with self.progress_bar(total=num_inference_steps) as progress_bar: 394 | for i, t in enumerate(timesteps): 395 | # 1. Forward reference image 396 | if i == 0: 397 | self.reference_unet( 398 | ref_image_latents.repeat( 399 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 400 | ), 401 | torch.zeros_like(t), 402 | # t, 403 | encoder_hidden_states=encoder_hidden_states, 404 | return_dict=False, 405 | ) 406 | reference_control_reader.update(reference_control_writer) 407 | 408 | # 3.1 expand the latents if we are doing classifier free guidance 409 | latent_model_input = ( 410 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 411 | ) 412 | latent_model_input = self.scheduler.scale_model_input( 413 | latent_model_input, t 414 | ) 415 | 416 | noise_pred = self.denoising_unet( 417 | latent_model_input, 418 | t, 419 | encoder_hidden_states=encoder_hidden_states, 420 | pose_cond_fea=pose_fea, 421 | return_dict=False, 422 | )[0] 423 | 424 | # perform guidance 425 | if do_classifier_free_guidance: 426 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 427 | noise_pred = noise_pred_uncond + guidance_scale * ( 428 | noise_pred_text - noise_pred_uncond 429 | ) 430 | 431 | # compute the previous noisy sample x_t -> x_t-1 432 | latents = self.scheduler.step( 433 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 434 | )[0] 435 | 436 | # call the callback, if provided 437 | if i == len(timesteps) - 1 or ( 438 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 439 | ): 440 | progress_bar.update() 441 | if callback is not None and i % callback_steps == 0: 442 | step_idx = i // getattr(self.scheduler, "order", 1) 443 | callback(step_idx, t, latents) 444 | 445 | reference_control_reader.clear() 446 | reference_control_writer.clear() 447 | 448 | # Post-processing 449 | images = self.decode_latents(latents) # (b, c, f, h, w) 450 | 451 | # Convert to tensor 452 | if output_type == "tensor": 453 | images = torch.from_numpy(images) 454 | 455 | if not return_dict: 456 | return images 457 | 458 | return Pose2VideoPipelineOutput(videos=images) 459 | -------------------------------------------------------------------------------- /src/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | tensor_interpolation = None 4 | 5 | 6 | def get_tensor_interpolation_method(): 7 | return tensor_interpolation 8 | 9 | 10 | def set_tensor_interpolation_method(is_slerp): 11 | global tensor_interpolation 12 | tensor_interpolation = slerp if is_slerp else linear 13 | 14 | 15 | def linear(v1, v2, t): 16 | return (1.0 - t) * v1 + t * v2 17 | 18 | 19 | def slerp( 20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 21 | ) -> torch.Tensor: 22 | u0 = v0 / v0.norm() 23 | u1 = v1 / v1.norm() 24 | dot = (u0 * u1).sum() 25 | if dot.abs() > DOT_THRESHOLD: 26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 27 | return (1.0 - t) * v0 + t * v1 28 | omega = dot.acos() 29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 30 | -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import sys 6 | from pathlib import Path 7 | 8 | import av 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | from einops import rearrange 13 | from PIL import Image 14 | 15 | 16 | def seed_everything(seed): 17 | import random 18 | 19 | import numpy as np 20 | 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed % (2**32)) 24 | random.seed(seed) 25 | 26 | 27 | def import_filename(filename): 28 | spec = importlib.util.spec_from_file_location("mymodule", filename) 29 | module = importlib.util.module_from_spec(spec) 30 | sys.modules[spec.name] = module 31 | spec.loader.exec_module(module) 32 | return module 33 | 34 | 35 | def delete_additional_ckpt(base_path, num_keep): 36 | dirs = [] 37 | for d in os.listdir(base_path): 38 | if d.startswith("checkpoint-"): 39 | dirs.append(d) 40 | num_tot = len(dirs) 41 | if num_tot <= num_keep: 42 | return 43 | # ensure ckpt is sorted and delete the ealier! 44 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] 45 | for d in del_dirs: 46 | path_to_dir = osp.join(base_path, d) 47 | if osp.exists(path_to_dir): 48 | shutil.rmtree(path_to_dir) 49 | 50 | 51 | def save_videos_from_pil(pil_images, path, fps=8): 52 | import av 53 | 54 | save_fmt = Path(path).suffix 55 | os.makedirs(os.path.dirname(path), exist_ok=True) 56 | width, height = pil_images[0].size 57 | 58 | if save_fmt == ".mp4": 59 | codec = "libx264" 60 | container = av.open(path, "w") 61 | stream = container.add_stream(codec, rate=fps) 62 | 63 | stream.width = width 64 | stream.height = height 65 | 66 | for pil_image in pil_images: 67 | # pil_image = Image.fromarray(image_arr).convert("RGB") 68 | av_frame = av.VideoFrame.from_image(pil_image) 69 | container.mux(stream.encode(av_frame)) 70 | container.mux(stream.encode()) 71 | container.close() 72 | 73 | elif save_fmt == ".gif": 74 | pil_images[0].save( 75 | fp=path, 76 | format="GIF", 77 | append_images=pil_images[1:], 78 | save_all=True, 79 | duration=(1 / fps * 1000), 80 | loop=0, 81 | ) 82 | else: 83 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 84 | 85 | 86 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 87 | videos = rearrange(videos, "b c t h w -> t b c h w") 88 | height, width = videos.shape[-2:] 89 | outputs = [] 90 | 91 | for x in videos: 92 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 93 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 94 | if rescale: 95 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 96 | x = (x * 255).numpy().astype(np.uint8) 97 | x = Image.fromarray(x) 98 | 99 | outputs.append(x) 100 | 101 | os.makedirs(os.path.dirname(path), exist_ok=True) 102 | 103 | save_videos_from_pil(outputs, path, fps) 104 | 105 | 106 | def read_frames(video_path): 107 | container = av.open(video_path) 108 | 109 | video_stream = next(s for s in container.streams if s.type == "video") 110 | frames = [] 111 | for packet in container.demux(video_stream): 112 | for frame in packet.decode(): 113 | image = Image.frombytes( 114 | "RGB", 115 | (frame.width, frame.height), 116 | frame.to_rgb().to_ndarray(), 117 | ) 118 | frames.append(image) 119 | 120 | return frames 121 | 122 | 123 | def get_fps(video_path): 124 | container = av.open(video_path) 125 | video_stream = next(s for s in container.streams if s.type == "video") 126 | fps = video_stream.average_rate 127 | container.close() 128 | return fps 129 | -------------------------------------------------------------------------------- /tools/download_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path, PurePosixPath 3 | 4 | from huggingface_hub import hf_hub_download 5 | 6 | 7 | def prepare_base_model(): 8 | print(f'Preparing base stable-diffusion-v1-5 weights...') 9 | local_dir = "./pretrained_weights/stable-diffusion-v1-5" 10 | os.makedirs(local_dir, exist_ok=True) 11 | for hub_file in ["unet/config.json", "unet/diffusion_pytorch_model.bin"]: 12 | path = Path(hub_file) 13 | saved_path = local_dir / path 14 | if os.path.exists(saved_path): 15 | continue 16 | hf_hub_download( 17 | repo_id="runwayml/stable-diffusion-v1-5", 18 | subfolder=PurePosixPath(path.parent), 19 | filename=PurePosixPath(path.name), 20 | local_dir=local_dir, 21 | ) 22 | 23 | 24 | def prepare_image_encoder(): 25 | print(f"Preparing image encoder weights...") 26 | local_dir = "./pretrained_weights" 27 | os.makedirs(local_dir, exist_ok=True) 28 | for hub_file in ["image_encoder/config.json", "image_encoder/pytorch_model.bin"]: 29 | path = Path(hub_file) 30 | saved_path = local_dir / path 31 | if os.path.exists(saved_path): 32 | continue 33 | hf_hub_download( 34 | repo_id="lambdalabs/sd-image-variations-diffusers", 35 | subfolder=PurePosixPath(path.parent), 36 | filename=PurePosixPath(path.name), 37 | local_dir=local_dir, 38 | ) 39 | 40 | 41 | def prepare_dwpose(): 42 | print(f"Preparing DWPose weights...") 43 | local_dir = "./pretrained_weights/DWPose" 44 | os.makedirs(local_dir, exist_ok=True) 45 | for hub_file in [ 46 | "dw-ll_ucoco_384.onnx", 47 | "yolox_l.onnx", 48 | ]: 49 | path = Path(hub_file) 50 | saved_path = local_dir / path 51 | if os.path.exists(saved_path): 52 | continue 53 | 54 | hf_hub_download( 55 | repo_id="yzd-v/DWPose", 56 | subfolder=PurePosixPath(path.parent), 57 | filename=PurePosixPath(path.name), 58 | local_dir=local_dir, 59 | ) 60 | 61 | 62 | def prepare_vae(): 63 | print(f"Preparing vae weights...") 64 | local_dir = "./pretrained_weights/sd-vae-ft-mse" 65 | os.makedirs(local_dir, exist_ok=True) 66 | for hub_file in [ 67 | "config.json", 68 | "diffusion_pytorch_model.bin", 69 | ]: 70 | path = Path(hub_file) 71 | saved_path = local_dir / path 72 | if os.path.exists(saved_path): 73 | continue 74 | 75 | hf_hub_download( 76 | repo_id="stabilityai/sd-vae-ft-mse", 77 | subfolder=PurePosixPath(path.parent), 78 | filename=PurePosixPath(path.name), 79 | local_dir=local_dir, 80 | ) 81 | 82 | 83 | def prepare_anyone(): 84 | print(f"Preparing AnimateAnyone weights...") 85 | local_dir = "./pretrained_weights" 86 | os.makedirs(local_dir, exist_ok=True) 87 | for hub_file in [ 88 | "denoising_unet.pth", 89 | "motion_module.pth", 90 | "pose_guider.pth", 91 | "reference_unet.pth", 92 | ]: 93 | path = Path(hub_file) 94 | saved_path = local_dir / path 95 | if os.path.exists(saved_path): 96 | continue 97 | 98 | hf_hub_download( 99 | repo_id="patrolli/AnimateAnyone", 100 | subfolder=PurePosixPath(path.parent), 101 | filename=PurePosixPath(path.name), 102 | local_dir=local_dir, 103 | ) 104 | 105 | if __name__ == '__main__': 106 | prepare_base_model() 107 | prepare_image_encoder() 108 | prepare_dwpose() 109 | prepare_vae() 110 | prepare_anyone() 111 | -------------------------------------------------------------------------------- /tools/extract_dwpose_from_vid.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | 8 | from src.dwpose import DWposeDetector 9 | from src.utils.util import get_fps, read_frames, save_videos_from_pil 10 | 11 | # Extract dwpose mp4 videos from raw videos 12 | # /path/to/video_dataset/*/*.mp4 -> /path/to/video_dataset_dwpose/*/*.mp4 13 | 14 | 15 | def process_single_video(video_path, detector, root_dir, save_dir): 16 | relative_path = os.path.relpath(video_path, root_dir) 17 | print(relative_path, video_path, root_dir) 18 | out_path = os.path.join(save_dir, relative_path) 19 | if os.path.exists(out_path): 20 | return 21 | 22 | output_dir = Path(os.path.dirname(os.path.join(save_dir, relative_path))) 23 | if not output_dir.exists(): 24 | output_dir.mkdir(parents=True, exist_ok=True) 25 | 26 | fps = get_fps(video_path) 27 | frames = read_frames(video_path) 28 | kps_results = [] 29 | for i, frame_pil in enumerate(frames): 30 | result, score = detector(frame_pil) 31 | score = np.mean(score, axis=-1) 32 | 33 | kps_results.append(result) 34 | 35 | save_videos_from_pil(kps_results, out_path, fps=fps) 36 | 37 | 38 | def process_batch_videos(video_list, detector, root_dir, save_dir): 39 | for i, video_path in enumerate(video_list): 40 | print(f"Process {i}/{len(video_list)} video") 41 | process_single_video(video_path, detector, root_dir, save_dir) 42 | 43 | 44 | if __name__ == "__main__": 45 | # ----- 46 | # NOTE: 47 | # python tools/extract_dwpose_from_vid.py --video_root /path/to/video_dir 48 | # ----- 49 | import argparse 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--video_root", type=str) 53 | parser.add_argument( 54 | "--save_dir", type=str, help="Path to save extracted pose videos" 55 | ) 56 | parser.add_argument("-j", type=int, default=4, help="Num workers") 57 | args = parser.parse_args() 58 | num_workers = args.j 59 | if args.save_dir is None: 60 | save_dir = args.video_root + "_dwpose" 61 | else: 62 | save_dir = args.save_dir 63 | if not os.path.exists(save_dir): 64 | os.makedirs(save_dir) 65 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") 66 | gpu_ids = [int(id) for id in range(len(cuda_visible_devices.split(",")))] 67 | print(f"avaliable gpu ids: {gpu_ids}") 68 | 69 | # collect all video_folder paths 70 | video_mp4_paths = set() 71 | for root, dirs, files in os.walk(args.video_root): 72 | for name in files: 73 | if name.endswith(".mp4"): 74 | video_mp4_paths.add(os.path.join(root, name)) 75 | video_mp4_paths = list(video_mp4_paths) 76 | random.shuffle(video_mp4_paths) 77 | 78 | # split into chunks, 79 | batch_size = (len(video_mp4_paths) + num_workers - 1) // num_workers 80 | print(f"Num videos: {len(video_mp4_paths)} {batch_size = }") 81 | video_chunks = [ 82 | video_mp4_paths[i : i + batch_size] 83 | for i in range(0, len(video_mp4_paths), batch_size) 84 | ] 85 | 86 | with concurrent.futures.ThreadPoolExecutor() as executor: 87 | futures = [] 88 | for i, chunk in enumerate(video_chunks): 89 | # init detector 90 | gpu_id = gpu_ids[i % len(gpu_ids)] 91 | detector = DWposeDetector() 92 | # torch.cuda.set_device(gpu_id) 93 | detector = detector.to(f"cuda:{gpu_id}") 94 | 95 | futures.append( 96 | executor.submit( 97 | process_batch_videos, chunk, detector, args.video_root, save_dir 98 | ) 99 | ) 100 | for future in concurrent.futures.as_completed(futures): 101 | future.result() 102 | -------------------------------------------------------------------------------- /tools/extract_meta_info.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | # ----- 6 | # [{'vid': , 'kps': , 'other':}, 7 | # {'vid': , 'kps': , 'other':}] 8 | # ----- 9 | # python tools/extract_meta_info.py --root_path /path/to/video_dir --dataset_name fashion 10 | # ----- 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--root_path", type=str) 13 | parser.add_argument("--dataset_name", type=str) 14 | parser.add_argument("--meta_info_name", type=str) 15 | 16 | args = parser.parse_args() 17 | 18 | if args.meta_info_name is None: 19 | args.meta_info_name = args.dataset_name 20 | 21 | pose_dir = args.root_path + "_dwpose" 22 | 23 | # collect all video_folder paths 24 | video_mp4_paths = set() 25 | for root, dirs, files in os.walk(args.root_path): 26 | for name in files: 27 | if name.endswith(".mp4"): 28 | video_mp4_paths.add(os.path.join(root, name)) 29 | video_mp4_paths = list(video_mp4_paths) 30 | 31 | meta_infos = [] 32 | for video_mp4_path in video_mp4_paths: 33 | relative_video_name = os.path.relpath(video_mp4_path, args.root_path) 34 | kps_path = os.path.join(pose_dir, relative_video_name) 35 | meta_infos.append({"video_path": video_mp4_path, "kps_path": kps_path}) 36 | 37 | json.dump(meta_infos, open(f"./data/{args.meta_info_name}_meta.json", "w")) 38 | -------------------------------------------------------------------------------- /tools/facetracker_api.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os, sys 3 | import math 4 | import numpy as np 5 | import cv2 6 | sys.path.append("OpenSeeFace/") 7 | from tracker import Tracker, get_model_base_path 8 | 9 | features = ["eye_l", "eye_r", "eyebrow_steepness_l", "eyebrow_updown_l", "eyebrow_quirk_l", "eyebrow_steepness_r", "eyebrow_updown_r", "eyebrow_quirk_r", "mouth_corner_updown_l", "mouth_corner_inout_l", "mouth_corner_updown_r", "mouth_corner_inout_r", "mouth_open", "mouth_wide"] 10 | 11 | 12 | def face_image(frame, save_path=None): 13 | height, width, c = frame.shape 14 | tracker = Tracker(width, height, threshold=None, max_threads=1, max_faces=1, discard_after=10, scan_every=3, silent=False, model_type=3, model_dir=None, 15 | no_gaze=False, detection_threshold=0.4, use_retinaface=0, max_feature_updates=900, static_model=True, try_hard=False) 16 | faces = tracker.predict(frame) 17 | frame = np.zeros_like(frame) 18 | detected = False 19 | face_lms = None 20 | for face_num, f in enumerate(faces): 21 | f = copy.copy(f) 22 | if f.eye_blink is None: 23 | f.eye_blink = [1, 1] 24 | right_state = "O" if f.eye_blink[0] > 0.30 else "-" 25 | left_state = "O" if f.eye_blink[1] > 0.30 else "-" 26 | detected = True 27 | if not f.success: 28 | pts_3d = np.zeros((70, 3), np.float32) 29 | if face_num == 0: 30 | face_lms = f.lms 31 | for pt_num, (x,y,c) in enumerate(f.lms): 32 | if pt_num == 66 and (f.eye_blink[0] < 0.30 or c < 0.20): 33 | continue 34 | if pt_num == 67 and (f.eye_blink[1] < 0.30 or c < 0.20): 35 | continue 36 | x = int(x + 0.5) 37 | y = int(y + 0.5) 38 | 39 | color = (0, 255, 0) 40 | if pt_num >= 66: 41 | color = (255, 255, 0) 42 | if not (x < 0 or y < 0 or x >= height or y >= width): 43 | cv2.circle(frame, (y, x), 1, color, -1) 44 | if f.rotation is not None: 45 | projected = cv2.projectPoints(f.contour, f.rotation, f.translation, tracker.camera, tracker.dist_coeffs) 46 | for [(x,y)] in projected[0]: 47 | x = int(x + 0.5) 48 | y = int(y + 0.5) 49 | if not (x < 0 or y < 0 or x >= height or y >= width): 50 | frame[int(x), int(y)] = (0, 255, 255) 51 | x += 1 52 | if not (x < 0 or y < 0 or x >= height or y >= width): 53 | frame[int(x), int(y)] = (0, 255, 255) 54 | y += 1 55 | if not (x < 0 or y < 0 or x >= height or y >= width): 56 | frame[int(x), int(y)] = (0, 255, 255) 57 | x -= 1 58 | if not (x < 0 or y < 0 or x >= height or y >= width): 59 | frame[int(x), int(y)] = (0, 255, 255) 60 | if save_path is not None: 61 | cv2.imwrite(save_path, frame) 62 | return frame, face_lms 63 | -------------------------------------------------------------------------------- /tools/vid2pose.py: -------------------------------------------------------------------------------- 1 | from src.dwpose import DWposeDetector 2 | import os 3 | from pathlib import Path 4 | 5 | from src.utils.util import get_fps, read_frames, save_videos_from_pil 6 | import numpy as np 7 | 8 | 9 | if __name__ == "__main__": 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--video_path", type=str) 14 | args = parser.parse_args() 15 | 16 | if not os.path.exists(args.video_path): 17 | raise ValueError(f"Path: {args.video_path} not exists") 18 | 19 | dir_path, video_name = ( 20 | os.path.dirname(args.video_path), 21 | os.path.splitext(os.path.basename(args.video_path))[0], 22 | ) 23 | out_path = os.path.join(dir_path, video_name + "_kps.mp4") 24 | 25 | detector = DWposeDetector() 26 | detector = detector.to(f"cuda") 27 | 28 | fps = get_fps(args.video_path) 29 | frames = read_frames(args.video_path) 30 | kps_results = [] 31 | for i, frame_pil in enumerate(frames): 32 | result, score = detector(frame_pil) 33 | score = np.mean(score, axis=-1) 34 | 35 | kps_results.append(result) 36 | 37 | print(out_path) 38 | save_videos_from_pil(kps_results, out_path, fps=fps) 39 | --------------------------------------------------------------------------------