├── .gitignore
├── LICENSE
├── NOTICE
├── README.md
├── app.py
├── assets
└── mini_program_maliang.png
├── configs
├── inference
│ ├── inference_v1.yaml
│ ├── inference_v2.yaml
│ ├── pose_images
│ │ └── pose-1.png
│ ├── pose_videos
│ │ ├── anyone-video-1_kps.mp4
│ │ ├── anyone-video-2_kps.mp4
│ │ ├── anyone-video-4_kps.mp4
│ │ └── anyone-video-5_kps.mp4
│ ├── ref_images
│ │ ├── anyone-1.png
│ │ ├── anyone-10.png
│ │ ├── anyone-11.png
│ │ ├── anyone-2.png
│ │ ├── anyone-3.png
│ │ └── anyone-5.png
│ ├── talkinghead_images
│ │ ├── 1.png
│ │ ├── 2.png
│ │ ├── 3.png
│ │ ├── 4.png
│ │ └── 5.png
│ └── talkinghead_videos
│ │ ├── 1.mp4
│ │ ├── 2.mp4
│ │ ├── 3.mp4
│ │ └── 4.mp4
├── prompts
│ ├── animation.yaml
│ ├── inference_reenact.yaml
│ └── test_cases.py
└── train
│ ├── stage1.yaml
│ └── stage2.yaml
├── requirements.txt
├── scripts
├── lmks2vid.py
└── pose2vid.py
├── src
├── __init__.py
├── dataset
│ ├── dance_image.py
│ └── dance_video.py
├── dwpose
│ ├── __init__.py
│ ├── onnxdet.py
│ ├── onnxpose.py
│ ├── util.py
│ └── wholebody.py
├── models
│ ├── attention.py
│ ├── motion_module.py
│ ├── mutual_self_attention.py
│ ├── pose_guider.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── transformer_3d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_condition.py
│ ├── unet_3d.py
│ └── unet_3d_blocks.py
├── pipelines
│ ├── __init__.py
│ ├── context.py
│ ├── pipeline_lmks2vid_long.py
│ ├── pipeline_pose2img.py
│ ├── pipeline_pose2vid.py
│ ├── pipeline_pose2vid_long.py
│ └── utils.py
└── utils
│ └── util.py
├── tools
├── download_weights.py
├── extract_dwpose_from_vid.py
├── extract_meta_info.py
├── facetracker_api.py
└── vid2pose.py
├── train_stage_1.py
└── train_stage_2.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | pretrained_weights/
3 | output/
4 | .venv/
5 | mlruns/
6 | data/
7 |
8 | *.pth
9 | *.pt
10 | *.pkl
11 | *.bin
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright @2023-2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | ==============================================================
2 | This repo also contains various third-party components and some code modified from other repos under other open source licenses. The following sections contain licensing infromation for such third-party libraries.
3 |
4 | -----------------------------
5 | majic-animate
6 | BSD 3-Clause License
7 | Copyright (c) Bytedance Inc.
8 |
9 | -----------------------------
10 | animatediff
11 | Apache License, Version 2.0
12 |
13 | -----------------------------
14 | Dwpose
15 | Apache License, Version 2.0
16 |
17 | -----------------------------
18 | inference pipeline for animatediff-cli-prompt-travel
19 | animatediff-cli-prompt-travel
20 | Apache License, Version 2.0
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🤗 Introduction
2 | **update** 🔥🔥🔥 We propose a face reenactment method, based on our AnimateAnyone pipeline: Using the facial landmark of driving video to control the pose of given source image, and keeping the identity of source image. Specially, we disentangle head attitude (including eyes blink) and mouth motion from the landmark of driving video, and it can control the expression and movements of source face precisely. We release our inference codes and pretrained models of face reenactment!!
3 |
4 |
5 | **update** 🏋️🏋️🏋️ We release our training codes!! Now you can train your own AnimateAnyone models. See [here](#train) for more details. Have fun!
6 |
7 | **update**:🔥🔥🔥 We launch a HuggingFace Spaces demo of Moore-AnimateAnyone at [here](https://huggingface.co/spaces/xunsong/Moore-AnimateAnyone)!!
8 |
9 | This repository reproduces [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone). To align the results demonstrated by the original paper, we adopt various approaches and tricks, which may differ somewhat from the paper and another [implementation](https://github.com/guoqincode/Open-AnimateAnyone).
10 |
11 | It's worth noting that this is a very preliminary version, aiming for approximating the performance (roughly 80% under our test) showed in [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone).
12 |
13 | We will continue to develop it, and also welcome feedbacks and ideas from the community. The enhanced version will also be launched on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform.
14 |
15 | # 📝 Release Plans
16 |
17 | - [x] Inference codes and pretrained weights of AnimateAnyone
18 | - [x] Training scripts of AnimateAnyone
19 | - [x] Inference codes and pretrained weights of face reenactment
20 | - [ ] Training scripts of face reenactment
21 | - [ ] Inference scripts of audio driven portrait video generation
22 | - [ ] Training scripts of audio driven portrait video generation
23 | # 🎞️ Examples
24 |
25 | ## AnimateAnyone
26 |
27 | Here are some AnimateAnyone results we generated, with the resolution of 512x768.
28 |
29 | https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/f0454f30-6726-4ad4-80a7-5b7a15619057
30 |
31 | https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/337ff231-68a3-4760-a9f9-5113654acf48
32 |
33 |
34 |
35 |
36 |
37 |
38 | |
39 |
40 |
41 | |
42 |
43 |
44 |
45 |
46 |
47 | |
48 |
49 |
50 | |
51 |
52 |
53 |
54 | **Limitation**: We observe following shortcomings in current version:
55 | 1. The background may occur some artifacts, when the reference image has a clean background
56 | 2. Suboptimal results may arise when there is a scale mismatch between the reference image and keypoints. We have yet to implement preprocessing techniques as mentioned in the [paper](https://arxiv.org/pdf/2311.17117.pdf).
57 | 3. Some flickering and jittering may occur when the motion sequence is subtle or the scene is static.
58 |
59 |
60 |
61 | These issues will be addressed and improved in the near future. We appreciate your anticipation!
62 |
63 | ## Face Reenactment
64 |
65 | Here are some results we generated, with the resolution of 512x512.
66 |
67 |
68 |
69 |
70 |
71 |
72 | |
73 |
74 |
75 | |
76 |
77 |
78 |
79 |
80 |
81 | |
82 |
83 |
84 | |
85 |
86 |
87 |
88 |
89 | # ⚒️ Installation
90 |
91 | ## Build Environtment
92 |
93 | We Recommend a python version `>=3.10` and cuda version `=11.7`. Then build environment as follows:
94 |
95 | ```shell
96 | # [Optional] Create a virtual env
97 | python -m venv .venv
98 | source .venv/bin/activate
99 | # Install with pip:
100 | pip install -r requirements.txt
101 | # For face landmark extraction
102 | git clone https://github.com/emilianavt/OpenSeeFace.git
103 | ```
104 |
105 | ## Download weights
106 |
107 | **Automatically downloading**: You can run the following command to download weights automatically:
108 |
109 | ```shell
110 | python tools/download_weights.py
111 | ```
112 |
113 | Weights will be placed under the `./pretrained_weights` direcotry. The whole downloading process may take a long time.
114 |
115 | **Manually downloading**: You can also download weights manually, which has some steps:
116 |
117 | 1. Download our AnimateAnyone trained [weights](https://huggingface.co/patrolli/AnimateAnyone/tree/main), which include four parts: `denoising_unet.pth`, `reference_unet.pth`, `pose_guider.pth` and `motion_module.pth`.
118 |
119 | 2. Download our trained [weights](https://pan.baidu.com/s/1lS5CynyNfYlDbjowKKfG8g?pwd=crci) of face reenactment, and place these weights under `pretrained_weights`.
120 |
121 | 3. Download pretrained weight of based models and other components:
122 | - [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
123 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
124 | - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)
125 |
126 | 4. Download dwpose weights (`dw-ll_ucoco_384.onnx`, `yolox_l.onnx`) following [this](https://github.com/IDEA-Research/DWPose?tab=readme-ov-file#-dwpose-for-controlnet).
127 |
128 | Finally, these weights should be orgnized as follows:
129 |
130 | ```text
131 | ./pretrained_weights/
132 | |-- DWPose
133 | | |-- dw-ll_ucoco_384.onnx
134 | | `-- yolox_l.onnx
135 | |-- image_encoder
136 | | |-- config.json
137 | | `-- pytorch_model.bin
138 | |-- denoising_unet.pth
139 | |-- motion_module.pth
140 | |-- pose_guider.pth
141 | |-- reference_unet.pth
142 | |-- sd-vae-ft-mse
143 | | |-- config.json
144 | | |-- diffusion_pytorch_model.bin
145 | | `-- diffusion_pytorch_model.safetensors
146 | |-- reenact
147 | | |-- denoising_unet.pth
148 | | |-- reference_unet.pth
149 | | |-- pose_guider1.pth
150 | | |-- pose_guider2.pth
151 | `-- stable-diffusion-v1-5
152 | |-- feature_extractor
153 | | `-- preprocessor_config.json
154 | |-- model_index.json
155 | |-- unet
156 | | |-- config.json
157 | | `-- diffusion_pytorch_model.bin
158 | `-- v1-inference.yaml
159 | ```
160 |
161 | Note: If you have installed some of the pretrained models, such as `StableDiffusion V1.5`, you can specify their paths in the config file (e.g. `./config/prompts/animation.yaml`).
162 |
163 | # 🚀 Training and Inference
164 |
165 | ## Inference of AnimateAnyone
166 |
167 | Here is the cli command for running inference scripts:
168 |
169 | ```shell
170 | python -m scripts.pose2vid --config ./configs/prompts/animation.yaml -W 512 -H 784 -L 64
171 | ```
172 |
173 | You can refer the format of `animation.yaml` to add your own reference images or pose videos. To convert the raw video into a pose video (keypoint sequence), you can run with the following command:
174 |
175 | ```shell
176 | python tools/vid2pose.py --video_path /path/to/your/video.mp4
177 | ```
178 |
179 | ## Inference of Face Reenactment
180 | Here is the cli command for running inference scripts:
181 |
182 | ```shell
183 | python -m scripts.lmks2vid --config ./configs/prompts/inference_reenact.yaml --driving_video_path YOUR_OWN_DRIVING_VIDEO_PATH --source_image_path YOUR_OWN_SOURCE_IMAGE_PATH
184 | ```
185 | We provide some face images in `./config/inference/talkinghead_images`, and some face videos in `./config/inference/talkinghead_videos` for inference.
186 |
187 | ## Training of AnimateAnyone
188 |
189 | Note: package dependencies have been updated, you may upgrade your environment via `pip install -r requirements.txt` before training.
190 |
191 | ### Data Preparation
192 |
193 | Extract keypoints from raw videos:
194 |
195 | ```shell
196 | python tools/extract_dwpose_from_vid.py --video_root /path/to/your/video_dir
197 | ```
198 |
199 | Extract the meta info of dataset:
200 |
201 | ```shell
202 | python tools/extract_meta_info.py --root_path /path/to/your/video_dir --dataset_name anyone
203 | ```
204 |
205 | Update lines in the training config file:
206 |
207 | ```yaml
208 | data:
209 | meta_paths:
210 | - "./data/anyone_meta.json"
211 | ```
212 |
213 | ### Stage1
214 |
215 | Put [openpose controlnet weights](https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/tree/main) under `./pretrained_weights`, which is used to initialize the pose_guider.
216 |
217 | Put [sd-image-variation](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main) under `./pretrained_weights`, which is used to initialize unet weights.
218 |
219 | Run command:
220 |
221 | ```shell
222 | accelerate launch train_stage_1.py --config configs/train/stage1.yaml
223 | ```
224 |
225 | ### Stage2
226 |
227 | Put the pretrained motion module weights `mm_sd_v15_v2.ckpt` ([download link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt)) under `./pretrained_weights`.
228 |
229 | Specify the stage1 training weights in the config file `stage2.yaml`, for example:
230 |
231 | ```yaml
232 | stage1_ckpt_dir: './exp_output/stage1'
233 | stage1_ckpt_step: 30000
234 | ```
235 |
236 | Run command:
237 |
238 | ```shell
239 | accelerate launch train_stage_2.py --config configs/train/stage2.yaml
240 | ```
241 |
242 | # 🎨 Gradio Demo
243 |
244 | **HuggingFace Demo**: We launch a quick preview demo of Moore-AnimateAnyone at [HuggingFace Spaces](https://huggingface.co/spaces/xunsong/Moore-AnimateAnyone)!!
245 | We appreciate the assistance provided by the HuggingFace team in setting up this demo.
246 |
247 | To reduce waiting time, we limit the size (width, height, and length) and inference steps when generating videos.
248 |
249 | If you have your own GPU resource (>= 16GB vram), you can run a local gradio app via following commands:
250 |
251 | `python app.py`
252 |
253 | # Community Contributions
254 |
255 | - Installation for Windows users: [Moore-AnimateAnyone-for-windows](https://github.com/sdbds/Moore-AnimateAnyone-for-windows)
256 |
257 | # 🖌️ Try on Mobi MaLiang
258 |
259 | We will launched this model on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform. Mobi MaLiang has now integrated various AIGC applications and functionalities (e.g. text-to-image, controllable generation...). You can experience it by [clicking this link](https://maliang.mthreads.com/) or scanning the QR code bellow via WeChat!
260 |
261 |
262 |
264 |
265 |
266 | # ⚖️ Disclaimer
267 |
268 | This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using the generative model. The project contributors have no legal affiliation with, nor accountability for, users' behaviors. It is imperative to use the generative model responsibly, adhering to both ethical and legal standards.
269 |
270 | # 🙏🏻 Acknowledgements
271 |
272 | We first thank the authors of [AnimateAnyone](). Additionally, we would like to thank the contributors to the [majic-animate](https://github.com/magic-research/magic-animate), [animatediff](https://github.com/guoyww/AnimateDiff) and [Open-AnimateAnyone](https://github.com/guoqincode/Open-AnimateAnyone) repositories, for their open research and exploration. Furthermore, our repo incorporates some codes from [dwpose](https://github.com/IDEA-Research/DWPose) and [animatediff-cli-prompt-travel](https://github.com/s9roll7/animatediff-cli-prompt-travel/), and we extend our thanks to them as well.
273 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from datetime import datetime
4 |
5 | import gradio as gr
6 | import numpy as np
7 | import torch
8 | from diffusers import AutoencoderKL, DDIMScheduler
9 | from einops import repeat
10 | from omegaconf import OmegaConf
11 | from PIL import Image
12 | from torchvision import transforms
13 | from transformers import CLIPVisionModelWithProjection
14 |
15 | from src.models.pose_guider import PoseGuider
16 | from src.models.unet_2d_condition import UNet2DConditionModel
17 | from src.models.unet_3d import UNet3DConditionModel
18 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
19 | from src.utils.util import get_fps, read_frames, save_videos_grid
20 |
21 |
22 | class AnimateController:
23 | def __init__(
24 | self,
25 | config_path="./configs/prompts/animation.yaml",
26 | weight_dtype=torch.float16,
27 | ):
28 | # Read pretrained weights path from config
29 | self.config = OmegaConf.load(config_path)
30 | self.pipeline = None
31 | self.weight_dtype = weight_dtype
32 |
33 | def animate(
34 | self,
35 | ref_image,
36 | pose_video_path,
37 | width=512,
38 | height=768,
39 | length=24,
40 | num_inference_steps=25,
41 | cfg=3.5,
42 | seed=123,
43 | ):
44 | generator = torch.manual_seed(seed)
45 | if isinstance(ref_image, np.ndarray):
46 | ref_image = Image.fromarray(ref_image)
47 | if self.pipeline is None:
48 | vae = AutoencoderKL.from_pretrained(
49 | self.config.pretrained_vae_path,
50 | ).to("cuda", dtype=self.weight_dtype)
51 |
52 | reference_unet = UNet2DConditionModel.from_pretrained(
53 | self.config.pretrained_base_model_path,
54 | subfolder="unet",
55 | ).to(dtype=self.weight_dtype, device="cuda")
56 |
57 | inference_config_path = self.config.inference_config
58 | infer_config = OmegaConf.load(inference_config_path)
59 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
60 | self.config.pretrained_base_model_path,
61 | self.config.motion_module_path,
62 | subfolder="unet",
63 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
64 | ).to(dtype=self.weight_dtype, device="cuda")
65 |
66 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
67 | dtype=self.weight_dtype, device="cuda"
68 | )
69 |
70 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
71 | self.config.image_encoder_path
72 | ).to(dtype=self.weight_dtype, device="cuda")
73 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
74 | scheduler = DDIMScheduler(**sched_kwargs)
75 |
76 | # load pretrained weights
77 | denoising_unet.load_state_dict(
78 | torch.load(self.config.denoising_unet_path, map_location="cpu"),
79 | strict=False,
80 | )
81 | reference_unet.load_state_dict(
82 | torch.load(self.config.reference_unet_path, map_location="cpu"),
83 | )
84 | pose_guider.load_state_dict(
85 | torch.load(self.config.pose_guider_path, map_location="cpu"),
86 | )
87 |
88 | pipe = Pose2VideoPipeline(
89 | vae=vae,
90 | image_encoder=image_enc,
91 | reference_unet=reference_unet,
92 | denoising_unet=denoising_unet,
93 | pose_guider=pose_guider,
94 | scheduler=scheduler,
95 | )
96 | pipe = pipe.to("cuda", dtype=self.weight_dtype)
97 | self.pipeline = pipe
98 |
99 | pose_images = read_frames(pose_video_path)
100 | src_fps = get_fps(pose_video_path)
101 |
102 | pose_list = []
103 | pose_tensor_list = []
104 | pose_transform = transforms.Compose(
105 | [transforms.Resize((height, width)), transforms.ToTensor()]
106 | )
107 | for pose_image_pil in pose_images[:length]:
108 | pose_list.append(pose_image_pil)
109 | pose_tensor_list.append(pose_transform(pose_image_pil))
110 |
111 | video = self.pipeline(
112 | ref_image,
113 | pose_list,
114 | width=width,
115 | height=height,
116 | video_length=length,
117 | num_inference_steps=num_inference_steps,
118 | guidance_scale=cfg,
119 | generator=generator,
120 | ).videos
121 |
122 | ref_image_tensor = pose_transform(ref_image) # (c, h, w)
123 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
124 | ref_image_tensor = repeat(
125 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length
126 | )
127 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
128 | pose_tensor = pose_tensor.transpose(0, 1)
129 | pose_tensor = pose_tensor.unsqueeze(0)
130 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
131 |
132 | save_dir = f"./output/gradio"
133 | if not os.path.exists(save_dir):
134 | os.makedirs(save_dir, exist_ok=True)
135 | date_str = datetime.now().strftime("%Y%m%d")
136 | time_str = datetime.now().strftime("%H%M")
137 | out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4")
138 | save_videos_grid(
139 | video,
140 | out_path,
141 | n_rows=3,
142 | fps=src_fps,
143 | )
144 |
145 | torch.cuda.empty_cache()
146 |
147 | return out_path
148 |
149 |
150 | controller = AnimateController()
151 |
152 |
153 | def ui():
154 | with gr.Blocks() as demo:
155 | gr.Markdown(
156 | """
157 | # Moore-AnimateAnyone Demo
158 | """
159 | )
160 | animation = gr.Video(
161 | format="mp4",
162 | label="Animation Results",
163 | height=448,
164 | autoplay=True,
165 | )
166 |
167 | with gr.Row():
168 | reference_image = gr.Image(label="Reference Image")
169 | motion_sequence = gr.Video(
170 | format="mp4", label="Motion Sequence", height=512
171 | )
172 |
173 | with gr.Column():
174 | width_slider = gr.Slider(
175 | label="Width", minimum=448, maximum=768, value=512, step=64
176 | )
177 | height_slider = gr.Slider(
178 | label="Height", minimum=512, maximum=1024, value=768, step=64
179 | )
180 | length_slider = gr.Slider(
181 | label="Video Length", minimum=24, maximum=128, value=24, step=24
182 | )
183 | with gr.Row():
184 | seed_textbox = gr.Textbox(label="Seed", value=-1)
185 | seed_button = gr.Button(
186 | value="\U0001F3B2", elem_classes="toolbutton"
187 | )
188 | seed_button.click(
189 | fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)),
190 | inputs=[],
191 | outputs=[seed_textbox],
192 | )
193 | with gr.Row():
194 | sampling_steps = gr.Slider(
195 | label="Sampling steps",
196 | value=25,
197 | info="default: 25",
198 | step=5,
199 | maximum=30,
200 | minimum=10,
201 | )
202 | guidance_scale = gr.Slider(
203 | label="Guidance scale",
204 | value=3.5,
205 | info="default: 3.5",
206 | step=0.5,
207 | maximum=10,
208 | minimum=2.0,
209 | )
210 | submit = gr.Button("Animate")
211 |
212 | def read_video(video):
213 | return video
214 |
215 | def read_image(image):
216 | return Image.fromarray(image)
217 |
218 | # when user uploads a new video
219 | motion_sequence.upload(read_video, motion_sequence, motion_sequence)
220 | # when `first_frame` is updated
221 | reference_image.upload(read_image, reference_image, reference_image)
222 | # when the `submit` button is clicked
223 | submit.click(
224 | controller.animate,
225 | [
226 | reference_image,
227 | motion_sequence,
228 | width_slider,
229 | height_slider,
230 | length_slider,
231 | sampling_steps,
232 | guidance_scale,
233 | seed_textbox,
234 | ],
235 | animation,
236 | )
237 |
238 | # Examples
239 | gr.Markdown("## Examples")
240 | gr.Examples(
241 | examples=[
242 | [
243 | "./configs/inference/ref_images/anyone-5.png",
244 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
245 | ],
246 | [
247 | "./configs/inference/ref_images/anyone-10.png",
248 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
249 | ],
250 | [
251 | "./configs/inference/ref_images/anyone-2.png",
252 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
253 | ],
254 | ],
255 | inputs=[reference_image, motion_sequence],
256 | outputs=animation,
257 | )
258 |
259 | return demo
260 |
261 |
262 | demo = ui()
263 | demo.launch(share=True)
264 |
--------------------------------------------------------------------------------
/assets/mini_program_maliang.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/assets/mini_program_maliang.png
--------------------------------------------------------------------------------
/configs/inference/inference_v1.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | unet_use_cross_frame_attention: false
3 | unet_use_temporal_attention: false
4 | use_motion_module: true
5 | motion_module_resolutions: [1,2,4,8]
6 | motion_module_mid_block: false
7 | motion_module_decoder_only: false
8 | motion_module_type: "Vanilla"
9 |
10 | motion_module_kwargs:
11 | num_attention_heads: 8
12 | num_transformer_block: 1
13 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
14 | temporal_position_encoding: true
15 | temporal_position_encoding_max_len: 24
16 | temporal_attention_dim_div: 1
17 |
18 | noise_scheduler_kwargs:
19 | beta_start: 0.00085
20 | beta_end: 0.012
21 | beta_schedule: "linear"
22 | steps_offset: 1
23 | clip_sample: False
--------------------------------------------------------------------------------
/configs/inference/inference_v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 32
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "linear"
28 | clip_sample: false
29 | steps_offset: 1
30 | ### Zero-SNR params
31 | prediction_type: "v_prediction"
32 | rescale_betas_zero_snr: True
33 | timestep_spacing: "trailing"
34 |
35 | sampler: DDIM
--------------------------------------------------------------------------------
/configs/inference/pose_images/pose-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_images/pose-1.png
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-1_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-1_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-2_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-2_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-4_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-4_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/pose_videos/anyone-video-5_kps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/pose_videos/anyone-video-5_kps.mp4
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-1.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-10.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-11.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-2.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-3.png
--------------------------------------------------------------------------------
/configs/inference/ref_images/anyone-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/ref_images/anyone-5.png
--------------------------------------------------------------------------------
/configs/inference/talkinghead_images/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/1.png
--------------------------------------------------------------------------------
/configs/inference/talkinghead_images/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/2.png
--------------------------------------------------------------------------------
/configs/inference/talkinghead_images/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/3.png
--------------------------------------------------------------------------------
/configs/inference/talkinghead_images/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/4.png
--------------------------------------------------------------------------------
/configs/inference/talkinghead_images/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_images/5.png
--------------------------------------------------------------------------------
/configs/inference/talkinghead_videos/1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/1.mp4
--------------------------------------------------------------------------------
/configs/inference/talkinghead_videos/2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/2.mp4
--------------------------------------------------------------------------------
/configs/inference/talkinghead_videos/3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/3.mp4
--------------------------------------------------------------------------------
/configs/inference/talkinghead_videos/4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/configs/inference/talkinghead_videos/4.mp4
--------------------------------------------------------------------------------
/configs/prompts/animation.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: "./pretrained_weights/stable-diffusion-v1-5/"
2 | pretrained_vae_path: "./pretrained_weights/sd-vae-ft-mse"
3 | image_encoder_path: "./pretrained_weights/image_encoder"
4 | denoising_unet_path: "./pretrained_weights/denoising_unet.pth"
5 | reference_unet_path: "./pretrained_weights/reference_unet.pth"
6 | pose_guider_path: "./pretrained_weights/pose_guider.pth"
7 | motion_module_path: "./pretrained_weights/motion_module.pth"
8 |
9 | inference_config: "./configs/inference/inference_v2.yaml"
10 | weight_dtype: 'fp16'
11 |
12 | test_cases:
13 | "./configs/inference/ref_images/anyone-2.png":
14 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
15 | - "./configs/inference/pose_videos/anyone-video-5_kps.mp4"
16 | "./configs/inference/ref_images/anyone-10.png":
17 | - "./configs/inference/pose_videos/anyone-video-1_kps.mp4"
18 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
19 | "./configs/inference/ref_images/anyone-11.png":
20 | - "./configs/inference/pose_videos/anyone-video-1_kps.mp4"
21 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
22 | "./configs/inference/ref_images/anyone-3.png":
23 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
24 | - "./configs/inference/pose_videos/anyone-video-5_kps.mp4"
25 | "./configs/inference/ref_images/anyone-5.png":
26 | - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
27 |
--------------------------------------------------------------------------------
/configs/prompts/inference_reenact.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: "./pretrained_weights/stable-diffusion-v1-5/"
2 | pretrained_vae_path: "./pretrained_weights/sd-vae-ft-mse"
3 | image_encoder_path: "./pretrained_weights/image_encoder"
4 | denoising_unet_path: "./pretrained_weights/reenact/denoising_unet.pth"
5 | reference_unet_path: "./pretrained_weights/reenact/reference_unet.pth"
6 | pose_guider1_path: "./pretrained_weights/reenact/pose_guider1.pth"
7 | pose_guider2_path: "./pretrained_weights/reenact/pose_guider2.pth"
8 | unet_additional_kwargs:
9 | task_type: "reenact"
10 | mode: "read" # "read"
11 | use_inflated_groupnorm: true
12 | unet_use_cross_frame_attention: false
13 | unet_use_temporal_attention: false
14 | use_motion_module: true
15 | motion_module_resolutions:
16 | - 1
17 | - 2
18 | - 4
19 | - 8
20 | motion_module_mid_block: true
21 | motion_module_decoder_only: false
22 | motion_module_type: Vanilla
23 | motion_module_kwargs:
24 | num_attention_heads: 8
25 | num_transformer_block: 1
26 | attention_block_types:
27 | - Temporal_Self
28 | - Temporal_Self
29 | temporal_position_encoding: true
30 | temporal_position_encoding_max_len: 32
31 | temporal_attention_dim_div: 1
32 |
33 | noise_scheduler_kwargs:
34 | beta_start: 0.00085
35 | beta_end: 0.012
36 | beta_schedule: "linear"
37 | # beta_schedule: "scaled_linear"
38 | clip_sample: false
39 | # set_alpha_to_one: False
40 | # skip_prk_steps: true
41 | steps_offset: 1
42 | ### Zero-SNR params
43 | # prediction_type: "v_prediction"
44 | # rescale_betas_zero_snr: True
45 | # timestep_spacing: "trailing"
46 |
47 | weight_dtype: float16
48 | sampler: DDIM
49 |
--------------------------------------------------------------------------------
/configs/prompts/test_cases.py:
--------------------------------------------------------------------------------
1 | TestCasesDict = {
2 | 0: [
3 | {
4 | "./configs/inference/ref_images/anyone-2.png": [
5 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
6 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
7 | ]
8 | },
9 | {
10 | "./configs/inference/ref_images/anyone-10.png": [
11 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
12 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
13 | ]
14 | },
15 | {
16 | "./configs/inference/ref_images/anyone-11.png": [
17 | "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
18 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
19 | ]
20 | },
21 | {
22 | "./configs/inference/anyone-ref-3.png": [
23 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
24 | "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
25 | ]
26 | },
27 | {
28 | "./configs/inference/ref_images/anyone-5.png": [
29 | "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
30 | ]
31 | },
32 | ],
33 | }
34 |
--------------------------------------------------------------------------------
/configs/train/stage1.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 4
3 | train_width: 768
4 | train_height: 768
5 | meta_paths:
6 | - "./data/fashion_meta.json"
7 | # Margin of frame indexes between ref and tgt images
8 | sample_margin: 30
9 |
10 | solver:
11 | gradient_accumulation_steps: 1
12 | mixed_precision: 'fp16'
13 | enable_xformers_memory_efficient_attention: True
14 | gradient_checkpointing: False
15 | max_train_steps: 30000
16 | max_grad_norm: 1.0
17 | # lr
18 | learning_rate: 1.0e-5
19 | scale_lr: False
20 | lr_warmup_steps: 1
21 | lr_scheduler: 'constant'
22 |
23 | # optimizer
24 | use_8bit_adam: False
25 | adam_beta1: 0.9
26 | adam_beta2: 0.999
27 | adam_weight_decay: 1.0e-2
28 | adam_epsilon: 1.0e-8
29 |
30 | val:
31 | validation_steps: 200
32 |
33 |
34 | noise_scheduler_kwargs:
35 | num_train_timesteps: 1000
36 | beta_start: 0.00085
37 | beta_end: 0.012
38 | beta_schedule: "scaled_linear"
39 | steps_offset: 1
40 | clip_sample: false
41 |
42 | base_model_path: './pretrained_weights/sd-image-variations-diffusers'
43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse'
44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
45 | controlnet_openpose_path: './pretrained_weights/control_v11p_sd15_openpose/diffusion_pytorch_model.bin'
46 |
47 | weight_dtype: 'fp16' # [fp16, fp32]
48 | uncond_ratio: 0.1
49 | noise_offset: 0.05
50 | snr_gamma: 5.0
51 | enable_zero_snr: True
52 | pose_guider_pretrain: True
53 |
54 | seed: 12580
55 | resume_from_checkpoint: ''
56 | checkpointing_steps: 2000
57 | save_model_epoch_interval: 5
58 | exp_name: 'stage1'
59 | output_dir: './exp_output'
--------------------------------------------------------------------------------
/configs/train/stage2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 1
3 | train_width: 512
4 | train_height: 512
5 | meta_paths:
6 | - "./data/fashion_meta.json"
7 | sample_rate: 4
8 | n_sample_frames: 24
9 |
10 | solver:
11 | gradient_accumulation_steps: 1
12 | mixed_precision: 'fp16'
13 | enable_xformers_memory_efficient_attention: True
14 | gradient_checkpointing: True
15 | max_train_steps: 10000
16 | max_grad_norm: 1.0
17 | # lr
18 | learning_rate: 1e-5
19 | scale_lr: False
20 | lr_warmup_steps: 1
21 | lr_scheduler: 'constant'
22 |
23 | # optimizer
24 | use_8bit_adam: True
25 | adam_beta1: 0.9
26 | adam_beta2: 0.999
27 | adam_weight_decay: 1.0e-2
28 | adam_epsilon: 1.0e-8
29 |
30 | val:
31 | validation_steps: 20
32 |
33 |
34 | noise_scheduler_kwargs:
35 | num_train_timesteps: 1000
36 | beta_start: 0.00085
37 | beta_end: 0.012
38 | beta_schedule: "linear"
39 | steps_offset: 1
40 | clip_sample: false
41 |
42 | base_model_path: './pretrained_weights/stable-diffusion-v1-5'
43 | vae_model_path: './pretrained_weights/sd-vae-ft-mse'
44 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
45 | mm_path: './pretrained_weights/mm_sd_v15_v2.ckpt'
46 |
47 | weight_dtype: 'fp16' # [fp16, fp32]
48 | uncond_ratio: 0.1
49 | noise_offset: 0.05
50 | snr_gamma: 5.0
51 | enable_zero_snr: True
52 | stage1_ckpt_dir: './exp_output/stage1'
53 | stage1_ckpt_step: 980
54 |
55 | seed: 12580
56 | resume_from_checkpoint: ''
57 | checkpointing_steps: 2000
58 | exp_name: 'stage2'
59 | output_dir: './exp_output'
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | av==11.0.0
3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4 | decord==0.6.0
5 | diffusers==0.24.0
6 | einops==0.4.1
7 | gradio==3.41.2
8 | gradio_client==0.5.0
9 | imageio==2.33.0
10 | imageio-ffmpeg==0.4.9
11 | numpy==1.23.5
12 | omegaconf==2.2.3
13 | onnxruntime-gpu==1.16.3
14 | open-clip-torch==2.20.0
15 | opencv-contrib-python==4.8.1.78
16 | opencv-python==4.8.1.78
17 | Pillow==9.5.0
18 | scikit-image==0.21.0
19 | scikit-learn==1.3.2
20 | scipy==1.11.4
21 | torch==2.0.1
22 | torchdiffeq==0.2.3
23 | torchmetrics==1.2.1
24 | torchsde==0.2.5
25 | torchvision==0.15.2
26 | tqdm==4.66.1
27 | transformers==4.30.2
28 | mlflow==2.9.2
29 | xformers==0.0.22
30 | controlnet-aux==0.0.7
--------------------------------------------------------------------------------
/scripts/lmks2vid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | from datetime import datetime
5 | from pathlib import Path
6 | from typing import List
7 |
8 | import av
9 | import cv2
10 | import numpy as np
11 | import torch
12 |
13 | # 初始化模型
14 | import torchvision
15 | from diffusers import AutoencoderKL, DDIMScheduler
16 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
17 | from einops import rearrange, repeat
18 | from omegaconf import OmegaConf
19 | from PIL import Image
20 | from torchvision import transforms
21 | from transformers import (
22 | CLIPImageProcessor,
23 | CLIPTextModel,
24 | CLIPTokenizer,
25 | CLIPVisionModel,
26 | CLIPVisionModelWithProjection,
27 | )
28 |
29 | import sys
30 | from src.models.unet_3d import UNet3DConditionModel
31 | from src.pipelines.pipeline_lmks2vid_long import Pose2VideoPipeline
32 | from src.models.pose_guider import PoseGuider
33 | from src.utils.util import get_fps, read_frames, save_videos_grid
34 | from tools.facetracker_api import face_image
35 |
36 |
37 | def parse_args():
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument(
40 | "--config", type=str, help="Path of inference configs",
41 | default="./configs/prompts/inference_reenact.yaml"
42 | )
43 | parser.add_argument(
44 | "--save_dir", type=str, help="Path of save results",
45 | default="./output/stage2_infer"
46 | )
47 | parser.add_argument(
48 | "--source_image_path", type=str, help="Path of source image",
49 | default="",
50 | )
51 | parser.add_argument(
52 | "--driving_video_path", type=str, help="Path of driving video",
53 | default="",
54 | )
55 | parser.add_argument(
56 | "--batch_size",
57 | type=int,
58 | default=320,
59 | help="Checkpoint step of pretrained model",
60 | )
61 | parser.add_argument("--mask_ratio", type=float, default=0.55) # 0.55~0.6
62 | parser.add_argument("-W", type=int, default=512)
63 | parser.add_argument("-H", type=int, default=512)
64 | parser.add_argument("-L", type=int, default=24)
65 | parser.add_argument("--seed", type=int, default=42)
66 | parser.add_argument("--cfg", type=float, default=3.5)
67 | parser.add_argument("--steps", type=int, default=30)
68 | parser.add_argument("--fps", type=int, default=25)
69 | args = parser.parse_args()
70 |
71 | return args
72 |
73 |
74 | def lmks_vis(img, lms):
75 | # Visualize the mouth, nose, and entire face based on landmarks
76 | h, w, c = img.shape
77 | lms = lms[:, :2]
78 | mouth = lms[48:66]
79 | nose = lms[27:36]
80 | color = (0, 255, 0)
81 | # Center mouth and nose
82 | x_c, y_c = np.mean(lms[:, 0]), np.mean(lms[:, 1])
83 | h_c, w_c = h // 2, w // 2
84 | img_face, img_mouth, img_nose = img.copy(), img.copy(), img.copy()
85 | for pt_num, (x, y) in enumerate(mouth):
86 | x = x - (x_c - w_c)
87 | y = y - (y_c - h_c)
88 | x = int(x + 0.5)
89 | y = int(y + 0.5)
90 | cv2.circle(img_mouth, (y, x), 1, color, -1)
91 | for pt_num, (x, y) in enumerate(nose):
92 | x = x - (x_c - w_c)
93 | y = y - (y_c - h_c)
94 | x = int(x + 0.5)
95 | y = int(y + 0.5)
96 | cv2.circle(img_nose, (y, x), 1, color, -1)
97 | for pt_num, (x, y) in enumerate(lms):
98 | x = int(x + 0.5)
99 | y = int(y + 0.5)
100 | if pt_num >= 66:
101 | color = (255, 255, 0)
102 | else:
103 | color = (0, 255, 0)
104 | cv2.circle(img_face, (y, x), 1, color, -1)
105 | return img_face, img_mouth, img_nose
106 |
107 |
108 | def batch_rearrange(pose_len, batch_size=24):
109 | # To rearrange the pose sequence based on batch size
110 | batch_ind_list = []
111 | for i in range(0, pose_len, batch_size):
112 | if i + batch_size < pose_len:
113 | batch_ind_list.append(list(range(i, i + batch_size)))
114 | else:
115 | batch_ind_list.append(list(range(i, min(i + batch_size, pose_len))))
116 | return batch_ind_list
117 |
118 |
119 | def lmks_video_extract(video_path):
120 | # To extract the landmark sequence of video (single face video)
121 | video_stream = cv2.VideoCapture(video_path)
122 | lmks_list, frames = [], []
123 | while 1:
124 | still_reading, frame = video_stream.read()
125 | if not still_reading:
126 | video_stream.release()
127 | break
128 | h, w, c = frame.shape
129 | lmk_img, lmks = face_image(frame)
130 | if lmks is not None:
131 | lmks_list.append(lmks)
132 | frames.append(frame)
133 | return frames, np.array(lmks_list), [h, w]
134 |
135 |
136 | def adjust_pose(src_lms_list, src_size, ref_lms, ref_size):
137 | # To align the center of source landmarks based on reference landmark
138 | new_src_lms_list = []
139 | ref_lms = ref_lms[:, :2]
140 | src_lms = src_lms_list[0][:, :2]
141 | ref_lms[:, 0] = ref_lms[:, 0] / ref_size[1]
142 | ref_lms[:, 1] = ref_lms[:, 1] / ref_size[0]
143 | src_lms[:, 0] = src_lms[:, 0] / src_size[1]
144 | src_lms[:, 1] = src_lms[:, 1] / src_size[0]
145 | ref_cx, ref_cy = np.mean(ref_lms[:, 0]), np.mean(ref_lms[:, 1])
146 | src_cx, src_cy = np.mean(src_lms[:, 0]), np.mean(src_lms[:, 1])
147 | for item in src_lms_list:
148 | item = item[:, :2]
149 | item[:, 0] = item[:, 0] - int((src_cx - ref_cx)) * src_size[1]
150 | item[:, 1] = item[:, 1] - int((src_cy - ref_cy)) * src_size[0]
151 | new_src_lms_list.append(item)
152 | return np.array(new_src_lms_list)
153 |
154 |
155 | def main():
156 | args = parse_args()
157 | infer_config = OmegaConf.load(args.config)
158 |
159 | # base_model_path = "./pretrained_weights/huggingface-models/sd-image-variations-diffusers/"
160 | base_model_path = infer_config.pretrained_base_model_path
161 | weight_dtype = torch.float16
162 |
163 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
164 | # "./pretrained_weights/huggingface-models/sd-image-variations-diffusers/image_encoder"
165 | infer_config.image_encoder_path
166 | ).to(dtype=weight_dtype, device="cuda")
167 | vae = AutoencoderKL.from_pretrained(
168 | # "./pretrained_weights/huggingface-models/sd-vae-ft-mse"
169 | infer_config.pretrained_vae_path
170 | ).to("cuda", dtype=weight_dtype)
171 | # initial reference unet, denoise unet, pose guider
172 | reference_unet = UNet3DConditionModel.from_pretrained_2d(
173 | base_model_path,
174 | "",
175 | subfolder="unet",
176 | unet_additional_kwargs={
177 | "task_type": "reenact",
178 | "use_motion_module": False,
179 | "unet_use_temporal_attention": False,
180 | "mode": "write",
181 | },
182 | ).to(device="cuda", dtype=weight_dtype)
183 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
184 | base_model_path,
185 | "./pretrained_weights/mm_sd_v15_v2.ckpt",
186 | subfolder="unet",
187 | unet_additional_kwargs=OmegaConf.to_container(
188 | infer_config.unet_additional_kwargs
189 | ),
190 | # mm_zero_proj_out=True,
191 | ).to(device="cuda")
192 | pose_guider1 = PoseGuider(
193 | conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
194 | ).to(device="cuda", dtype=weight_dtype)
195 | pose_guider2 = PoseGuider(
196 | conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
197 | ).to(device="cuda", dtype=weight_dtype)
198 | print("------------------initial all networks------------------")
199 | # load model from pretrained models
200 | denoising_unet.load_state_dict(
201 | torch.load(
202 | infer_config.denoising_unet_path,
203 | map_location="cpu",
204 | ),
205 | strict=True,
206 | )
207 | reference_unet.load_state_dict(
208 | torch.load(
209 | infer_config.reference_unet_path,
210 | map_location="cpu",
211 | )
212 | )
213 | pose_guider1.load_state_dict(
214 | torch.load(
215 | infer_config.pose_guider1_path,
216 | map_location="cpu",
217 | )
218 | )
219 | pose_guider2.load_state_dict(
220 | torch.load(
221 | infer_config.pose_guider2_path,
222 | map_location="cpu",
223 | )
224 | )
225 | print("---------load pretrained denoising unet, reference unet and pose guider----------")
226 | # scheduler
227 | enable_zero_snr = True
228 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
229 | if enable_zero_snr:
230 | sched_kwargs.update(
231 | rescale_betas_zero_snr=True,
232 | timestep_spacing="trailing",
233 | prediction_type="v_prediction",
234 | )
235 | scheduler = DDIMScheduler(**sched_kwargs)
236 | pipe = Pose2VideoPipeline(
237 | vae=vae,
238 | image_encoder=image_enc,
239 | reference_unet=reference_unet,
240 | denoising_unet=denoising_unet,
241 | pose_guider1=pose_guider1,
242 | pose_guider2=pose_guider2,
243 | scheduler=scheduler,
244 | )
245 | pipe = pipe.to("cuda", dtype=weight_dtype)
246 | height, width, clip_length = args.H, args.W, args.L
247 | generator = torch.manual_seed(42)
248 | date_str = datetime.now().strftime("%Y%m%d")
249 | save_dir = Path(f"{args.save_dir}/{date_str}")
250 | save_dir.mkdir(exist_ok=True, parents=True)
251 |
252 | ref_image_path, pose_video_path = args.source_image_path, args.driving_video_path
253 | ref_name = Path(ref_image_path).stem
254 | pose_name = Path(pose_video_path).stem
255 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
256 | ref_image = cv2.imread(ref_image_path)
257 | ref_h, ref_w, c = ref_image.shape
258 | ref_pose, ref_pose_lms = face_image(ref_image)
259 | # To extract landmarks from driving video
260 | pose_frames, pose_lms_list, pose_size = lmks_video_extract(pose_video_path)
261 | pose_lms_list = adjust_pose(pose_lms_list, pose_size, ref_pose_lms, [ref_h, ref_w])
262 | pose_h, pose_w = int(pose_size[0]), int(pose_size[1])
263 | pose_len = pose_lms_list.shape[0]
264 | # Truncating the video tail if its frames less than 24 to obtain stable effect.
265 | pose_len = pose_len // 24 * 24
266 | batch_index_list = batch_rearrange(pose_len, args.batch_size)
267 | pose_transform = transforms.Compose(
268 | [transforms.Resize((height, width)), transforms.ToTensor()]
269 | )
270 | videos = []
271 | zero_map = np.zeros_like(ref_pose)
272 | zero_map = cv2.resize(zero_map, (pose_w, pose_h))
273 | for batch_index in batch_index_list:
274 | pose_list, pose_up_list, pose_down_list = [], [], []
275 | pose_frame_list = []
276 | pose_tensor_list, pose_up_tensor_list, pose_down_tensor_list = [], [], []
277 | batch_len = len(batch_index)
278 | for pose_idx in batch_index:
279 | pose_lms = pose_lms_list[pose_idx]
280 | pose_frame = pose_frames[pose_idx][:, :, ::-1]
281 | pose_image, pose_mouth_image, _ = lmks_vis(zero_map, pose_lms)
282 | h, w, c = pose_image.shape
283 | pose_up_image = pose_image.copy()
284 | pose_up_image[int(h * args.mask_ratio):, :, :] = 0.
285 | pose_image_pil = Image.fromarray(pose_image)
286 | pose_frame = Image.fromarray(pose_frame)
287 | pose_up_pil = Image.fromarray(pose_up_image)
288 | pose_mouth_pil = Image.fromarray(pose_mouth_image)
289 | pose_list.append(pose_image_pil)
290 | pose_up_list.append(pose_up_pil)
291 | pose_down_list.append(pose_mouth_pil)
292 | pose_tensor_list.append(pose_transform(pose_image_pil))
293 | pose_up_tensor_list.append(pose_transform(pose_up_pil))
294 | pose_down_tensor_list.append(pose_transform(pose_mouth_pil))
295 | pose_frame_list.append(pose_transform(pose_frame))
296 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
297 | pose_tensor = pose_tensor.transpose(0, 1)
298 | pose_tensor = pose_tensor.unsqueeze(0)
299 | pose_frames_tensor = torch.stack(pose_frame_list, dim=0) # (f, c, h, w)
300 | pose_frames_tensor = pose_frames_tensor.transpose(0, 1)
301 | pose_frames_tensor = pose_frames_tensor.unsqueeze(0)
302 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
303 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
304 | ref_image_tensor = repeat(
305 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=batch_len
306 | )
307 | # To disentangle head attitude control (including eyes blink) and mouth motion control
308 | pipeline_output = pipe(
309 | ref_image_pil,
310 | pose_up_list,
311 | pose_down_list,
312 | width,
313 | height,
314 | batch_len,
315 | 20,
316 | 3.5,
317 | generator=generator,
318 | )
319 | video = pipeline_output.videos
320 | video = torch.cat([ref_image_tensor, pose_frames_tensor, video], dim=0)
321 | videos.append(video)
322 | videos = torch.cat(videos, dim=2)
323 | time_str = datetime.now().strftime("%H%M")
324 | save_video_path = f"{save_dir}/{ref_name}_{pose_name}_{time_str}.mp4"
325 | save_videos_grid(
326 | videos,
327 | save_video_path,
328 | n_rows=3,
329 | fps=args.fps,
330 | )
331 | print("infer results: {}".format(save_video_path))
332 | del pipe
333 | torch.cuda.empty_cache()
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
338 |
--------------------------------------------------------------------------------
/scripts/pose2vid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from datetime import datetime
4 | from pathlib import Path
5 | from typing import List
6 |
7 | import av
8 | import numpy as np
9 | import torch
10 | import torchvision
11 | from diffusers import AutoencoderKL, DDIMScheduler
12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
13 | from einops import repeat
14 | from omegaconf import OmegaConf
15 | from PIL import Image
16 | from torchvision import transforms
17 | from transformers import CLIPVisionModelWithProjection
18 |
19 | from configs.prompts.test_cases import TestCasesDict
20 | from src.models.pose_guider import PoseGuider
21 | from src.models.unet_2d_condition import UNet2DConditionModel
22 | from src.models.unet_3d import UNet3DConditionModel
23 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24 | from src.utils.util import get_fps, read_frames, save_videos_grid
25 |
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--config")
30 | parser.add_argument("-W", type=int, default=512)
31 | parser.add_argument("-H", type=int, default=784)
32 | parser.add_argument("-L", type=int, default=24)
33 | parser.add_argument("--seed", type=int, default=42)
34 | parser.add_argument("--cfg", type=float, default=3.5)
35 | parser.add_argument("--steps", type=int, default=30)
36 | parser.add_argument("--fps", type=int)
37 | args = parser.parse_args()
38 |
39 | return args
40 |
41 |
42 | def main():
43 | args = parse_args()
44 |
45 | config = OmegaConf.load(args.config)
46 |
47 | if config.weight_dtype == "fp16":
48 | weight_dtype = torch.float16
49 | else:
50 | weight_dtype = torch.float32
51 |
52 | vae = AutoencoderKL.from_pretrained(
53 | config.pretrained_vae_path,
54 | ).to("cuda", dtype=weight_dtype)
55 |
56 | reference_unet = UNet2DConditionModel.from_pretrained(
57 | config.pretrained_base_model_path,
58 | subfolder="unet",
59 | ).to(dtype=weight_dtype, device="cuda")
60 |
61 | inference_config_path = config.inference_config
62 | infer_config = OmegaConf.load(inference_config_path)
63 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
64 | config.pretrained_base_model_path,
65 | config.motion_module_path,
66 | subfolder="unet",
67 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
68 | ).to(dtype=weight_dtype, device="cuda")
69 |
70 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
71 | dtype=weight_dtype, device="cuda"
72 | )
73 |
74 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
75 | config.image_encoder_path
76 | ).to(dtype=weight_dtype, device="cuda")
77 |
78 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
79 | scheduler = DDIMScheduler(**sched_kwargs)
80 |
81 | generator = torch.manual_seed(args.seed)
82 |
83 | width, height = args.W, args.H
84 |
85 | # load pretrained weights
86 | denoising_unet.load_state_dict(
87 | torch.load(config.denoising_unet_path, map_location="cpu"),
88 | strict=False,
89 | )
90 | reference_unet.load_state_dict(
91 | torch.load(config.reference_unet_path, map_location="cpu"),
92 | )
93 | pose_guider.load_state_dict(
94 | torch.load(config.pose_guider_path, map_location="cpu"),
95 | )
96 |
97 | pipe = Pose2VideoPipeline(
98 | vae=vae,
99 | image_encoder=image_enc,
100 | reference_unet=reference_unet,
101 | denoising_unet=denoising_unet,
102 | pose_guider=pose_guider,
103 | scheduler=scheduler,
104 | )
105 | pipe = pipe.to("cuda", dtype=weight_dtype)
106 |
107 | date_str = datetime.now().strftime("%Y%m%d")
108 | time_str = datetime.now().strftime("%H%M")
109 | save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
110 |
111 | save_dir = Path(f"output/{date_str}/{save_dir_name}")
112 | save_dir.mkdir(exist_ok=True, parents=True)
113 |
114 | for ref_image_path in config["test_cases"].keys():
115 | # Each ref_image may correspond to multiple actions
116 | for pose_video_path in config["test_cases"][ref_image_path]:
117 | ref_name = Path(ref_image_path).stem
118 | pose_name = Path(pose_video_path).stem.replace("_kps", "")
119 |
120 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
121 |
122 | pose_list = []
123 | pose_tensor_list = []
124 | pose_images = read_frames(pose_video_path)
125 | src_fps = get_fps(pose_video_path)
126 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
127 | pose_transform = transforms.Compose(
128 | [transforms.Resize((height, width)), transforms.ToTensor()]
129 | )
130 | for pose_image_pil in pose_images[: args.L]:
131 | pose_tensor_list.append(pose_transform(pose_image_pil))
132 | pose_list.append(pose_image_pil)
133 |
134 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
135 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(
136 | 0
137 | ) # (1, c, 1, h, w)
138 | ref_image_tensor = repeat(
139 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=args.L
140 | )
141 |
142 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
143 | pose_tensor = pose_tensor.transpose(0, 1)
144 | pose_tensor = pose_tensor.unsqueeze(0)
145 |
146 | video = pipe(
147 | ref_image_pil,
148 | pose_list,
149 | width,
150 | height,
151 | args.L,
152 | args.steps,
153 | args.cfg,
154 | generator=generator,
155 | ).videos
156 |
157 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
158 | save_videos_grid(
159 | video,
160 | f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.mp4",
161 | n_rows=3,
162 | fps=src_fps if args.fps is None else args.fps,
163 | )
164 |
165 |
166 | if __name__ == "__main__":
167 | main()
168 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/src/__init__.py
--------------------------------------------------------------------------------
/src/dataset/dance_image.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import torch
5 | import torchvision.transforms as transforms
6 | from decord import VideoReader
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from transformers import CLIPImageProcessor
10 |
11 |
12 | class HumanDanceDataset(Dataset):
13 | def __init__(
14 | self,
15 | img_size,
16 | img_scale=(1.0, 1.0),
17 | img_ratio=(0.9, 1.0),
18 | drop_ratio=0.1,
19 | data_meta_paths=["./data/fahsion_meta.json"],
20 | sample_margin=30,
21 | ):
22 | super().__init__()
23 |
24 | self.img_size = img_size
25 | self.img_scale = img_scale
26 | self.img_ratio = img_ratio
27 | self.sample_margin = sample_margin
28 |
29 | # -----
30 | # vid_meta format:
31 | # [{'video_path': , 'kps_path': , 'other':},
32 | # {'video_path': , 'kps_path': , 'other':}]
33 | # -----
34 | vid_meta = []
35 | for data_meta_path in data_meta_paths:
36 | vid_meta.extend(json.load(open(data_meta_path, "r")))
37 | self.vid_meta = vid_meta
38 |
39 | self.clip_image_processor = CLIPImageProcessor()
40 |
41 | self.transform = transforms.Compose(
42 | [
43 | transforms.RandomResizedCrop(
44 | self.img_size,
45 | scale=self.img_scale,
46 | ratio=self.img_ratio,
47 | interpolation=transforms.InterpolationMode.BILINEAR,
48 | ),
49 | transforms.ToTensor(),
50 | transforms.Normalize([0.5], [0.5]),
51 | ]
52 | )
53 |
54 | self.cond_transform = transforms.Compose(
55 | [
56 | transforms.RandomResizedCrop(
57 | self.img_size,
58 | scale=self.img_scale,
59 | ratio=self.img_ratio,
60 | interpolation=transforms.InterpolationMode.BILINEAR,
61 | ),
62 | transforms.ToTensor(),
63 | ]
64 | )
65 |
66 | self.drop_ratio = drop_ratio
67 |
68 | def augmentation(self, image, transform, state=None):
69 | if state is not None:
70 | torch.set_rng_state(state)
71 | return transform(image)
72 |
73 | def __getitem__(self, index):
74 | video_meta = self.vid_meta[index]
75 | video_path = video_meta["video_path"]
76 | kps_path = video_meta["kps_path"]
77 |
78 | video_reader = VideoReader(video_path)
79 | kps_reader = VideoReader(kps_path)
80 |
81 | assert len(video_reader) == len(
82 | kps_reader
83 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
84 |
85 | video_length = len(video_reader)
86 |
87 | margin = min(self.sample_margin, video_length)
88 |
89 | ref_img_idx = random.randint(0, video_length - 1)
90 | if ref_img_idx + margin < video_length:
91 | tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
92 | elif ref_img_idx - margin > 0:
93 | tgt_img_idx = random.randint(0, ref_img_idx - margin)
94 | else:
95 | tgt_img_idx = random.randint(0, video_length - 1)
96 |
97 | ref_img = video_reader[ref_img_idx]
98 | ref_img_pil = Image.fromarray(ref_img.asnumpy())
99 | tgt_img = video_reader[tgt_img_idx]
100 | tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
101 |
102 | tgt_pose = kps_reader[tgt_img_idx]
103 | tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
104 |
105 | state = torch.get_rng_state()
106 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
107 | tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
108 | ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
109 | clip_image = self.clip_image_processor(
110 | images=ref_img_pil, return_tensors="pt"
111 | ).pixel_values[0]
112 |
113 | sample = dict(
114 | video_dir=video_path,
115 | img=tgt_img,
116 | tgt_pose=tgt_pose_img,
117 | ref_img=ref_img_vae,
118 | clip_images=clip_image,
119 | )
120 |
121 | return sample
122 |
123 | def __len__(self):
124 | return len(self.vid_meta)
125 |
--------------------------------------------------------------------------------
/src/dataset/dance_video.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from typing import List
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | import torchvision.transforms as transforms
9 | from decord import VideoReader
10 | from PIL import Image
11 | from torch.utils.data import Dataset
12 | from transformers import CLIPImageProcessor
13 |
14 |
15 | class HumanDanceVideoDataset(Dataset):
16 | def __init__(
17 | self,
18 | sample_rate,
19 | n_sample_frames,
20 | width,
21 | height,
22 | img_scale=(1.0, 1.0),
23 | img_ratio=(0.9, 1.0),
24 | drop_ratio=0.1,
25 | data_meta_paths=["./data/fashion_meta.json"],
26 | ):
27 | super().__init__()
28 | self.sample_rate = sample_rate
29 | self.n_sample_frames = n_sample_frames
30 | self.width = width
31 | self.height = height
32 | self.img_scale = img_scale
33 | self.img_ratio = img_ratio
34 |
35 | vid_meta = []
36 | for data_meta_path in data_meta_paths:
37 | vid_meta.extend(json.load(open(data_meta_path, "r")))
38 | self.vid_meta = vid_meta
39 |
40 | self.clip_image_processor = CLIPImageProcessor()
41 |
42 | self.pixel_transform = transforms.Compose(
43 | [
44 | transforms.RandomResizedCrop(
45 | (height, width),
46 | scale=self.img_scale,
47 | ratio=self.img_ratio,
48 | interpolation=transforms.InterpolationMode.BILINEAR,
49 | ),
50 | transforms.ToTensor(),
51 | transforms.Normalize([0.5], [0.5]),
52 | ]
53 | )
54 |
55 | self.cond_transform = transforms.Compose(
56 | [
57 | transforms.RandomResizedCrop(
58 | (height, width),
59 | scale=self.img_scale,
60 | ratio=self.img_ratio,
61 | interpolation=transforms.InterpolationMode.BILINEAR,
62 | ),
63 | transforms.ToTensor(),
64 | ]
65 | )
66 |
67 | self.drop_ratio = drop_ratio
68 |
69 | def augmentation(self, images, transform, state=None):
70 | if state is not None:
71 | torch.set_rng_state(state)
72 | if isinstance(images, List):
73 | transformed_images = [transform(img) for img in images]
74 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
75 | else:
76 | ret_tensor = transform(images) # (c, h, w)
77 | return ret_tensor
78 |
79 | def __getitem__(self, index):
80 | video_meta = self.vid_meta[index]
81 | video_path = video_meta["video_path"]
82 | kps_path = video_meta["kps_path"]
83 |
84 | video_reader = VideoReader(video_path)
85 | kps_reader = VideoReader(kps_path)
86 |
87 | assert len(video_reader) == len(
88 | kps_reader
89 | ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
90 |
91 | video_length = len(video_reader)
92 |
93 | clip_length = min(
94 | video_length, (self.n_sample_frames - 1) * self.sample_rate + 1
95 | )
96 | start_idx = random.randint(0, video_length - clip_length)
97 | batch_index = np.linspace(
98 | start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
99 | ).tolist()
100 |
101 | # read frames and kps
102 | vid_pil_image_list = []
103 | pose_pil_image_list = []
104 | for index in batch_index:
105 | img = video_reader[index]
106 | vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
107 | img = kps_reader[index]
108 | pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
109 |
110 | ref_img_idx = random.randint(0, video_length - 1)
111 | ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
112 |
113 | # transform
114 | state = torch.get_rng_state()
115 | pixel_values_vid = self.augmentation(
116 | vid_pil_image_list, self.pixel_transform, state
117 | )
118 | pixel_values_pose = self.augmentation(
119 | pose_pil_image_list, self.cond_transform, state
120 | )
121 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
122 | clip_ref_img = self.clip_image_processor(
123 | images=ref_img, return_tensors="pt"
124 | ).pixel_values[0]
125 |
126 | sample = dict(
127 | video_dir=video_path,
128 | pixel_values_vid=pixel_values_vid,
129 | pixel_values_pose=pixel_values_pose,
130 | pixel_values_ref_img=pixel_values_ref_img,
131 | clip_ref_img=clip_ref_img,
132 | )
133 |
134 | return sample
135 |
136 | def __len__(self):
137 | return len(self.vid_meta)
138 |
--------------------------------------------------------------------------------
/src/dwpose/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | # Openpose
3 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
4 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
5 | # 3rd Edited by ControlNet
6 | # 4th Edited by ControlNet (added face and correct hands)
7 |
8 | import copy
9 | import os
10 |
11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
12 | import cv2
13 | import numpy as np
14 | import torch
15 | from controlnet_aux.util import HWC3, resize_image
16 | from PIL import Image
17 |
18 | from . import util
19 | from .wholebody import Wholebody
20 |
21 |
22 | def draw_pose(pose, H, W):
23 | bodies = pose["bodies"]
24 | faces = pose["faces"]
25 | hands = pose["hands"]
26 | candidate = bodies["candidate"]
27 | subset = bodies["subset"]
28 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
29 |
30 | canvas = util.draw_bodypose(canvas, candidate, subset)
31 |
32 | canvas = util.draw_handpose(canvas, hands)
33 |
34 | canvas = util.draw_facepose(canvas, faces)
35 |
36 | return canvas
37 |
38 |
39 | class DWposeDetector:
40 | def __init__(self):
41 | pass
42 |
43 | def to(self, device):
44 | self.pose_estimation = Wholebody(device)
45 | return self
46 |
47 | def cal_height(self, input_image):
48 | input_image = cv2.cvtColor(
49 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
50 | )
51 |
52 | input_image = HWC3(input_image)
53 | H, W, C = input_image.shape
54 | with torch.no_grad():
55 | candidate, subset = self.pose_estimation(input_image)
56 | nums, keys, locs = candidate.shape
57 | # candidate[..., 0] /= float(W)
58 | # candidate[..., 1] /= float(H)
59 | body = candidate
60 | return body[0, ..., 1].min(), body[..., 1].max() - body[..., 1].min()
61 |
62 | def __call__(
63 | self,
64 | input_image,
65 | detect_resolution=512,
66 | image_resolution=512,
67 | output_type="pil",
68 | **kwargs,
69 | ):
70 | input_image = cv2.cvtColor(
71 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
72 | )
73 |
74 | input_image = HWC3(input_image)
75 | input_image = resize_image(input_image, detect_resolution)
76 | H, W, C = input_image.shape
77 | with torch.no_grad():
78 | candidate, subset = self.pose_estimation(input_image)
79 | nums, keys, locs = candidate.shape
80 | candidate[..., 0] /= float(W)
81 | candidate[..., 1] /= float(H)
82 | score = subset[:, :18]
83 | max_ind = np.mean(score, axis=-1).argmax(axis=0)
84 | score = score[[max_ind]]
85 | body = candidate[:, :18].copy()
86 | body = body[[max_ind]]
87 | nums = 1
88 | body = body.reshape(nums * 18, locs)
89 | body_score = copy.deepcopy(score)
90 | for i in range(len(score)):
91 | for j in range(len(score[i])):
92 | if score[i][j] > 0.3:
93 | score[i][j] = int(18 * i + j)
94 | else:
95 | score[i][j] = -1
96 |
97 | un_visible = subset < 0.3
98 | candidate[un_visible] = -1
99 |
100 | foot = candidate[:, 18:24]
101 |
102 | faces = candidate[[max_ind], 24:92]
103 |
104 | hands = candidate[[max_ind], 92:113]
105 | hands = np.vstack([hands, candidate[[max_ind], 113:]])
106 |
107 | bodies = dict(candidate=body, subset=score)
108 | pose = dict(bodies=bodies, hands=hands, faces=faces)
109 |
110 | detected_map = draw_pose(pose, H, W)
111 | detected_map = HWC3(detected_map)
112 |
113 | img = resize_image(input_image, image_resolution)
114 | H, W, C = img.shape
115 |
116 | detected_map = cv2.resize(
117 | detected_map, (W, H), interpolation=cv2.INTER_LINEAR
118 | )
119 |
120 | if output_type == "pil":
121 | detected_map = Image.fromarray(detected_map)
122 |
123 | return detected_map, body_score
124 |
--------------------------------------------------------------------------------
/src/dwpose/onnxdet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | import cv2
3 | import numpy as np
4 | import onnxruntime
5 |
6 |
7 | def nms(boxes, scores, nms_thr):
8 | """Single class NMS implemented in Numpy."""
9 | x1 = boxes[:, 0]
10 | y1 = boxes[:, 1]
11 | x2 = boxes[:, 2]
12 | y2 = boxes[:, 3]
13 |
14 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
15 | order = scores.argsort()[::-1]
16 |
17 | keep = []
18 | while order.size > 0:
19 | i = order[0]
20 | keep.append(i)
21 | xx1 = np.maximum(x1[i], x1[order[1:]])
22 | yy1 = np.maximum(y1[i], y1[order[1:]])
23 | xx2 = np.minimum(x2[i], x2[order[1:]])
24 | yy2 = np.minimum(y2[i], y2[order[1:]])
25 |
26 | w = np.maximum(0.0, xx2 - xx1 + 1)
27 | h = np.maximum(0.0, yy2 - yy1 + 1)
28 | inter = w * h
29 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
30 |
31 | inds = np.where(ovr <= nms_thr)[0]
32 | order = order[inds + 1]
33 |
34 | return keep
35 |
36 |
37 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
38 | """Multiclass NMS implemented in Numpy. Class-aware version."""
39 | final_dets = []
40 | num_classes = scores.shape[1]
41 | for cls_ind in range(num_classes):
42 | cls_scores = scores[:, cls_ind]
43 | valid_score_mask = cls_scores > score_thr
44 | if valid_score_mask.sum() == 0:
45 | continue
46 | else:
47 | valid_scores = cls_scores[valid_score_mask]
48 | valid_boxes = boxes[valid_score_mask]
49 | keep = nms(valid_boxes, valid_scores, nms_thr)
50 | if len(keep) > 0:
51 | cls_inds = np.ones((len(keep), 1)) * cls_ind
52 | dets = np.concatenate(
53 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
54 | )
55 | final_dets.append(dets)
56 | if len(final_dets) == 0:
57 | return None
58 | return np.concatenate(final_dets, 0)
59 |
60 |
61 | def demo_postprocess(outputs, img_size, p6=False):
62 | grids = []
63 | expanded_strides = []
64 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
65 |
66 | hsizes = [img_size[0] // stride for stride in strides]
67 | wsizes = [img_size[1] // stride for stride in strides]
68 |
69 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
70 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
71 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
72 | grids.append(grid)
73 | shape = grid.shape[:2]
74 | expanded_strides.append(np.full((*shape, 1), stride))
75 |
76 | grids = np.concatenate(grids, 1)
77 | expanded_strides = np.concatenate(expanded_strides, 1)
78 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
79 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
80 |
81 | return outputs
82 |
83 |
84 | def preprocess(img, input_size, swap=(2, 0, 1)):
85 | if len(img.shape) == 3:
86 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
87 | else:
88 | padded_img = np.ones(input_size, dtype=np.uint8) * 114
89 |
90 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
91 | resized_img = cv2.resize(
92 | img,
93 | (int(img.shape[1] * r), int(img.shape[0] * r)),
94 | interpolation=cv2.INTER_LINEAR,
95 | ).astype(np.uint8)
96 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
97 |
98 | padded_img = padded_img.transpose(swap)
99 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
100 | return padded_img, r
101 |
102 |
103 | def inference_detector(session, oriImg):
104 | input_shape = (640, 640)
105 | img, ratio = preprocess(oriImg, input_shape)
106 |
107 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
108 | output = session.run(None, ort_inputs)
109 | predictions = demo_postprocess(output[0], input_shape)[0]
110 |
111 | boxes = predictions[:, :4]
112 | scores = predictions[:, 4:5] * predictions[:, 5:]
113 |
114 | boxes_xyxy = np.ones_like(boxes)
115 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
116 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
117 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
118 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
119 | boxes_xyxy /= ratio
120 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
121 | if dets is not None:
122 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
123 | isscore = final_scores > 0.3
124 | iscat = final_cls_inds == 0
125 | isbbox = [i and j for (i, j) in zip(isscore, iscat)]
126 | final_boxes = final_boxes[isbbox]
127 | else:
128 | return []
129 |
130 | return final_boxes
131 |
--------------------------------------------------------------------------------
/src/dwpose/onnxpose.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | from typing import List, Tuple
3 |
4 | import cv2
5 | import numpy as np
6 | import onnxruntime as ort
7 |
8 |
9 | def preprocess(
10 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
11 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
12 | """Do preprocessing for RTMPose model inference.
13 |
14 | Args:
15 | img (np.ndarray): Input image in shape.
16 | input_size (tuple): Input image size in shape (w, h).
17 |
18 | Returns:
19 | tuple:
20 | - resized_img (np.ndarray): Preprocessed image.
21 | - center (np.ndarray): Center of image.
22 | - scale (np.ndarray): Scale of image.
23 | """
24 | # get shape of image
25 | img_shape = img.shape[:2]
26 | out_img, out_center, out_scale = [], [], []
27 | if len(out_bbox) == 0:
28 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
29 | for i in range(len(out_bbox)):
30 | x0 = out_bbox[i][0]
31 | y0 = out_bbox[i][1]
32 | x1 = out_bbox[i][2]
33 | y1 = out_bbox[i][3]
34 | bbox = np.array([x0, y0, x1, y1])
35 |
36 | # get center and scale
37 | center, scale = bbox_xyxy2cs(bbox, padding=1.25)
38 |
39 | # do affine transformation
40 | resized_img, scale = top_down_affine(input_size, scale, center, img)
41 |
42 | # normalize image
43 | mean = np.array([123.675, 116.28, 103.53])
44 | std = np.array([58.395, 57.12, 57.375])
45 | resized_img = (resized_img - mean) / std
46 |
47 | out_img.append(resized_img)
48 | out_center.append(center)
49 | out_scale.append(scale)
50 |
51 | return out_img, out_center, out_scale
52 |
53 |
54 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
55 | """Inference RTMPose model.
56 |
57 | Args:
58 | sess (ort.InferenceSession): ONNXRuntime session.
59 | img (np.ndarray): Input image in shape.
60 |
61 | Returns:
62 | outputs (np.ndarray): Output of RTMPose model.
63 | """
64 | all_out = []
65 | # build input
66 | for i in range(len(img)):
67 | input = [img[i].transpose(2, 0, 1)]
68 |
69 | # build output
70 | sess_input = {sess.get_inputs()[0].name: input}
71 | sess_output = []
72 | for out in sess.get_outputs():
73 | sess_output.append(out.name)
74 |
75 | # run model
76 | outputs = sess.run(sess_output, sess_input)
77 | all_out.append(outputs)
78 |
79 | return all_out
80 |
81 |
82 | def postprocess(
83 | outputs: List[np.ndarray],
84 | model_input_size: Tuple[int, int],
85 | center: Tuple[int, int],
86 | scale: Tuple[int, int],
87 | simcc_split_ratio: float = 2.0,
88 | ) -> Tuple[np.ndarray, np.ndarray]:
89 | """Postprocess for RTMPose model output.
90 |
91 | Args:
92 | outputs (np.ndarray): Output of RTMPose model.
93 | model_input_size (tuple): RTMPose model Input image size.
94 | center (tuple): Center of bbox in shape (x, y).
95 | scale (tuple): Scale of bbox in shape (w, h).
96 | simcc_split_ratio (float): Split ratio of simcc.
97 |
98 | Returns:
99 | tuple:
100 | - keypoints (np.ndarray): Rescaled keypoints.
101 | - scores (np.ndarray): Model predict scores.
102 | """
103 | all_key = []
104 | all_score = []
105 | for i in range(len(outputs)):
106 | # use simcc to decode
107 | simcc_x, simcc_y = outputs[i]
108 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
109 |
110 | # rescale keypoints
111 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
112 | all_key.append(keypoints[0])
113 | all_score.append(scores[0])
114 |
115 | return np.array(all_key), np.array(all_score)
116 |
117 |
118 | def bbox_xyxy2cs(
119 | bbox: np.ndarray, padding: float = 1.0
120 | ) -> Tuple[np.ndarray, np.ndarray]:
121 | """Transform the bbox format from (x,y,w,h) into (center, scale)
122 |
123 | Args:
124 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
125 | as (left, top, right, bottom)
126 | padding (float): BBox padding factor that will be multilied to scale.
127 | Default: 1.0
128 |
129 | Returns:
130 | tuple: A tuple containing center and scale.
131 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
132 | (n, 2)
133 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
134 | (n, 2)
135 | """
136 | # convert single bbox from (4, ) to (1, 4)
137 | dim = bbox.ndim
138 | if dim == 1:
139 | bbox = bbox[None, :]
140 |
141 | # get bbox center and scale
142 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
143 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5
144 | scale = np.hstack([x2 - x1, y2 - y1]) * padding
145 |
146 | if dim == 1:
147 | center = center[0]
148 | scale = scale[0]
149 |
150 | return center, scale
151 |
152 |
153 | def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float) -> np.ndarray:
154 | """Extend the scale to match the given aspect ratio.
155 |
156 | Args:
157 | scale (np.ndarray): The image scale (w, h) in shape (2, )
158 | aspect_ratio (float): The ratio of ``w/h``
159 |
160 | Returns:
161 | np.ndarray: The reshaped image scale in (2, )
162 | """
163 | w, h = np.hsplit(bbox_scale, [1])
164 | bbox_scale = np.where(
165 | w > h * aspect_ratio,
166 | np.hstack([w, w / aspect_ratio]),
167 | np.hstack([h * aspect_ratio, h]),
168 | )
169 | return bbox_scale
170 |
171 |
172 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
173 | """Rotate a point by an angle.
174 |
175 | Args:
176 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
177 | angle_rad (float): rotation angle in radian
178 |
179 | Returns:
180 | np.ndarray: Rotated point in shape (2, )
181 | """
182 | sn, cs = np.sin(angle_rad), np.cos(angle_rad)
183 | rot_mat = np.array([[cs, -sn], [sn, cs]])
184 | return rot_mat @ pt
185 |
186 |
187 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
188 | """To calculate the affine matrix, three pairs of points are required. This
189 | function is used to get the 3rd point, given 2D points a & b.
190 |
191 | The 3rd point is defined by rotating vector `a - b` by 90 degrees
192 | anticlockwise, using b as the rotation center.
193 |
194 | Args:
195 | a (np.ndarray): The 1st point (x,y) in shape (2, )
196 | b (np.ndarray): The 2nd point (x,y) in shape (2, )
197 |
198 | Returns:
199 | np.ndarray: The 3rd point.
200 | """
201 | direction = a - b
202 | c = b + np.r_[-direction[1], direction[0]]
203 | return c
204 |
205 |
206 | def get_warp_matrix(
207 | center: np.ndarray,
208 | scale: np.ndarray,
209 | rot: float,
210 | output_size: Tuple[int, int],
211 | shift: Tuple[float, float] = (0.0, 0.0),
212 | inv: bool = False,
213 | ) -> np.ndarray:
214 | """Calculate the affine transformation matrix that can warp the bbox area
215 | in the input image to the output size.
216 |
217 | Args:
218 | center (np.ndarray[2, ]): Center of the bounding box (x, y).
219 | scale (np.ndarray[2, ]): Scale of the bounding box
220 | wrt [width, height].
221 | rot (float): Rotation angle (degree).
222 | output_size (np.ndarray[2, ] | list(2,)): Size of the
223 | destination heatmaps.
224 | shift (0-100%): Shift translation ratio wrt the width/height.
225 | Default (0., 0.).
226 | inv (bool): Option to inverse the affine transform direction.
227 | (inv=False: src->dst or inv=True: dst->src)
228 |
229 | Returns:
230 | np.ndarray: A 2x3 transformation matrix
231 | """
232 | shift = np.array(shift)
233 | src_w = scale[0]
234 | dst_w = output_size[0]
235 | dst_h = output_size[1]
236 |
237 | # compute transformation matrix
238 | rot_rad = np.deg2rad(rot)
239 | src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
240 | dst_dir = np.array([0.0, dst_w * -0.5])
241 |
242 | # get four corners of the src rectangle in the original image
243 | src = np.zeros((3, 2), dtype=np.float32)
244 | src[0, :] = center + scale * shift
245 | src[1, :] = center + src_dir + scale * shift
246 | src[2, :] = _get_3rd_point(src[0, :], src[1, :])
247 |
248 | # get four corners of the dst rectangle in the input image
249 | dst = np.zeros((3, 2), dtype=np.float32)
250 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
251 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
252 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
253 |
254 | if inv:
255 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
256 | else:
257 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
258 |
259 | return warp_mat
260 |
261 |
262 | def top_down_affine(
263 | input_size: dict, bbox_scale: dict, bbox_center: dict, img: np.ndarray
264 | ) -> Tuple[np.ndarray, np.ndarray]:
265 | """Get the bbox image as the model input by affine transform.
266 |
267 | Args:
268 | input_size (dict): The input size of the model.
269 | bbox_scale (dict): The bbox scale of the img.
270 | bbox_center (dict): The bbox center of the img.
271 | img (np.ndarray): The original image.
272 |
273 | Returns:
274 | tuple: A tuple containing center and scale.
275 | - np.ndarray[float32]: img after affine transform.
276 | - np.ndarray[float32]: bbox scale after affine transform.
277 | """
278 | w, h = input_size
279 | warp_size = (int(w), int(h))
280 |
281 | # reshape bbox to fixed aspect ratio
282 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
283 |
284 | # get the affine matrix
285 | center = bbox_center
286 | scale = bbox_scale
287 | rot = 0
288 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
289 |
290 | # do affine transform
291 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
292 |
293 | return img, bbox_scale
294 |
295 |
296 | def get_simcc_maximum(
297 | simcc_x: np.ndarray, simcc_y: np.ndarray
298 | ) -> Tuple[np.ndarray, np.ndarray]:
299 | """Get maximum response location and value from simcc representations.
300 |
301 | Note:
302 | instance number: N
303 | num_keypoints: K
304 | heatmap height: H
305 | heatmap width: W
306 |
307 | Args:
308 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
309 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
310 |
311 | Returns:
312 | tuple:
313 | - locs (np.ndarray): locations of maximum heatmap responses in shape
314 | (K, 2) or (N, K, 2)
315 | - vals (np.ndarray): values of maximum heatmap responses in shape
316 | (K,) or (N, K)
317 | """
318 | N, K, Wx = simcc_x.shape
319 | simcc_x = simcc_x.reshape(N * K, -1)
320 | simcc_y = simcc_y.reshape(N * K, -1)
321 |
322 | # get maximum value locations
323 | x_locs = np.argmax(simcc_x, axis=1)
324 | y_locs = np.argmax(simcc_y, axis=1)
325 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
326 | max_val_x = np.amax(simcc_x, axis=1)
327 | max_val_y = np.amax(simcc_y, axis=1)
328 |
329 | # get maximum value across x and y axis
330 | mask = max_val_x > max_val_y
331 | max_val_x[mask] = max_val_y[mask]
332 | vals = max_val_x
333 | locs[vals <= 0.0] = -1
334 |
335 | # reshape
336 | locs = locs.reshape(N, K, 2)
337 | vals = vals.reshape(N, K)
338 |
339 | return locs, vals
340 |
341 |
342 | def decode(
343 | simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio
344 | ) -> Tuple[np.ndarray, np.ndarray]:
345 | """Modulate simcc distribution with Gaussian.
346 |
347 | Args:
348 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
349 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
350 | simcc_split_ratio (int): The split ratio of simcc.
351 |
352 | Returns:
353 | tuple: A tuple containing center and scale.
354 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
355 | - np.ndarray[float32]: scores in shape (K,) or (n, K)
356 | """
357 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
358 | keypoints /= simcc_split_ratio
359 |
360 | return keypoints, scores
361 |
362 |
363 | def inference_pose(session, out_bbox, oriImg):
364 | h, w = session.get_inputs()[0].shape[2:]
365 | model_input_size = (w, h)
366 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
367 | outputs = inference(session, resized_img)
368 | keypoints, scores = postprocess(outputs, model_input_size, center, scale)
369 |
370 | return keypoints, scores
371 |
--------------------------------------------------------------------------------
/src/dwpose/util.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | import math
3 | import numpy as np
4 | import matplotlib
5 | import cv2
6 |
7 |
8 | eps = 0.01
9 |
10 |
11 | def smart_resize(x, s):
12 | Ht, Wt = s
13 | if x.ndim == 2:
14 | Ho, Wo = x.shape
15 | Co = 1
16 | else:
17 | Ho, Wo, Co = x.shape
18 | if Co == 3 or Co == 1:
19 | k = float(Ht + Wt) / float(Ho + Wo)
20 | return cv2.resize(
21 | x,
22 | (int(Wt), int(Ht)),
23 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
24 | )
25 | else:
26 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
27 |
28 |
29 | def smart_resize_k(x, fx, fy):
30 | if x.ndim == 2:
31 | Ho, Wo = x.shape
32 | Co = 1
33 | else:
34 | Ho, Wo, Co = x.shape
35 | Ht, Wt = Ho * fy, Wo * fx
36 | if Co == 3 or Co == 1:
37 | k = float(Ht + Wt) / float(Ho + Wo)
38 | return cv2.resize(
39 | x,
40 | (int(Wt), int(Ht)),
41 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
42 | )
43 | else:
44 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
45 |
46 |
47 | def padRightDownCorner(img, stride, padValue):
48 | h = img.shape[0]
49 | w = img.shape[1]
50 |
51 | pad = 4 * [None]
52 | pad[0] = 0 # up
53 | pad[1] = 0 # left
54 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
55 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
56 |
57 | img_padded = img
58 | pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
59 | img_padded = np.concatenate((pad_up, img_padded), axis=0)
60 | pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
61 | img_padded = np.concatenate((pad_left, img_padded), axis=1)
62 | pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
63 | img_padded = np.concatenate((img_padded, pad_down), axis=0)
64 | pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
65 | img_padded = np.concatenate((img_padded, pad_right), axis=1)
66 |
67 | return img_padded, pad
68 |
69 |
70 | def transfer(model, model_weights):
71 | transfered_model_weights = {}
72 | for weights_name in model.state_dict().keys():
73 | transfered_model_weights[weights_name] = model_weights[
74 | ".".join(weights_name.split(".")[1:])
75 | ]
76 | return transfered_model_weights
77 |
78 |
79 | def draw_bodypose(canvas, candidate, subset):
80 | H, W, C = canvas.shape
81 | candidate = np.array(candidate)
82 | subset = np.array(subset)
83 |
84 | stickwidth = 4
85 |
86 | limbSeq = [
87 | [2, 3],
88 | [2, 6],
89 | [3, 4],
90 | [4, 5],
91 | [6, 7],
92 | [7, 8],
93 | [2, 9],
94 | [9, 10],
95 | [10, 11],
96 | [2, 12],
97 | [12, 13],
98 | [13, 14],
99 | [2, 1],
100 | [1, 15],
101 | [15, 17],
102 | [1, 16],
103 | [16, 18],
104 | [3, 17],
105 | [6, 18],
106 | ]
107 |
108 | colors = [
109 | [255, 0, 0],
110 | [255, 85, 0],
111 | [255, 170, 0],
112 | [255, 255, 0],
113 | [170, 255, 0],
114 | [85, 255, 0],
115 | [0, 255, 0],
116 | [0, 255, 85],
117 | [0, 255, 170],
118 | [0, 255, 255],
119 | [0, 170, 255],
120 | [0, 85, 255],
121 | [0, 0, 255],
122 | [85, 0, 255],
123 | [170, 0, 255],
124 | [255, 0, 255],
125 | [255, 0, 170],
126 | [255, 0, 85],
127 | ]
128 |
129 | for i in range(17):
130 | for n in range(len(subset)):
131 | index = subset[n][np.array(limbSeq[i]) - 1]
132 | if -1 in index:
133 | continue
134 | Y = candidate[index.astype(int), 0] * float(W)
135 | X = candidate[index.astype(int), 1] * float(H)
136 | mX = np.mean(X)
137 | mY = np.mean(Y)
138 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
139 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
140 | polygon = cv2.ellipse2Poly(
141 | (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
142 | )
143 | cv2.fillConvexPoly(canvas, polygon, colors[i])
144 |
145 | canvas = (canvas * 0.6).astype(np.uint8)
146 |
147 | for i in range(18):
148 | for n in range(len(subset)):
149 | index = int(subset[n][i])
150 | if index == -1:
151 | continue
152 | x, y = candidate[index][0:2]
153 | x = int(x * W)
154 | y = int(y * H)
155 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
156 |
157 | return canvas
158 |
159 |
160 | def draw_handpose(canvas, all_hand_peaks):
161 | H, W, C = canvas.shape
162 |
163 | edges = [
164 | [0, 1],
165 | [1, 2],
166 | [2, 3],
167 | [3, 4],
168 | [0, 5],
169 | [5, 6],
170 | [6, 7],
171 | [7, 8],
172 | [0, 9],
173 | [9, 10],
174 | [10, 11],
175 | [11, 12],
176 | [0, 13],
177 | [13, 14],
178 | [14, 15],
179 | [15, 16],
180 | [0, 17],
181 | [17, 18],
182 | [18, 19],
183 | [19, 20],
184 | ]
185 |
186 | for peaks in all_hand_peaks:
187 | peaks = np.array(peaks)
188 |
189 | for ie, e in enumerate(edges):
190 | x1, y1 = peaks[e[0]]
191 | x2, y2 = peaks[e[1]]
192 | x1 = int(x1 * W)
193 | y1 = int(y1 * H)
194 | x2 = int(x2 * W)
195 | y2 = int(y2 * H)
196 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
197 | cv2.line(
198 | canvas,
199 | (x1, y1),
200 | (x2, y2),
201 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
202 | * 255,
203 | thickness=2,
204 | )
205 |
206 | for i, keyponit in enumerate(peaks):
207 | x, y = keyponit
208 | x = int(x * W)
209 | y = int(y * H)
210 | if x > eps and y > eps:
211 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
212 | return canvas
213 |
214 |
215 | def draw_facepose(canvas, all_lmks):
216 | H, W, C = canvas.shape
217 | for lmks in all_lmks:
218 | lmks = np.array(lmks)
219 | for lmk in lmks:
220 | x, y = lmk
221 | x = int(x * W)
222 | y = int(y * H)
223 | if x > eps and y > eps:
224 | cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
225 | return canvas
226 |
227 |
228 | # detect hand according to body pose keypoints
229 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
230 | def handDetect(candidate, subset, oriImg):
231 | # right hand: wrist 4, elbow 3, shoulder 2
232 | # left hand: wrist 7, elbow 6, shoulder 5
233 | ratioWristElbow = 0.33
234 | detect_result = []
235 | image_height, image_width = oriImg.shape[0:2]
236 | for person in subset.astype(int):
237 | # if any of three not detected
238 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0
239 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0
240 | if not (has_left or has_right):
241 | continue
242 | hands = []
243 | # left hand
244 | if has_left:
245 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
246 | x1, y1 = candidate[left_shoulder_index][:2]
247 | x2, y2 = candidate[left_elbow_index][:2]
248 | x3, y3 = candidate[left_wrist_index][:2]
249 | hands.append([x1, y1, x2, y2, x3, y3, True])
250 | # right hand
251 | if has_right:
252 | right_shoulder_index, right_elbow_index, right_wrist_index = person[
253 | [2, 3, 4]
254 | ]
255 | x1, y1 = candidate[right_shoulder_index][:2]
256 | x2, y2 = candidate[right_elbow_index][:2]
257 | x3, y3 = candidate[right_wrist_index][:2]
258 | hands.append([x1, y1, x2, y2, x3, y3, False])
259 |
260 | for x1, y1, x2, y2, x3, y3, is_left in hands:
261 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
262 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
263 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
264 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
265 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
266 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
267 | x = x3 + ratioWristElbow * (x3 - x2)
268 | y = y3 + ratioWristElbow * (y3 - y2)
269 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
270 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
271 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
272 | # x-y refers to the center --> offset to topLeft point
273 | # handRectangle.x -= handRectangle.width / 2.f;
274 | # handRectangle.y -= handRectangle.height / 2.f;
275 | x -= width / 2
276 | y -= width / 2 # width = height
277 | # overflow the image
278 | if x < 0:
279 | x = 0
280 | if y < 0:
281 | y = 0
282 | width1 = width
283 | width2 = width
284 | if x + width > image_width:
285 | width1 = image_width - x
286 | if y + width > image_height:
287 | width2 = image_height - y
288 | width = min(width1, width2)
289 | # the max hand box value is 20 pixels
290 | if width >= 20:
291 | detect_result.append([int(x), int(y), int(width), is_left])
292 |
293 | """
294 | return value: [[x, y, w, True if left hand else False]].
295 | width=height since the network require squared input.
296 | x, y is the coordinate of top left
297 | """
298 | return detect_result
299 |
300 |
301 | # Written by Lvmin
302 | def faceDetect(candidate, subset, oriImg):
303 | # left right eye ear 14 15 16 17
304 | detect_result = []
305 | image_height, image_width = oriImg.shape[0:2]
306 | for person in subset.astype(int):
307 | has_head = person[0] > -1
308 | if not has_head:
309 | continue
310 |
311 | has_left_eye = person[14] > -1
312 | has_right_eye = person[15] > -1
313 | has_left_ear = person[16] > -1
314 | has_right_ear = person[17] > -1
315 |
316 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
317 | continue
318 |
319 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
320 |
321 | width = 0.0
322 | x0, y0 = candidate[head][:2]
323 |
324 | if has_left_eye:
325 | x1, y1 = candidate[left_eye][:2]
326 | d = max(abs(x0 - x1), abs(y0 - y1))
327 | width = max(width, d * 3.0)
328 |
329 | if has_right_eye:
330 | x1, y1 = candidate[right_eye][:2]
331 | d = max(abs(x0 - x1), abs(y0 - y1))
332 | width = max(width, d * 3.0)
333 |
334 | if has_left_ear:
335 | x1, y1 = candidate[left_ear][:2]
336 | d = max(abs(x0 - x1), abs(y0 - y1))
337 | width = max(width, d * 1.5)
338 |
339 | if has_right_ear:
340 | x1, y1 = candidate[right_ear][:2]
341 | d = max(abs(x0 - x1), abs(y0 - y1))
342 | width = max(width, d * 1.5)
343 |
344 | x, y = x0, y0
345 |
346 | x -= width
347 | y -= width
348 |
349 | if x < 0:
350 | x = 0
351 |
352 | if y < 0:
353 | y = 0
354 |
355 | width1 = width * 2
356 | width2 = width * 2
357 |
358 | if x + width > image_width:
359 | width1 = image_width - x
360 |
361 | if y + width > image_height:
362 | width2 = image_height - y
363 |
364 | width = min(width1, width2)
365 |
366 | if width >= 20:
367 | detect_result.append([int(x), int(y), int(width)])
368 |
369 | return detect_result
370 |
371 |
372 | # get max index of 2d array
373 | def npmax(array):
374 | arrayindex = array.argmax(1)
375 | arrayvalue = array.max(1)
376 | i = arrayvalue.argmax()
377 | j = arrayindex[i]
378 | return i, j
379 |
--------------------------------------------------------------------------------
/src/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | from pathlib import Path
3 |
4 | import cv2
5 | import numpy as np
6 | import onnxruntime as ort
7 |
8 | from .onnxdet import inference_detector
9 | from .onnxpose import inference_pose
10 |
11 | ModelDataPathPrefix = Path("./pretrained_weights")
12 |
13 |
14 | class Wholebody:
15 | def __init__(self, device="cuda:0"):
16 | providers = (
17 | ["CPUExecutionProvider"] if device == "cpu" else ["CUDAExecutionProvider"]
18 | )
19 | onnx_det = ModelDataPathPrefix.joinpath("DWPose/yolox_l.onnx")
20 | onnx_pose = ModelDataPathPrefix.joinpath("DWPose/dw-ll_ucoco_384.onnx")
21 |
22 | self.session_det = ort.InferenceSession(
23 | path_or_bytes=onnx_det, providers=providers
24 | )
25 | self.session_pose = ort.InferenceSession(
26 | path_or_bytes=onnx_pose, providers=providers
27 | )
28 |
29 | def __call__(self, oriImg):
30 | det_result = inference_detector(self.session_det, oriImg)
31 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
32 |
33 | keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
34 | # compute neck joint
35 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
36 | # neck score when visualizing pred
37 | neck[:, 2:4] = np.logical_and(
38 | keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3
39 | ).astype(int)
40 | new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
41 | mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
42 | openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
43 | new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
44 | keypoints_info = new_keypoints_info
45 |
46 | keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
47 |
48 | return keypoints, scores
49 |
--------------------------------------------------------------------------------
/src/models/motion_module.py:
--------------------------------------------------------------------------------
1 | # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional
5 |
6 | import torch
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.attention_processor import Attention, AttnProcessor
9 | from diffusers.utils import BaseOutput
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from einops import rearrange, repeat
12 | from torch import nn
13 |
14 |
15 | def zero_module(module):
16 | # Zero out the parameters of a module and return it.
17 | for p in module.parameters():
18 | p.detach().zero_()
19 | return module
20 |
21 |
22 | @dataclass
23 | class TemporalTransformer3DModelOutput(BaseOutput):
24 | sample: torch.FloatTensor
25 |
26 |
27 | if is_xformers_available():
28 | import xformers
29 | import xformers.ops
30 | else:
31 | xformers = None
32 |
33 |
34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35 | if motion_module_type == "Vanilla":
36 | return VanillaTemporalModule(
37 | in_channels=in_channels,
38 | **motion_module_kwargs,
39 | )
40 | else:
41 | raise ValueError
42 |
43 |
44 | class VanillaTemporalModule(nn.Module):
45 | def __init__(
46 | self,
47 | in_channels,
48 | num_attention_heads=8,
49 | num_transformer_block=2,
50 | attention_block_types=("Temporal_Self", "Temporal_Self"),
51 | cross_frame_attention_mode=None,
52 | temporal_position_encoding=False,
53 | temporal_position_encoding_max_len=24,
54 | temporal_attention_dim_div=1,
55 | zero_initialize=True,
56 | ):
57 | super().__init__()
58 |
59 | self.temporal_transformer = TemporalTransformer3DModel(
60 | in_channels=in_channels,
61 | num_attention_heads=num_attention_heads,
62 | attention_head_dim=in_channels
63 | // num_attention_heads
64 | // temporal_attention_dim_div,
65 | num_layers=num_transformer_block,
66 | attention_block_types=attention_block_types,
67 | cross_frame_attention_mode=cross_frame_attention_mode,
68 | temporal_position_encoding=temporal_position_encoding,
69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70 | )
71 |
72 | if zero_initialize:
73 | self.temporal_transformer.proj_out = zero_module(
74 | self.temporal_transformer.proj_out
75 | )
76 |
77 | def forward(
78 | self,
79 | input_tensor,
80 | temb,
81 | encoder_hidden_states,
82 | attention_mask=None,
83 | anchor_frame_idx=None,
84 | ):
85 | hidden_states = input_tensor
86 | hidden_states = self.temporal_transformer(
87 | hidden_states, encoder_hidden_states, attention_mask
88 | )
89 |
90 | output = hidden_states
91 | return output
92 |
93 |
94 | class TemporalTransformer3DModel(nn.Module):
95 | def __init__(
96 | self,
97 | in_channels,
98 | num_attention_heads,
99 | attention_head_dim,
100 | num_layers,
101 | attention_block_types=(
102 | "Temporal_Self",
103 | "Temporal_Self",
104 | ),
105 | dropout=0.0,
106 | norm_num_groups=32,
107 | cross_attention_dim=768,
108 | activation_fn="geglu",
109 | attention_bias=False,
110 | upcast_attention=False,
111 | cross_frame_attention_mode=None,
112 | temporal_position_encoding=False,
113 | temporal_position_encoding_max_len=24,
114 | ):
115 | super().__init__()
116 |
117 | inner_dim = num_attention_heads * attention_head_dim
118 |
119 | self.norm = torch.nn.GroupNorm(
120 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121 | )
122 | self.proj_in = nn.Linear(in_channels, inner_dim)
123 |
124 | self.transformer_blocks = nn.ModuleList(
125 | [
126 | TemporalTransformerBlock(
127 | dim=inner_dim,
128 | num_attention_heads=num_attention_heads,
129 | attention_head_dim=attention_head_dim,
130 | attention_block_types=attention_block_types,
131 | dropout=dropout,
132 | norm_num_groups=norm_num_groups,
133 | cross_attention_dim=cross_attention_dim,
134 | activation_fn=activation_fn,
135 | attention_bias=attention_bias,
136 | upcast_attention=upcast_attention,
137 | cross_frame_attention_mode=cross_frame_attention_mode,
138 | temporal_position_encoding=temporal_position_encoding,
139 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140 | )
141 | for d in range(num_layers)
142 | ]
143 | )
144 | self.proj_out = nn.Linear(inner_dim, in_channels)
145 |
146 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147 | assert (
148 | hidden_states.dim() == 5
149 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150 | video_length = hidden_states.shape[2]
151 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152 |
153 | batch, channel, height, weight = hidden_states.shape
154 | residual = hidden_states
155 |
156 | hidden_states = self.norm(hidden_states)
157 | inner_dim = hidden_states.shape[1]
158 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159 | batch, height * weight, inner_dim
160 | )
161 | hidden_states = self.proj_in(hidden_states)
162 |
163 | # Transformer Blocks
164 | for block in self.transformer_blocks:
165 | hidden_states = block(
166 | hidden_states,
167 | encoder_hidden_states=encoder_hidden_states,
168 | video_length=video_length,
169 | )
170 |
171 | # output
172 | hidden_states = self.proj_out(hidden_states)
173 | hidden_states = (
174 | hidden_states.reshape(batch, height, weight, inner_dim)
175 | .permute(0, 3, 1, 2)
176 | .contiguous()
177 | )
178 |
179 | output = hidden_states + residual
180 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181 |
182 | return output
183 |
184 |
185 | class TemporalTransformerBlock(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | num_attention_heads,
190 | attention_head_dim,
191 | attention_block_types=(
192 | "Temporal_Self",
193 | "Temporal_Self",
194 | ),
195 | dropout=0.0,
196 | norm_num_groups=32,
197 | cross_attention_dim=768,
198 | activation_fn="geglu",
199 | attention_bias=False,
200 | upcast_attention=False,
201 | cross_frame_attention_mode=None,
202 | temporal_position_encoding=False,
203 | temporal_position_encoding_max_len=24,
204 | ):
205 | super().__init__()
206 |
207 | attention_blocks = []
208 | norms = []
209 |
210 | for block_name in attention_block_types:
211 | attention_blocks.append(
212 | VersatileAttention(
213 | attention_mode=block_name.split("_")[0],
214 | cross_attention_dim=cross_attention_dim
215 | if block_name.endswith("_Cross")
216 | else None,
217 | query_dim=dim,
218 | heads=num_attention_heads,
219 | dim_head=attention_head_dim,
220 | dropout=dropout,
221 | bias=attention_bias,
222 | upcast_attention=upcast_attention,
223 | cross_frame_attention_mode=cross_frame_attention_mode,
224 | temporal_position_encoding=temporal_position_encoding,
225 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226 | )
227 | )
228 | norms.append(nn.LayerNorm(dim))
229 |
230 | self.attention_blocks = nn.ModuleList(attention_blocks)
231 | self.norms = nn.ModuleList(norms)
232 |
233 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234 | self.ff_norm = nn.LayerNorm(dim)
235 |
236 | def forward(
237 | self,
238 | hidden_states,
239 | encoder_hidden_states=None,
240 | attention_mask=None,
241 | video_length=None,
242 | ):
243 | for attention_block, norm in zip(self.attention_blocks, self.norms):
244 | norm_hidden_states = norm(hidden_states)
245 | hidden_states = (
246 | attention_block(
247 | norm_hidden_states,
248 | encoder_hidden_states=encoder_hidden_states
249 | if attention_block.is_cross_attention
250 | else None,
251 | video_length=video_length,
252 | )
253 | + hidden_states
254 | )
255 |
256 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257 |
258 | output = hidden_states
259 | return output
260 |
261 |
262 | class PositionalEncoding(nn.Module):
263 | def __init__(self, d_model, dropout=0.0, max_len=24):
264 | super().__init__()
265 | self.dropout = nn.Dropout(p=dropout)
266 | position = torch.arange(max_len).unsqueeze(1)
267 | div_term = torch.exp(
268 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269 | )
270 | pe = torch.zeros(1, max_len, d_model)
271 | pe[0, :, 0::2] = torch.sin(position * div_term)
272 | pe[0, :, 1::2] = torch.cos(position * div_term)
273 | self.register_buffer("pe", pe)
274 |
275 | def forward(self, x):
276 | x = x + self.pe[:, : x.size(1)]
277 | return self.dropout(x)
278 |
279 |
280 | class VersatileAttention(Attention):
281 | def __init__(
282 | self,
283 | attention_mode=None,
284 | cross_frame_attention_mode=None,
285 | temporal_position_encoding=False,
286 | temporal_position_encoding_max_len=24,
287 | *args,
288 | **kwargs,
289 | ):
290 | super().__init__(*args, **kwargs)
291 | assert attention_mode == "Temporal"
292 |
293 | self.attention_mode = attention_mode
294 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295 |
296 | self.pos_encoder = (
297 | PositionalEncoding(
298 | kwargs["query_dim"],
299 | dropout=0.0,
300 | max_len=temporal_position_encoding_max_len,
301 | )
302 | if (temporal_position_encoding and attention_mode == "Temporal")
303 | else None
304 | )
305 |
306 | def extra_repr(self):
307 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308 |
309 | def set_use_memory_efficient_attention_xformers(
310 | self,
311 | use_memory_efficient_attention_xformers: bool,
312 | attention_op: Optional[Callable] = None,
313 | ):
314 | if use_memory_efficient_attention_xformers:
315 | if not is_xformers_available():
316 | raise ModuleNotFoundError(
317 | (
318 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319 | " xformers"
320 | ),
321 | name="xformers",
322 | )
323 | elif not torch.cuda.is_available():
324 | raise ValueError(
325 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326 | " only available for GPU "
327 | )
328 | else:
329 | try:
330 | # Make sure we can run the memory efficient attention
331 | _ = xformers.ops.memory_efficient_attention(
332 | torch.randn((1, 2, 40), device="cuda"),
333 | torch.randn((1, 2, 40), device="cuda"),
334 | torch.randn((1, 2, 40), device="cuda"),
335 | )
336 | except Exception as e:
337 | raise e
338 |
339 | # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340 | # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341 | # You don't need XFormersAttnProcessor here.
342 | # processor = XFormersAttnProcessor(
343 | # attention_op=attention_op,
344 | # )
345 | processor = AttnProcessor()
346 | else:
347 | processor = AttnProcessor()
348 |
349 | self.set_processor(processor)
350 |
351 | def forward(
352 | self,
353 | hidden_states,
354 | encoder_hidden_states=None,
355 | attention_mask=None,
356 | video_length=None,
357 | **cross_attention_kwargs,
358 | ):
359 | if self.attention_mode == "Temporal":
360 | d = hidden_states.shape[1] # d means HxW
361 | hidden_states = rearrange(
362 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
363 | )
364 |
365 | if self.pos_encoder is not None:
366 | hidden_states = self.pos_encoder(hidden_states)
367 |
368 | encoder_hidden_states = (
369 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370 | if encoder_hidden_states is not None
371 | else encoder_hidden_states
372 | )
373 |
374 | else:
375 | raise NotImplementedError
376 |
377 | hidden_states = self.processor(
378 | self,
379 | hidden_states,
380 | encoder_hidden_states=encoder_hidden_states,
381 | attention_mask=attention_mask,
382 | **cross_attention_kwargs,
383 | )
384 |
385 | if self.attention_mode == "Temporal":
386 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387 |
388 | return hidden_states
389 |
--------------------------------------------------------------------------------
/src/models/mutual_self_attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from src.models.attention import TemporalBasicTransformerBlock
8 |
9 | from .attention import BasicTransformerBlock
10 |
11 |
12 | def torch_dfs(model: torch.nn.Module):
13 | result = [model]
14 | for child in model.children():
15 | result += torch_dfs(child)
16 | return result
17 |
18 |
19 | class ReferenceAttentionControl:
20 | def __init__(
21 | self,
22 | unet,
23 | mode="write",
24 | do_classifier_free_guidance=False,
25 | attention_auto_machine_weight=float("inf"),
26 | gn_auto_machine_weight=1.0,
27 | style_fidelity=1.0,
28 | reference_attn=True,
29 | reference_adain=False,
30 | fusion_blocks="midup",
31 | batch_size=1,
32 | ) -> None:
33 | # 10. Modify self attention and group norm
34 | self.unet = unet
35 | assert mode in ["read", "write"]
36 | assert fusion_blocks in ["midup", "full"]
37 | self.reference_attn = reference_attn
38 | self.reference_adain = reference_adain
39 | self.fusion_blocks = fusion_blocks
40 | self.register_reference_hooks(
41 | mode,
42 | do_classifier_free_guidance,
43 | attention_auto_machine_weight,
44 | gn_auto_machine_weight,
45 | style_fidelity,
46 | reference_attn,
47 | reference_adain,
48 | fusion_blocks,
49 | batch_size=batch_size,
50 | )
51 |
52 | def register_reference_hooks(
53 | self,
54 | mode,
55 | do_classifier_free_guidance,
56 | attention_auto_machine_weight,
57 | gn_auto_machine_weight,
58 | style_fidelity,
59 | reference_attn,
60 | reference_adain,
61 | dtype=torch.float16,
62 | batch_size=1,
63 | num_images_per_prompt=1,
64 | device=torch.device("cpu"),
65 | fusion_blocks="midup",
66 | ):
67 | MODE = mode
68 | do_classifier_free_guidance = do_classifier_free_guidance
69 | attention_auto_machine_weight = attention_auto_machine_weight
70 | gn_auto_machine_weight = gn_auto_machine_weight
71 | style_fidelity = style_fidelity
72 | reference_attn = reference_attn
73 | reference_adain = reference_adain
74 | fusion_blocks = fusion_blocks
75 | num_images_per_prompt = num_images_per_prompt
76 | dtype = dtype
77 | if do_classifier_free_guidance:
78 | uc_mask = (
79 | torch.Tensor(
80 | [1] * batch_size * num_images_per_prompt * 16
81 | + [0] * batch_size * num_images_per_prompt * 16
82 | )
83 | .to(device)
84 | .bool()
85 | )
86 | else:
87 | uc_mask = (
88 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89 | .to(device)
90 | .bool()
91 | )
92 |
93 | def hacked_basic_transformer_inner_forward(
94 | self,
95 | hidden_states: torch.FloatTensor,
96 | attention_mask: Optional[torch.FloatTensor] = None,
97 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
98 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
99 | timestep: Optional[torch.LongTensor] = None,
100 | cross_attention_kwargs: Dict[str, Any] = None,
101 | class_labels: Optional[torch.LongTensor] = None,
102 | video_length=None,
103 | self_attention_additional_feats=None,
104 | mode=None,
105 | ):
106 | if self.use_ada_layer_norm: # False
107 | norm_hidden_states = self.norm1(hidden_states, timestep)
108 | elif self.use_ada_layer_norm_zero:
109 | (
110 | norm_hidden_states,
111 | gate_msa,
112 | shift_mlp,
113 | scale_mlp,
114 | gate_mlp,
115 | ) = self.norm1(
116 | hidden_states,
117 | timestep,
118 | class_labels,
119 | hidden_dtype=hidden_states.dtype,
120 | )
121 | else:
122 | norm_hidden_states = self.norm1(hidden_states)
123 |
124 | # 1. Self-Attention
125 | # self.only_cross_attention = False
126 | cross_attention_kwargs = (
127 | cross_attention_kwargs if cross_attention_kwargs is not None else {}
128 | )
129 | if self.only_cross_attention:
130 | attn_output = self.attn1(
131 | norm_hidden_states,
132 | encoder_hidden_states=encoder_hidden_states
133 | if self.only_cross_attention
134 | else None,
135 | attention_mask=attention_mask,
136 | **cross_attention_kwargs,
137 | )
138 | else:
139 | if MODE == "write":
140 | self.bank.append(norm_hidden_states.clone())
141 | attn_output = self.attn1(
142 | norm_hidden_states,
143 | encoder_hidden_states=encoder_hidden_states
144 | if self.only_cross_attention
145 | else None,
146 | attention_mask=attention_mask,
147 | **cross_attention_kwargs,
148 | )
149 | if MODE == "read":
150 | bank_fea = [
151 | rearrange(
152 | d.unsqueeze(1).repeat(1, video_length, 1, 1),
153 | "b t l c -> (b t) l c",
154 | )
155 | for d in self.bank
156 | ]
157 | modify_norm_hidden_states = torch.cat(
158 | [norm_hidden_states] + bank_fea, dim=1
159 | )
160 | hidden_states_uc = (
161 | self.attn1(
162 | norm_hidden_states,
163 | encoder_hidden_states=modify_norm_hidden_states,
164 | attention_mask=attention_mask,
165 | )
166 | + hidden_states
167 | )
168 | if do_classifier_free_guidance:
169 | hidden_states_c = hidden_states_uc.clone()
170 | _uc_mask = uc_mask.clone()
171 | if hidden_states.shape[0] != _uc_mask.shape[0]:
172 | _uc_mask = (
173 | torch.Tensor(
174 | [1] * (hidden_states.shape[0] // 2)
175 | + [0] * (hidden_states.shape[0] // 2)
176 | )
177 | .to(device)
178 | .bool()
179 | )
180 | hidden_states_c[_uc_mask] = (
181 | self.attn1(
182 | norm_hidden_states[_uc_mask],
183 | encoder_hidden_states=norm_hidden_states[_uc_mask],
184 | attention_mask=attention_mask,
185 | )
186 | + hidden_states[_uc_mask]
187 | )
188 | hidden_states = hidden_states_c.clone()
189 | else:
190 | hidden_states = hidden_states_uc
191 |
192 | # self.bank.clear()
193 | if self.attn2 is not None:
194 | # Cross-Attention
195 | norm_hidden_states = (
196 | self.norm2(hidden_states, timestep)
197 | if self.use_ada_layer_norm
198 | else self.norm2(hidden_states)
199 | )
200 | hidden_states = (
201 | self.attn2(
202 | norm_hidden_states,
203 | encoder_hidden_states=encoder_hidden_states,
204 | attention_mask=attention_mask,
205 | )
206 | + hidden_states
207 | )
208 |
209 | # Feed-forward
210 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
211 |
212 | # Temporal-Attention
213 | if self.unet_use_temporal_attention:
214 | d = hidden_states.shape[1]
215 | hidden_states = rearrange(
216 | hidden_states, "(b f) d c -> (b d) f c", f=video_length
217 | )
218 | norm_hidden_states = (
219 | self.norm_temp(hidden_states, timestep)
220 | if self.use_ada_layer_norm
221 | else self.norm_temp(hidden_states)
222 | )
223 | hidden_states = (
224 | self.attn_temp(norm_hidden_states) + hidden_states
225 | )
226 | hidden_states = rearrange(
227 | hidden_states, "(b d) f c -> (b f) d c", d=d
228 | )
229 |
230 | return hidden_states
231 |
232 | if self.use_ada_layer_norm_zero:
233 | attn_output = gate_msa.unsqueeze(1) * attn_output
234 | hidden_states = attn_output + hidden_states
235 |
236 | if self.attn2 is not None:
237 | norm_hidden_states = (
238 | self.norm2(hidden_states, timestep)
239 | if self.use_ada_layer_norm
240 | else self.norm2(hidden_states)
241 | )
242 |
243 | # 2. Cross-Attention
244 | attn_output = self.attn2(
245 | norm_hidden_states,
246 | encoder_hidden_states=encoder_hidden_states,
247 | attention_mask=encoder_attention_mask,
248 | **cross_attention_kwargs,
249 | )
250 | hidden_states = attn_output + hidden_states
251 |
252 | # 3. Feed-forward
253 | norm_hidden_states = self.norm3(hidden_states)
254 |
255 | if self.use_ada_layer_norm_zero:
256 | norm_hidden_states = (
257 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
258 | )
259 |
260 | ff_output = self.ff(norm_hidden_states)
261 |
262 | if self.use_ada_layer_norm_zero:
263 | ff_output = gate_mlp.unsqueeze(1) * ff_output
264 |
265 | hidden_states = ff_output + hidden_states
266 |
267 | return hidden_states
268 |
269 | if self.reference_attn:
270 | if self.fusion_blocks == "midup":
271 | attn_modules = [
272 | module
273 | for module in (
274 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
275 | )
276 | if isinstance(module, BasicTransformerBlock)
277 | or isinstance(module, TemporalBasicTransformerBlock)
278 | ]
279 | elif self.fusion_blocks == "full":
280 | attn_modules = [
281 | module
282 | for module in torch_dfs(self.unet)
283 | if isinstance(module, BasicTransformerBlock)
284 | or isinstance(module, TemporalBasicTransformerBlock)
285 | ]
286 | attn_modules = sorted(
287 | attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
288 | )
289 |
290 | for i, module in enumerate(attn_modules):
291 | module._original_inner_forward = module.forward
292 | if isinstance(module, BasicTransformerBlock):
293 | module.forward = hacked_basic_transformer_inner_forward.__get__(
294 | module, BasicTransformerBlock
295 | )
296 | if isinstance(module, TemporalBasicTransformerBlock):
297 | module.forward = hacked_basic_transformer_inner_forward.__get__(
298 | module, TemporalBasicTransformerBlock
299 | )
300 |
301 | module.bank = []
302 | module.attn_weight = float(i) / float(len(attn_modules))
303 |
304 | def update(self, writer, dtype=torch.float16):
305 | if self.reference_attn:
306 | if self.fusion_blocks == "midup":
307 | reader_attn_modules = [
308 | module
309 | for module in (
310 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
311 | )
312 | if isinstance(module, TemporalBasicTransformerBlock)
313 | ]
314 | writer_attn_modules = [
315 | module
316 | for module in (
317 | torch_dfs(writer.unet.mid_block)
318 | + torch_dfs(writer.unet.up_blocks)
319 | )
320 | if isinstance(module, BasicTransformerBlock)
321 | ]
322 | elif self.fusion_blocks == "full":
323 | reader_attn_modules = [
324 | module
325 | for module in torch_dfs(self.unet)
326 | if isinstance(module, TemporalBasicTransformerBlock)
327 | ]
328 | writer_attn_modules = [
329 | module
330 | for module in torch_dfs(writer.unet)
331 | if isinstance(module, BasicTransformerBlock)
332 | ]
333 | reader_attn_modules = sorted(
334 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
335 | )
336 | writer_attn_modules = sorted(
337 | writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
338 | )
339 | for r, w in zip(reader_attn_modules, writer_attn_modules):
340 | r.bank = [v.clone().to(dtype) for v in w.bank]
341 | # w.bank.clear()
342 |
343 | def clear(self):
344 | if self.reference_attn:
345 | if self.fusion_blocks == "midup":
346 | reader_attn_modules = [
347 | module
348 | for module in (
349 | torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
350 | )
351 | if isinstance(module, BasicTransformerBlock)
352 | or isinstance(module, TemporalBasicTransformerBlock)
353 | ]
354 | elif self.fusion_blocks == "full":
355 | reader_attn_modules = [
356 | module
357 | for module in torch_dfs(self.unet)
358 | if isinstance(module, BasicTransformerBlock)
359 | or isinstance(module, TemporalBasicTransformerBlock)
360 | ]
361 | reader_attn_modules = sorted(
362 | reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
363 | )
364 | for r in reader_attn_modules:
365 | r.bank.clear()
366 |
--------------------------------------------------------------------------------
/src/models/pose_guider.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 | from diffusers.models.modeling_utils import ModelMixin
7 |
8 | from src.models.motion_module import zero_module
9 | from src.models.resnet import InflatedConv3d
10 |
11 |
12 | class PoseGuider(ModelMixin):
13 | def __init__(
14 | self,
15 | conditioning_embedding_channels: int,
16 | conditioning_channels: int = 3,
17 | block_out_channels: Tuple[int] = (16, 32, 64, 128),
18 | ):
19 | super().__init__()
20 | self.conv_in = InflatedConv3d(
21 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22 | )
23 |
24 | self.blocks = nn.ModuleList([])
25 |
26 | for i in range(len(block_out_channels) - 1):
27 | channel_in = block_out_channels[i]
28 | channel_out = block_out_channels[i + 1]
29 | self.blocks.append(
30 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31 | )
32 | self.blocks.append(
33 | InflatedConv3d(
34 | channel_in, channel_out, kernel_size=3, padding=1, stride=2
35 | )
36 | )
37 |
38 | self.conv_out = zero_module(
39 | InflatedConv3d(
40 | block_out_channels[-1],
41 | conditioning_embedding_channels,
42 | kernel_size=3,
43 | padding=1,
44 | )
45 | )
46 |
47 | def forward(self, conditioning):
48 | embedding = self.conv_in(conditioning)
49 | embedding = F.silu(embedding)
50 |
51 | for block in self.blocks:
52 | embedding = block(embedding)
53 | embedding = F.silu(embedding)
54 |
55 | embedding = self.conv_out(embedding)
56 |
57 | return embedding
58 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | class InflatedConv3d(nn.Conv2d):
10 | def forward(self, x):
11 | video_length = x.shape[2]
12 |
13 | x = rearrange(x, "b c f h w -> (b f) c h w")
14 | x = super().forward(x)
15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16 |
17 | return x
18 |
19 |
20 | class InflatedGroupNorm(nn.GroupNorm):
21 | def forward(self, x):
22 | video_length = x.shape[2]
23 |
24 | x = rearrange(x, "b c f h w -> (b f) c h w")
25 | x = super().forward(x)
26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27 |
28 | return x
29 |
30 |
31 | class Upsample3D(nn.Module):
32 | def __init__(
33 | self,
34 | channels,
35 | use_conv=False,
36 | use_conv_transpose=False,
37 | out_channels=None,
38 | name="conv",
39 | ):
40 | super().__init__()
41 | self.channels = channels
42 | self.out_channels = out_channels or channels
43 | self.use_conv = use_conv
44 | self.use_conv_transpose = use_conv_transpose
45 | self.name = name
46 |
47 | conv = None
48 | if use_conv_transpose:
49 | raise NotImplementedError
50 | elif use_conv:
51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52 |
53 | def forward(self, hidden_states, output_size=None):
54 | assert hidden_states.shape[1] == self.channels
55 |
56 | if self.use_conv_transpose:
57 | raise NotImplementedError
58 |
59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60 | dtype = hidden_states.dtype
61 | if dtype == torch.bfloat16:
62 | hidden_states = hidden_states.to(torch.float32)
63 |
64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65 | if hidden_states.shape[0] >= 64:
66 | hidden_states = hidden_states.contiguous()
67 |
68 | # if `output_size` is passed we force the interpolation output
69 | # size and do not make use of `scale_factor=2`
70 | if output_size is None:
71 | hidden_states = F.interpolate(
72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73 | )
74 | else:
75 | hidden_states = F.interpolate(
76 | hidden_states, size=output_size, mode="nearest"
77 | )
78 |
79 | # If the input is bfloat16, we cast back to bfloat16
80 | if dtype == torch.bfloat16:
81 | hidden_states = hidden_states.to(dtype)
82 |
83 | # if self.use_conv:
84 | # if self.name == "conv":
85 | # hidden_states = self.conv(hidden_states)
86 | # else:
87 | # hidden_states = self.Conv2d_0(hidden_states)
88 | hidden_states = self.conv(hidden_states)
89 |
90 | return hidden_states
91 |
92 |
93 | class Downsample3D(nn.Module):
94 | def __init__(
95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96 | ):
97 | super().__init__()
98 | self.channels = channels
99 | self.out_channels = out_channels or channels
100 | self.use_conv = use_conv
101 | self.padding = padding
102 | stride = 2
103 | self.name = name
104 |
105 | if use_conv:
106 | self.conv = InflatedConv3d(
107 | self.channels, self.out_channels, 3, stride=stride, padding=padding
108 | )
109 | else:
110 | raise NotImplementedError
111 |
112 | def forward(self, hidden_states):
113 | assert hidden_states.shape[1] == self.channels
114 | if self.use_conv and self.padding == 0:
115 | raise NotImplementedError
116 |
117 | assert hidden_states.shape[1] == self.channels
118 | hidden_states = self.conv(hidden_states)
119 |
120 | return hidden_states
121 |
122 |
123 | class ResnetBlock3D(nn.Module):
124 | def __init__(
125 | self,
126 | *,
127 | in_channels,
128 | out_channels=None,
129 | conv_shortcut=False,
130 | dropout=0.0,
131 | temb_channels=512,
132 | groups=32,
133 | groups_out=None,
134 | pre_norm=True,
135 | eps=1e-6,
136 | non_linearity="swish",
137 | time_embedding_norm="default",
138 | output_scale_factor=1.0,
139 | use_in_shortcut=None,
140 | use_inflated_groupnorm=None,
141 | ):
142 | super().__init__()
143 | self.pre_norm = pre_norm
144 | self.pre_norm = True
145 | self.in_channels = in_channels
146 | out_channels = in_channels if out_channels is None else out_channels
147 | self.out_channels = out_channels
148 | self.use_conv_shortcut = conv_shortcut
149 | self.time_embedding_norm = time_embedding_norm
150 | self.output_scale_factor = output_scale_factor
151 |
152 | if groups_out is None:
153 | groups_out = groups
154 |
155 | assert use_inflated_groupnorm != None
156 | if use_inflated_groupnorm:
157 | self.norm1 = InflatedGroupNorm(
158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159 | )
160 | else:
161 | self.norm1 = torch.nn.GroupNorm(
162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163 | )
164 |
165 | self.conv1 = InflatedConv3d(
166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
167 | )
168 |
169 | if temb_channels is not None:
170 | if self.time_embedding_norm == "default":
171 | time_emb_proj_out_channels = out_channels
172 | elif self.time_embedding_norm == "scale_shift":
173 | time_emb_proj_out_channels = out_channels * 2
174 | else:
175 | raise ValueError(
176 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
177 | )
178 |
179 | self.time_emb_proj = torch.nn.Linear(
180 | temb_channels, time_emb_proj_out_channels
181 | )
182 | else:
183 | self.time_emb_proj = None
184 |
185 | if use_inflated_groupnorm:
186 | self.norm2 = InflatedGroupNorm(
187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188 | )
189 | else:
190 | self.norm2 = torch.nn.GroupNorm(
191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192 | )
193 | self.dropout = torch.nn.Dropout(dropout)
194 | self.conv2 = InflatedConv3d(
195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
196 | )
197 |
198 | if non_linearity == "swish":
199 | self.nonlinearity = lambda x: F.silu(x)
200 | elif non_linearity == "mish":
201 | self.nonlinearity = Mish()
202 | elif non_linearity == "silu":
203 | self.nonlinearity = nn.SiLU()
204 |
205 | self.use_in_shortcut = (
206 | self.in_channels != self.out_channels
207 | if use_in_shortcut is None
208 | else use_in_shortcut
209 | )
210 |
211 | self.conv_shortcut = None
212 | if self.use_in_shortcut:
213 | self.conv_shortcut = InflatedConv3d(
214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
215 | )
216 |
217 | def forward(self, input_tensor, temb):
218 | hidden_states = input_tensor
219 |
220 | hidden_states = self.norm1(hidden_states)
221 | hidden_states = self.nonlinearity(hidden_states)
222 |
223 | hidden_states = self.conv1(hidden_states)
224 |
225 | if temb is not None:
226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227 |
228 | if temb is not None and self.time_embedding_norm == "default":
229 | hidden_states = hidden_states + temb
230 |
231 | hidden_states = self.norm2(hidden_states)
232 |
233 | if temb is not None and self.time_embedding_norm == "scale_shift":
234 | scale, shift = torch.chunk(temb, 2, dim=1)
235 | hidden_states = hidden_states * (1 + scale) + shift
236 |
237 | hidden_states = self.nonlinearity(hidden_states)
238 |
239 | hidden_states = self.dropout(hidden_states)
240 | hidden_states = self.conv2(hidden_states)
241 |
242 | if self.conv_shortcut is not None:
243 | input_tensor = self.conv_shortcut(input_tensor)
244 |
245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246 |
247 | return output_tensor
248 |
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
--------------------------------------------------------------------------------
/src/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | name=None,
49 | ):
50 | super().__init__()
51 | self.use_linear_projection = use_linear_projection
52 | self.num_attention_heads = num_attention_heads
53 | self.attention_head_dim = attention_head_dim
54 | inner_dim = num_attention_heads * attention_head_dim
55 |
56 | # Define input layers
57 | self.in_channels = in_channels
58 | self.name=name
59 |
60 | self.norm = torch.nn.GroupNorm(
61 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
62 | )
63 | if use_linear_projection:
64 | self.proj_in = nn.Linear(in_channels, inner_dim)
65 | else:
66 | self.proj_in = nn.Conv2d(
67 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
68 | )
69 |
70 | # Define transformers blocks
71 | self.transformer_blocks = nn.ModuleList(
72 | [
73 | TemporalBasicTransformerBlock(
74 | inner_dim,
75 | num_attention_heads,
76 | attention_head_dim,
77 | dropout=dropout,
78 | cross_attention_dim=cross_attention_dim,
79 | activation_fn=activation_fn,
80 | num_embeds_ada_norm=num_embeds_ada_norm,
81 | attention_bias=attention_bias,
82 | only_cross_attention=only_cross_attention,
83 | upcast_attention=upcast_attention,
84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85 | unet_use_temporal_attention=unet_use_temporal_attention,
86 | name=f"{self.name}_{d}_TransformerBlock" if self.name else None,
87 | )
88 | for d in range(num_layers)
89 | ]
90 | )
91 |
92 | # 4. Define output layers
93 | if use_linear_projection:
94 | self.proj_out = nn.Linear(in_channels, inner_dim)
95 | else:
96 | self.proj_out = nn.Conv2d(
97 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
98 | )
99 |
100 | self.gradient_checkpointing = False
101 |
102 | def _set_gradient_checkpointing(self, module, value=False):
103 | if hasattr(module, "gradient_checkpointing"):
104 | module.gradient_checkpointing = value
105 |
106 | def forward(
107 | self,
108 | hidden_states,
109 | encoder_hidden_states=None,
110 | self_attention_additional_feats=None,
111 | mode=None,
112 | timestep=None,
113 | return_dict: bool = True,
114 | ):
115 | # Input
116 | assert (
117 | hidden_states.dim() == 5
118 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
119 | video_length = hidden_states.shape[2]
120 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
121 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
122 | encoder_hidden_states = repeat(
123 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
124 | )
125 |
126 | batch, channel, height, weight = hidden_states.shape
127 | residual = hidden_states
128 |
129 | hidden_states = self.norm(hidden_states)
130 | if not self.use_linear_projection:
131 | hidden_states = self.proj_in(hidden_states)
132 | inner_dim = hidden_states.shape[1]
133 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134 | batch, height * weight, inner_dim
135 | )
136 | else:
137 | inner_dim = hidden_states.shape[1]
138 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
139 | batch, height * weight, inner_dim
140 | )
141 | hidden_states = self.proj_in(hidden_states)
142 |
143 | # Blocks
144 | for i, block in enumerate(self.transformer_blocks):
145 |
146 | if self.training and self.gradient_checkpointing:
147 |
148 | def create_custom_forward(module, return_dict=None):
149 | def custom_forward(*inputs):
150 | if return_dict is not None:
151 | return module(*inputs, return_dict=return_dict)
152 | else:
153 | return module(*inputs)
154 |
155 | return custom_forward
156 |
157 | # if hasattr(self.block, 'bank') and len(self.block.bank) > 0:
158 | # hidden_states
159 | hidden_states = torch.utils.checkpoint.checkpoint(
160 | create_custom_forward(block),
161 | hidden_states,
162 | encoder_hidden_states=encoder_hidden_states,
163 | timestep=timestep,
164 | attention_mask=None,
165 | video_length=video_length,
166 | self_attention_additional_feats=self_attention_additional_feats,
167 | mode=mode,
168 | )
169 | else:
170 |
171 | hidden_states = block(
172 | hidden_states,
173 | encoder_hidden_states=encoder_hidden_states,
174 | timestep=timestep,
175 | self_attention_additional_feats=self_attention_additional_feats,
176 | mode=mode,
177 | video_length=video_length,
178 | )
179 |
180 | # Output
181 | if not self.use_linear_projection:
182 | hidden_states = (
183 | hidden_states.reshape(batch, height, weight, inner_dim)
184 | .permute(0, 3, 1, 2)
185 | .contiguous()
186 | )
187 | hidden_states = self.proj_out(hidden_states)
188 | else:
189 | hidden_states = self.proj_out(hidden_states)
190 | hidden_states = (
191 | hidden_states.reshape(batch, height, weight, inner_dim)
192 | .permute(0, 3, 1, 2)
193 | .contiguous()
194 | )
195 |
196 | output = hidden_states + residual
197 |
198 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
199 | if not return_dict:
200 | return (output,)
201 |
202 | return Transformer3DModelOutput(sample=output)
203 |
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MooreThreads/Moore-AnimateAnyone/a914ef38aae3733c2f02f29853dd0593372e0cc9/src/pipelines/__init__.py
--------------------------------------------------------------------------------
/src/pipelines/context.py:
--------------------------------------------------------------------------------
1 | # TODO: Adapted from cli
2 | from typing import Callable, List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | def uniform(
16 | step: int = ...,
17 | num_steps: Optional[int] = None,
18 | num_frames: int = ...,
19 | context_size: Optional[int] = None,
20 | context_stride: int = 3,
21 | context_overlap: int = 4,
22 | closed_loop: bool = True,
23 | ):
24 | if num_frames <= context_size:
25 | yield list(range(num_frames))
26 | return
27 |
28 | context_stride = min(
29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30 | )
31 |
32 | for context_step in 1 << np.arange(context_stride):
33 | pad = int(round(num_frames * ordered_halving(step)))
34 | for j in range(
35 | int(ordered_halving(step) * context_step) + pad,
36 | num_frames + pad + (0 if closed_loop else -context_overlap),
37 | (context_size * context_step - context_overlap),
38 | ):
39 | yield [
40 | e % num_frames
41 | for e in range(j, j + context_size * context_step, context_step)
42 | ]
43 |
44 |
45 | def get_context_scheduler(name: str) -> Callable:
46 | if name == "uniform":
47 | return uniform
48 | else:
49 | raise ValueError(f"Unknown context_overlap policy {name}")
50 |
51 |
52 | def get_total_steps(
53 | scheduler,
54 | timesteps: List[int],
55 | num_steps: Optional[int] = None,
56 | num_frames: int = ...,
57 | context_size: Optional[int] = None,
58 | context_stride: int = 3,
59 | context_overlap: int = 4,
60 | closed_loop: bool = True,
61 | ):
62 | return sum(
63 | len(
64 | list(
65 | scheduler(
66 | i,
67 | num_steps,
68 | num_frames,
69 | context_size,
70 | context_stride,
71 | context_overlap,
72 | )
73 | )
74 | )
75 | for i in range(len(timesteps))
76 | )
77 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline_pose2img.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers import DiffusionPipeline
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers.schedulers import (
10 | DDIMScheduler,
11 | DPMSolverMultistepScheduler,
12 | EulerAncestralDiscreteScheduler,
13 | EulerDiscreteScheduler,
14 | LMSDiscreteScheduler,
15 | PNDMScheduler,
16 | )
17 | from diffusers.utils import BaseOutput, is_accelerate_available
18 | from diffusers.utils.torch_utils import randn_tensor
19 | from einops import rearrange
20 | from tqdm import tqdm
21 | from transformers import CLIPImageProcessor
22 |
23 | from src.models.mutual_self_attention import ReferenceAttentionControl
24 |
25 |
26 | @dataclass
27 | class Pose2ImagePipelineOutput(BaseOutput):
28 | images: Union[torch.Tensor, np.ndarray]
29 |
30 |
31 | class Pose2ImagePipeline(DiffusionPipeline):
32 | _optional_components = []
33 |
34 | def __init__(
35 | self,
36 | vae,
37 | image_encoder,
38 | reference_unet,
39 | denoising_unet,
40 | pose_guider,
41 | scheduler: Union[
42 | DDIMScheduler,
43 | PNDMScheduler,
44 | LMSDiscreteScheduler,
45 | EulerDiscreteScheduler,
46 | EulerAncestralDiscreteScheduler,
47 | DPMSolverMultistepScheduler,
48 | ],
49 | ):
50 | super().__init__()
51 |
52 | self.register_modules(
53 | vae=vae,
54 | image_encoder=image_encoder,
55 | reference_unet=reference_unet,
56 | denoising_unet=denoising_unet,
57 | pose_guider=pose_guider,
58 | scheduler=scheduler,
59 | )
60 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
61 | self.clip_image_processor = CLIPImageProcessor()
62 | self.ref_image_processor = VaeImageProcessor(
63 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
64 | )
65 | self.cond_image_processor = VaeImageProcessor(
66 | vae_scale_factor=self.vae_scale_factor,
67 | do_convert_rgb=True,
68 | do_normalize=False,
69 | )
70 |
71 | def enable_vae_slicing(self):
72 | self.vae.enable_slicing()
73 |
74 | def disable_vae_slicing(self):
75 | self.vae.disable_slicing()
76 |
77 | def enable_sequential_cpu_offload(self, gpu_id=0):
78 | if is_accelerate_available():
79 | from accelerate import cpu_offload
80 | else:
81 | raise ImportError("Please install accelerate via `pip install accelerate`")
82 |
83 | device = torch.device(f"cuda:{gpu_id}")
84 |
85 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
86 | if cpu_offloaded_model is not None:
87 | cpu_offload(cpu_offloaded_model, device)
88 |
89 | @property
90 | def _execution_device(self):
91 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
92 | return self.device
93 | for module in self.unet.modules():
94 | if (
95 | hasattr(module, "_hf_hook")
96 | and hasattr(module._hf_hook, "execution_device")
97 | and module._hf_hook.execution_device is not None
98 | ):
99 | return torch.device(module._hf_hook.execution_device)
100 | return self.device
101 |
102 | def decode_latents(self, latents):
103 | video_length = latents.shape[2]
104 | latents = 1 / 0.18215 * latents
105 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
106 | # video = self.vae.decode(latents).sample
107 | video = []
108 | for frame_idx in tqdm(range(latents.shape[0])):
109 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
110 | video = torch.cat(video)
111 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
112 | video = (video / 2 + 0.5).clamp(0, 1)
113 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
114 | video = video.cpu().float().numpy()
115 | return video
116 |
117 | def prepare_extra_step_kwargs(self, generator, eta):
118 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
119 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
120 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
121 | # and should be between [0, 1]
122 |
123 | accepts_eta = "eta" in set(
124 | inspect.signature(self.scheduler.step).parameters.keys()
125 | )
126 | extra_step_kwargs = {}
127 | if accepts_eta:
128 | extra_step_kwargs["eta"] = eta
129 |
130 | # check if the scheduler accepts generator
131 | accepts_generator = "generator" in set(
132 | inspect.signature(self.scheduler.step).parameters.keys()
133 | )
134 | if accepts_generator:
135 | extra_step_kwargs["generator"] = generator
136 | return extra_step_kwargs
137 |
138 | def prepare_latents(
139 | self,
140 | batch_size,
141 | num_channels_latents,
142 | width,
143 | height,
144 | dtype,
145 | device,
146 | generator,
147 | latents=None,
148 | ):
149 | shape = (
150 | batch_size,
151 | num_channels_latents,
152 | height // self.vae_scale_factor,
153 | width // self.vae_scale_factor,
154 | )
155 | if isinstance(generator, list) and len(generator) != batch_size:
156 | raise ValueError(
157 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159 | )
160 |
161 | if latents is None:
162 | latents = randn_tensor(
163 | shape, generator=generator, device=device, dtype=dtype
164 | )
165 | else:
166 | latents = latents.to(device)
167 |
168 | # scale the initial noise by the standard deviation required by the scheduler
169 | latents = latents * self.scheduler.init_noise_sigma
170 | return latents
171 |
172 | def prepare_condition(
173 | self,
174 | cond_image,
175 | width,
176 | height,
177 | device,
178 | dtype,
179 | do_classififer_free_guidance=False,
180 | ):
181 | image = self.cond_image_processor.preprocess(
182 | cond_image, height=height, width=width
183 | ).to(dtype=torch.float32)
184 |
185 | image = image.to(device=device, dtype=dtype)
186 |
187 | if do_classififer_free_guidance:
188 | image = torch.cat([image] * 2)
189 |
190 | return image
191 |
192 | @torch.no_grad()
193 | def __call__(
194 | self,
195 | ref_image,
196 | pose_image,
197 | width,
198 | height,
199 | num_inference_steps,
200 | guidance_scale,
201 | num_images_per_prompt=1,
202 | eta: float = 0.0,
203 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204 | output_type: Optional[str] = "tensor",
205 | return_dict: bool = True,
206 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
207 | callback_steps: Optional[int] = 1,
208 | **kwargs,
209 | ):
210 | # Default height and width to unet
211 | height = height or self.unet.config.sample_size * self.vae_scale_factor
212 | width = width or self.unet.config.sample_size * self.vae_scale_factor
213 |
214 | device = self._execution_device
215 |
216 | do_classifier_free_guidance = guidance_scale > 1.0
217 |
218 | # Prepare timesteps
219 | self.scheduler.set_timesteps(num_inference_steps, device=device)
220 | timesteps = self.scheduler.timesteps
221 |
222 | batch_size = 1
223 |
224 | # Prepare clip image embeds
225 | clip_image = self.clip_image_processor.preprocess(
226 | ref_image.resize((224, 224)), return_tensors="pt"
227 | ).pixel_values
228 | clip_image_embeds = self.image_encoder(
229 | clip_image.to(device, dtype=self.image_encoder.dtype)
230 | ).image_embeds
231 | image_prompt_embeds = clip_image_embeds.unsqueeze(1)
232 | uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
233 |
234 | if do_classifier_free_guidance:
235 | image_prompt_embeds = torch.cat(
236 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
237 | )
238 |
239 | reference_control_writer = ReferenceAttentionControl(
240 | self.reference_unet,
241 | do_classifier_free_guidance=do_classifier_free_guidance,
242 | mode="write",
243 | batch_size=batch_size,
244 | fusion_blocks="full",
245 | )
246 | reference_control_reader = ReferenceAttentionControl(
247 | self.denoising_unet,
248 | do_classifier_free_guidance=do_classifier_free_guidance,
249 | mode="read",
250 | batch_size=batch_size,
251 | fusion_blocks="full",
252 | )
253 |
254 | num_channels_latents = self.denoising_unet.in_channels
255 | latents = self.prepare_latents(
256 | batch_size * num_images_per_prompt,
257 | num_channels_latents,
258 | width,
259 | height,
260 | clip_image_embeds.dtype,
261 | device,
262 | generator,
263 | )
264 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
265 | latents_dtype = latents.dtype
266 |
267 | # Prepare extra step kwargs.
268 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269 |
270 | # Prepare ref image latents
271 | ref_image_tensor = self.ref_image_processor.preprocess(
272 | ref_image, height=height, width=width
273 | ) # (bs, c, width, height)
274 | ref_image_tensor = ref_image_tensor.to(
275 | dtype=self.vae.dtype, device=self.vae.device
276 | )
277 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
278 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
279 |
280 | # Prepare pose condition image
281 | pose_cond_tensor = self.cond_image_processor.preprocess(
282 | pose_image, height=height, width=width
283 | )
284 | pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
285 | pose_cond_tensor = pose_cond_tensor.to(
286 | device=device, dtype=self.pose_guider.dtype
287 | )
288 | pose_fea = self.pose_guider(pose_cond_tensor)
289 | pose_fea = (
290 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
291 | )
292 |
293 | # denoising loop
294 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
295 | with self.progress_bar(total=num_inference_steps) as progress_bar:
296 | for i, t in enumerate(timesteps):
297 | # 1. Forward reference image
298 | if i == 0:
299 | self.reference_unet(
300 | ref_image_latents.repeat(
301 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
302 | ),
303 | torch.zeros_like(t),
304 | encoder_hidden_states=image_prompt_embeds,
305 | return_dict=False,
306 | )
307 |
308 | # 2. Update reference unet feature into denosing net
309 | reference_control_reader.update(reference_control_writer)
310 |
311 | # 3.1 expand the latents if we are doing classifier free guidance
312 | latent_model_input = (
313 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
314 | )
315 | latent_model_input = self.scheduler.scale_model_input(
316 | latent_model_input, t
317 | )
318 |
319 | noise_pred = self.denoising_unet(
320 | latent_model_input,
321 | t,
322 | encoder_hidden_states=image_prompt_embeds,
323 | pose_cond_fea=pose_fea,
324 | return_dict=False,
325 | )[0]
326 |
327 | # perform guidance
328 | if do_classifier_free_guidance:
329 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330 | noise_pred = noise_pred_uncond + guidance_scale * (
331 | noise_pred_text - noise_pred_uncond
332 | )
333 |
334 | # compute the previous noisy sample x_t -> x_t-1
335 | latents = self.scheduler.step(
336 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
337 | )[0]
338 |
339 | # call the callback, if provided
340 | if i == len(timesteps) - 1 or (
341 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
342 | ):
343 | progress_bar.update()
344 | if callback is not None and i % callback_steps == 0:
345 | step_idx = i // getattr(self.scheduler, "order", 1)
346 | callback(step_idx, t, latents)
347 | reference_control_reader.clear()
348 | reference_control_writer.clear()
349 |
350 | # Post-processing
351 | image = self.decode_latents(latents) # (b, c, 1, h, w)
352 |
353 | # Convert to tensor
354 | if output_type == "tensor":
355 | image = torch.from_numpy(image)
356 |
357 | if not return_dict:
358 | return image
359 |
360 | return Pose2ImagePipelineOutput(images=image)
361 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline_pose2vid.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers import DiffusionPipeline
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10 | EulerAncestralDiscreteScheduler,
11 | EulerDiscreteScheduler, LMSDiscreteScheduler,
12 | PNDMScheduler)
13 | from diffusers.utils import BaseOutput, is_accelerate_available
14 | from diffusers.utils.torch_utils import randn_tensor
15 | from einops import rearrange
16 | from tqdm import tqdm
17 | from transformers import CLIPImageProcessor
18 |
19 | from src.models.mutual_self_attention import ReferenceAttentionControl
20 |
21 |
22 | @dataclass
23 | class Pose2VideoPipelineOutput(BaseOutput):
24 | videos: Union[torch.Tensor, np.ndarray]
25 |
26 |
27 | class Pose2VideoPipeline(DiffusionPipeline):
28 | _optional_components = []
29 |
30 | def __init__(
31 | self,
32 | vae,
33 | image_encoder,
34 | reference_unet,
35 | denoising_unet,
36 | pose_guider,
37 | scheduler: Union[
38 | DDIMScheduler,
39 | PNDMScheduler,
40 | LMSDiscreteScheduler,
41 | EulerDiscreteScheduler,
42 | EulerAncestralDiscreteScheduler,
43 | DPMSolverMultistepScheduler,
44 | ],
45 | image_proj_model=None,
46 | tokenizer=None,
47 | text_encoder=None,
48 | ):
49 | super().__init__()
50 |
51 | self.register_modules(
52 | vae=vae,
53 | image_encoder=image_encoder,
54 | reference_unet=reference_unet,
55 | denoising_unet=denoising_unet,
56 | pose_guider=pose_guider,
57 | scheduler=scheduler,
58 | image_proj_model=image_proj_model,
59 | tokenizer=tokenizer,
60 | text_encoder=text_encoder,
61 | )
62 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63 | self.clip_image_processor = CLIPImageProcessor()
64 | self.ref_image_processor = VaeImageProcessor(
65 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66 | )
67 | self.cond_image_processor = VaeImageProcessor(
68 | vae_scale_factor=self.vae_scale_factor,
69 | do_convert_rgb=True,
70 | do_normalize=False,
71 | )
72 |
73 | def enable_vae_slicing(self):
74 | self.vae.enable_slicing()
75 |
76 | def disable_vae_slicing(self):
77 | self.vae.disable_slicing()
78 |
79 | def enable_sequential_cpu_offload(self, gpu_id=0):
80 | if is_accelerate_available():
81 | from accelerate import cpu_offload
82 | else:
83 | raise ImportError("Please install accelerate via `pip install accelerate`")
84 |
85 | device = torch.device(f"cuda:{gpu_id}")
86 |
87 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88 | if cpu_offloaded_model is not None:
89 | cpu_offload(cpu_offloaded_model, device)
90 |
91 | @property
92 | def _execution_device(self):
93 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94 | return self.device
95 | for module in self.unet.modules():
96 | if (
97 | hasattr(module, "_hf_hook")
98 | and hasattr(module._hf_hook, "execution_device")
99 | and module._hf_hook.execution_device is not None
100 | ):
101 | return torch.device(module._hf_hook.execution_device)
102 | return self.device
103 |
104 | def decode_latents(self, latents):
105 | video_length = latents.shape[2]
106 | latents = 1 / 0.18215 * latents
107 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
108 | # video = self.vae.decode(latents).sample
109 | video = []
110 | for frame_idx in tqdm(range(latents.shape[0])):
111 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112 | video = torch.cat(video)
113 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114 | video = (video / 2 + 0.5).clamp(0, 1)
115 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116 | video = video.cpu().float().numpy()
117 | return video
118 |
119 | def prepare_extra_step_kwargs(self, generator, eta):
120 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123 | # and should be between [0, 1]
124 |
125 | accepts_eta = "eta" in set(
126 | inspect.signature(self.scheduler.step).parameters.keys()
127 | )
128 | extra_step_kwargs = {}
129 | if accepts_eta:
130 | extra_step_kwargs["eta"] = eta
131 |
132 | # check if the scheduler accepts generator
133 | accepts_generator = "generator" in set(
134 | inspect.signature(self.scheduler.step).parameters.keys()
135 | )
136 | if accepts_generator:
137 | extra_step_kwargs["generator"] = generator
138 | return extra_step_kwargs
139 |
140 | def prepare_latents(
141 | self,
142 | batch_size,
143 | num_channels_latents,
144 | width,
145 | height,
146 | video_length,
147 | dtype,
148 | device,
149 | generator,
150 | latents=None,
151 | ):
152 | shape = (
153 | batch_size,
154 | num_channels_latents,
155 | video_length,
156 | height // self.vae_scale_factor,
157 | width // self.vae_scale_factor,
158 | )
159 | if isinstance(generator, list) and len(generator) != batch_size:
160 | raise ValueError(
161 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
162 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
163 | )
164 |
165 | if latents is None:
166 | latents = randn_tensor(
167 | shape, generator=generator, device=device, dtype=dtype
168 | )
169 | else:
170 | latents = latents.to(device)
171 |
172 | # scale the initial noise by the standard deviation required by the scheduler
173 | latents = latents * self.scheduler.init_noise_sigma
174 | return latents
175 |
176 | def _encode_prompt(
177 | self,
178 | prompt,
179 | device,
180 | num_videos_per_prompt,
181 | do_classifier_free_guidance,
182 | negative_prompt,
183 | ):
184 | batch_size = len(prompt) if isinstance(prompt, list) else 1
185 |
186 | text_inputs = self.tokenizer(
187 | prompt,
188 | padding="max_length",
189 | max_length=self.tokenizer.model_max_length,
190 | truncation=True,
191 | return_tensors="pt",
192 | )
193 | text_input_ids = text_inputs.input_ids
194 | untruncated_ids = self.tokenizer(
195 | prompt, padding="longest", return_tensors="pt"
196 | ).input_ids
197 |
198 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
199 | text_input_ids, untruncated_ids
200 | ):
201 | removed_text = self.tokenizer.batch_decode(
202 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
203 | )
204 |
205 | if (
206 | hasattr(self.text_encoder.config, "use_attention_mask")
207 | and self.text_encoder.config.use_attention_mask
208 | ):
209 | attention_mask = text_inputs.attention_mask.to(device)
210 | else:
211 | attention_mask = None
212 |
213 | text_embeddings = self.text_encoder(
214 | text_input_ids.to(device),
215 | attention_mask=attention_mask,
216 | )
217 | text_embeddings = text_embeddings[0]
218 |
219 | # duplicate text embeddings for each generation per prompt, using mps friendly method
220 | bs_embed, seq_len, _ = text_embeddings.shape
221 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222 | text_embeddings = text_embeddings.view(
223 | bs_embed * num_videos_per_prompt, seq_len, -1
224 | )
225 |
226 | # get unconditional embeddings for classifier free guidance
227 | if do_classifier_free_guidance:
228 | uncond_tokens: List[str]
229 | if negative_prompt is None:
230 | uncond_tokens = [""] * batch_size
231 | elif type(prompt) is not type(negative_prompt):
232 | raise TypeError(
233 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
234 | f" {type(prompt)}."
235 | )
236 | elif isinstance(negative_prompt, str):
237 | uncond_tokens = [negative_prompt]
238 | elif batch_size != len(negative_prompt):
239 | raise ValueError(
240 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
241 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
242 | " the batch size of `prompt`."
243 | )
244 | else:
245 | uncond_tokens = negative_prompt
246 |
247 | max_length = text_input_ids.shape[-1]
248 | uncond_input = self.tokenizer(
249 | uncond_tokens,
250 | padding="max_length",
251 | max_length=max_length,
252 | truncation=True,
253 | return_tensors="pt",
254 | )
255 |
256 | if (
257 | hasattr(self.text_encoder.config, "use_attention_mask")
258 | and self.text_encoder.config.use_attention_mask
259 | ):
260 | attention_mask = uncond_input.attention_mask.to(device)
261 | else:
262 | attention_mask = None
263 |
264 | uncond_embeddings = self.text_encoder(
265 | uncond_input.input_ids.to(device),
266 | attention_mask=attention_mask,
267 | )
268 | uncond_embeddings = uncond_embeddings[0]
269 |
270 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271 | seq_len = uncond_embeddings.shape[1]
272 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273 | uncond_embeddings = uncond_embeddings.view(
274 | batch_size * num_videos_per_prompt, seq_len, -1
275 | )
276 |
277 | # For classifier free guidance, we need to do two forward passes.
278 | # Here we concatenate the unconditional and text embeddings into a single batch
279 | # to avoid doing two forward passes
280 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
281 |
282 | return text_embeddings
283 |
284 | @torch.no_grad()
285 | def __call__(
286 | self,
287 | ref_image,
288 | pose_images,
289 | width,
290 | height,
291 | video_length,
292 | num_inference_steps,
293 | guidance_scale,
294 | num_images_per_prompt=1,
295 | eta: float = 0.0,
296 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297 | output_type: Optional[str] = "tensor",
298 | return_dict: bool = True,
299 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
300 | callback_steps: Optional[int] = 1,
301 | **kwargs,
302 | ):
303 | # Default height and width to unet
304 | height = height or self.unet.config.sample_size * self.vae_scale_factor
305 | width = width or self.unet.config.sample_size * self.vae_scale_factor
306 |
307 | device = self._execution_device
308 |
309 | do_classifier_free_guidance = guidance_scale > 1.0
310 |
311 | # Prepare timesteps
312 | self.scheduler.set_timesteps(num_inference_steps, device=device)
313 | timesteps = self.scheduler.timesteps
314 |
315 | batch_size = 1
316 |
317 | # Prepare clip image embeds
318 | clip_image = self.clip_image_processor.preprocess(
319 | ref_image, return_tensors="pt"
320 | ).pixel_values
321 | clip_image_embeds = self.image_encoder(
322 | clip_image.to(device, dtype=self.image_encoder.dtype)
323 | ).image_embeds
324 | encoder_hidden_states = clip_image_embeds.unsqueeze(1)
325 | uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
326 |
327 | if do_classifier_free_guidance:
328 | encoder_hidden_states = torch.cat(
329 | [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
330 | )
331 | reference_control_writer = ReferenceAttentionControl(
332 | self.reference_unet,
333 | do_classifier_free_guidance=do_classifier_free_guidance,
334 | mode="write",
335 | batch_size=batch_size,
336 | fusion_blocks="full",
337 | )
338 | reference_control_reader = ReferenceAttentionControl(
339 | self.denoising_unet,
340 | do_classifier_free_guidance=do_classifier_free_guidance,
341 | mode="read",
342 | batch_size=batch_size,
343 | fusion_blocks="full",
344 | )
345 |
346 | num_channels_latents = self.denoising_unet.in_channels
347 | latents = self.prepare_latents(
348 | batch_size * num_images_per_prompt,
349 | num_channels_latents,
350 | width,
351 | height,
352 | video_length,
353 | clip_image_embeds.dtype,
354 | device,
355 | generator,
356 | )
357 |
358 | # Prepare extra step kwargs.
359 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
360 |
361 | # Prepare ref image latents
362 | ref_image_tensor = self.ref_image_processor.preprocess(
363 | ref_image, height=height, width=width
364 | ) # (bs, c, width, height)
365 | ref_image_tensor = ref_image_tensor.to(
366 | dtype=self.vae.dtype, device=self.vae.device
367 | )
368 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
369 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
370 |
371 | # Prepare a list of pose condition images
372 | pose_cond_tensor_list = []
373 | for pose_image in pose_images:
374 | pose_cond_tensor = (
375 | torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
376 | )
377 | pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
378 | 1
379 | ) # (c, 1, h, w)
380 | pose_cond_tensor_list.append(pose_cond_tensor)
381 | pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
382 | pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
383 | pose_cond_tensor = pose_cond_tensor.to(
384 | device=device, dtype=self.pose_guider.dtype
385 | )
386 | pose_fea = self.pose_guider(pose_cond_tensor)
387 | pose_fea = (
388 | torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
389 | )
390 |
391 | # denoising loop
392 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
393 | with self.progress_bar(total=num_inference_steps) as progress_bar:
394 | for i, t in enumerate(timesteps):
395 | # 1. Forward reference image
396 | if i == 0:
397 | self.reference_unet(
398 | ref_image_latents.repeat(
399 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
400 | ),
401 | torch.zeros_like(t),
402 | # t,
403 | encoder_hidden_states=encoder_hidden_states,
404 | return_dict=False,
405 | )
406 | reference_control_reader.update(reference_control_writer)
407 |
408 | # 3.1 expand the latents if we are doing classifier free guidance
409 | latent_model_input = (
410 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
411 | )
412 | latent_model_input = self.scheduler.scale_model_input(
413 | latent_model_input, t
414 | )
415 |
416 | noise_pred = self.denoising_unet(
417 | latent_model_input,
418 | t,
419 | encoder_hidden_states=encoder_hidden_states,
420 | pose_cond_fea=pose_fea,
421 | return_dict=False,
422 | )[0]
423 |
424 | # perform guidance
425 | if do_classifier_free_guidance:
426 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427 | noise_pred = noise_pred_uncond + guidance_scale * (
428 | noise_pred_text - noise_pred_uncond
429 | )
430 |
431 | # compute the previous noisy sample x_t -> x_t-1
432 | latents = self.scheduler.step(
433 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
434 | )[0]
435 |
436 | # call the callback, if provided
437 | if i == len(timesteps) - 1 or (
438 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
439 | ):
440 | progress_bar.update()
441 | if callback is not None and i % callback_steps == 0:
442 | step_idx = i // getattr(self.scheduler, "order", 1)
443 | callback(step_idx, t, latents)
444 |
445 | reference_control_reader.clear()
446 | reference_control_writer.clear()
447 |
448 | # Post-processing
449 | images = self.decode_latents(latents) # (b, c, f, h, w)
450 |
451 | # Convert to tensor
452 | if output_type == "tensor":
453 | images = torch.from_numpy(images)
454 |
455 | if not return_dict:
456 | return images
457 |
458 | return Pose2VideoPipelineOutput(videos=images)
459 |
--------------------------------------------------------------------------------
/src/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | tensor_interpolation = None
4 |
5 |
6 | def get_tensor_interpolation_method():
7 | return tensor_interpolation
8 |
9 |
10 | def set_tensor_interpolation_method(is_slerp):
11 | global tensor_interpolation
12 | tensor_interpolation = slerp if is_slerp else linear
13 |
14 |
15 | def linear(v1, v2, t):
16 | return (1.0 - t) * v1 + t * v2
17 |
18 |
19 | def slerp(
20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21 | ) -> torch.Tensor:
22 | u0 = v0 / v0.norm()
23 | u1 = v1 / v1.norm()
24 | dot = (u0 * u1).sum()
25 | if dot.abs() > DOT_THRESHOLD:
26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27 | return (1.0 - t) * v0 + t * v1
28 | omega = dot.acos()
29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
30 |
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import sys
6 | from pathlib import Path
7 |
8 | import av
9 | import numpy as np
10 | import torch
11 | import torchvision
12 | from einops import rearrange
13 | from PIL import Image
14 |
15 |
16 | def seed_everything(seed):
17 | import random
18 |
19 | import numpy as np
20 |
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed % (2**32))
24 | random.seed(seed)
25 |
26 |
27 | def import_filename(filename):
28 | spec = importlib.util.spec_from_file_location("mymodule", filename)
29 | module = importlib.util.module_from_spec(spec)
30 | sys.modules[spec.name] = module
31 | spec.loader.exec_module(module)
32 | return module
33 |
34 |
35 | def delete_additional_ckpt(base_path, num_keep):
36 | dirs = []
37 | for d in os.listdir(base_path):
38 | if d.startswith("checkpoint-"):
39 | dirs.append(d)
40 | num_tot = len(dirs)
41 | if num_tot <= num_keep:
42 | return
43 | # ensure ckpt is sorted and delete the ealier!
44 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45 | for d in del_dirs:
46 | path_to_dir = osp.join(base_path, d)
47 | if osp.exists(path_to_dir):
48 | shutil.rmtree(path_to_dir)
49 |
50 |
51 | def save_videos_from_pil(pil_images, path, fps=8):
52 | import av
53 |
54 | save_fmt = Path(path).suffix
55 | os.makedirs(os.path.dirname(path), exist_ok=True)
56 | width, height = pil_images[0].size
57 |
58 | if save_fmt == ".mp4":
59 | codec = "libx264"
60 | container = av.open(path, "w")
61 | stream = container.add_stream(codec, rate=fps)
62 |
63 | stream.width = width
64 | stream.height = height
65 |
66 | for pil_image in pil_images:
67 | # pil_image = Image.fromarray(image_arr).convert("RGB")
68 | av_frame = av.VideoFrame.from_image(pil_image)
69 | container.mux(stream.encode(av_frame))
70 | container.mux(stream.encode())
71 | container.close()
72 |
73 | elif save_fmt == ".gif":
74 | pil_images[0].save(
75 | fp=path,
76 | format="GIF",
77 | append_images=pil_images[1:],
78 | save_all=True,
79 | duration=(1 / fps * 1000),
80 | loop=0,
81 | )
82 | else:
83 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
84 |
85 |
86 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
87 | videos = rearrange(videos, "b c t h w -> t b c h w")
88 | height, width = videos.shape[-2:]
89 | outputs = []
90 |
91 | for x in videos:
92 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
93 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
94 | if rescale:
95 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
96 | x = (x * 255).numpy().astype(np.uint8)
97 | x = Image.fromarray(x)
98 |
99 | outputs.append(x)
100 |
101 | os.makedirs(os.path.dirname(path), exist_ok=True)
102 |
103 | save_videos_from_pil(outputs, path, fps)
104 |
105 |
106 | def read_frames(video_path):
107 | container = av.open(video_path)
108 |
109 | video_stream = next(s for s in container.streams if s.type == "video")
110 | frames = []
111 | for packet in container.demux(video_stream):
112 | for frame in packet.decode():
113 | image = Image.frombytes(
114 | "RGB",
115 | (frame.width, frame.height),
116 | frame.to_rgb().to_ndarray(),
117 | )
118 | frames.append(image)
119 |
120 | return frames
121 |
122 |
123 | def get_fps(video_path):
124 | container = av.open(video_path)
125 | video_stream = next(s for s in container.streams if s.type == "video")
126 | fps = video_stream.average_rate
127 | container.close()
128 | return fps
129 |
--------------------------------------------------------------------------------
/tools/download_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path, PurePosixPath
3 |
4 | from huggingface_hub import hf_hub_download
5 |
6 |
7 | def prepare_base_model():
8 | print(f'Preparing base stable-diffusion-v1-5 weights...')
9 | local_dir = "./pretrained_weights/stable-diffusion-v1-5"
10 | os.makedirs(local_dir, exist_ok=True)
11 | for hub_file in ["unet/config.json", "unet/diffusion_pytorch_model.bin"]:
12 | path = Path(hub_file)
13 | saved_path = local_dir / path
14 | if os.path.exists(saved_path):
15 | continue
16 | hf_hub_download(
17 | repo_id="runwayml/stable-diffusion-v1-5",
18 | subfolder=PurePosixPath(path.parent),
19 | filename=PurePosixPath(path.name),
20 | local_dir=local_dir,
21 | )
22 |
23 |
24 | def prepare_image_encoder():
25 | print(f"Preparing image encoder weights...")
26 | local_dir = "./pretrained_weights"
27 | os.makedirs(local_dir, exist_ok=True)
28 | for hub_file in ["image_encoder/config.json", "image_encoder/pytorch_model.bin"]:
29 | path = Path(hub_file)
30 | saved_path = local_dir / path
31 | if os.path.exists(saved_path):
32 | continue
33 | hf_hub_download(
34 | repo_id="lambdalabs/sd-image-variations-diffusers",
35 | subfolder=PurePosixPath(path.parent),
36 | filename=PurePosixPath(path.name),
37 | local_dir=local_dir,
38 | )
39 |
40 |
41 | def prepare_dwpose():
42 | print(f"Preparing DWPose weights...")
43 | local_dir = "./pretrained_weights/DWPose"
44 | os.makedirs(local_dir, exist_ok=True)
45 | for hub_file in [
46 | "dw-ll_ucoco_384.onnx",
47 | "yolox_l.onnx",
48 | ]:
49 | path = Path(hub_file)
50 | saved_path = local_dir / path
51 | if os.path.exists(saved_path):
52 | continue
53 |
54 | hf_hub_download(
55 | repo_id="yzd-v/DWPose",
56 | subfolder=PurePosixPath(path.parent),
57 | filename=PurePosixPath(path.name),
58 | local_dir=local_dir,
59 | )
60 |
61 |
62 | def prepare_vae():
63 | print(f"Preparing vae weights...")
64 | local_dir = "./pretrained_weights/sd-vae-ft-mse"
65 | os.makedirs(local_dir, exist_ok=True)
66 | for hub_file in [
67 | "config.json",
68 | "diffusion_pytorch_model.bin",
69 | ]:
70 | path = Path(hub_file)
71 | saved_path = local_dir / path
72 | if os.path.exists(saved_path):
73 | continue
74 |
75 | hf_hub_download(
76 | repo_id="stabilityai/sd-vae-ft-mse",
77 | subfolder=PurePosixPath(path.parent),
78 | filename=PurePosixPath(path.name),
79 | local_dir=local_dir,
80 | )
81 |
82 |
83 | def prepare_anyone():
84 | print(f"Preparing AnimateAnyone weights...")
85 | local_dir = "./pretrained_weights"
86 | os.makedirs(local_dir, exist_ok=True)
87 | for hub_file in [
88 | "denoising_unet.pth",
89 | "motion_module.pth",
90 | "pose_guider.pth",
91 | "reference_unet.pth",
92 | ]:
93 | path = Path(hub_file)
94 | saved_path = local_dir / path
95 | if os.path.exists(saved_path):
96 | continue
97 |
98 | hf_hub_download(
99 | repo_id="patrolli/AnimateAnyone",
100 | subfolder=PurePosixPath(path.parent),
101 | filename=PurePosixPath(path.name),
102 | local_dir=local_dir,
103 | )
104 |
105 | if __name__ == '__main__':
106 | prepare_base_model()
107 | prepare_image_encoder()
108 | prepare_dwpose()
109 | prepare_vae()
110 | prepare_anyone()
111 |
--------------------------------------------------------------------------------
/tools/extract_dwpose_from_vid.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import os
3 | import random
4 | from pathlib import Path
5 |
6 | import numpy as np
7 |
8 | from src.dwpose import DWposeDetector
9 | from src.utils.util import get_fps, read_frames, save_videos_from_pil
10 |
11 | # Extract dwpose mp4 videos from raw videos
12 | # /path/to/video_dataset/*/*.mp4 -> /path/to/video_dataset_dwpose/*/*.mp4
13 |
14 |
15 | def process_single_video(video_path, detector, root_dir, save_dir):
16 | relative_path = os.path.relpath(video_path, root_dir)
17 | print(relative_path, video_path, root_dir)
18 | out_path = os.path.join(save_dir, relative_path)
19 | if os.path.exists(out_path):
20 | return
21 |
22 | output_dir = Path(os.path.dirname(os.path.join(save_dir, relative_path)))
23 | if not output_dir.exists():
24 | output_dir.mkdir(parents=True, exist_ok=True)
25 |
26 | fps = get_fps(video_path)
27 | frames = read_frames(video_path)
28 | kps_results = []
29 | for i, frame_pil in enumerate(frames):
30 | result, score = detector(frame_pil)
31 | score = np.mean(score, axis=-1)
32 |
33 | kps_results.append(result)
34 |
35 | save_videos_from_pil(kps_results, out_path, fps=fps)
36 |
37 |
38 | def process_batch_videos(video_list, detector, root_dir, save_dir):
39 | for i, video_path in enumerate(video_list):
40 | print(f"Process {i}/{len(video_list)} video")
41 | process_single_video(video_path, detector, root_dir, save_dir)
42 |
43 |
44 | if __name__ == "__main__":
45 | # -----
46 | # NOTE:
47 | # python tools/extract_dwpose_from_vid.py --video_root /path/to/video_dir
48 | # -----
49 | import argparse
50 |
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument("--video_root", type=str)
53 | parser.add_argument(
54 | "--save_dir", type=str, help="Path to save extracted pose videos"
55 | )
56 | parser.add_argument("-j", type=int, default=4, help="Num workers")
57 | args = parser.parse_args()
58 | num_workers = args.j
59 | if args.save_dir is None:
60 | save_dir = args.video_root + "_dwpose"
61 | else:
62 | save_dir = args.save_dir
63 | if not os.path.exists(save_dir):
64 | os.makedirs(save_dir)
65 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
66 | gpu_ids = [int(id) for id in range(len(cuda_visible_devices.split(",")))]
67 | print(f"avaliable gpu ids: {gpu_ids}")
68 |
69 | # collect all video_folder paths
70 | video_mp4_paths = set()
71 | for root, dirs, files in os.walk(args.video_root):
72 | for name in files:
73 | if name.endswith(".mp4"):
74 | video_mp4_paths.add(os.path.join(root, name))
75 | video_mp4_paths = list(video_mp4_paths)
76 | random.shuffle(video_mp4_paths)
77 |
78 | # split into chunks,
79 | batch_size = (len(video_mp4_paths) + num_workers - 1) // num_workers
80 | print(f"Num videos: {len(video_mp4_paths)} {batch_size = }")
81 | video_chunks = [
82 | video_mp4_paths[i : i + batch_size]
83 | for i in range(0, len(video_mp4_paths), batch_size)
84 | ]
85 |
86 | with concurrent.futures.ThreadPoolExecutor() as executor:
87 | futures = []
88 | for i, chunk in enumerate(video_chunks):
89 | # init detector
90 | gpu_id = gpu_ids[i % len(gpu_ids)]
91 | detector = DWposeDetector()
92 | # torch.cuda.set_device(gpu_id)
93 | detector = detector.to(f"cuda:{gpu_id}")
94 |
95 | futures.append(
96 | executor.submit(
97 | process_batch_videos, chunk, detector, args.video_root, save_dir
98 | )
99 | )
100 | for future in concurrent.futures.as_completed(futures):
101 | future.result()
102 |
--------------------------------------------------------------------------------
/tools/extract_meta_info.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | # -----
6 | # [{'vid': , 'kps': , 'other':},
7 | # {'vid': , 'kps': , 'other':}]
8 | # -----
9 | # python tools/extract_meta_info.py --root_path /path/to/video_dir --dataset_name fashion
10 | # -----
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--root_path", type=str)
13 | parser.add_argument("--dataset_name", type=str)
14 | parser.add_argument("--meta_info_name", type=str)
15 |
16 | args = parser.parse_args()
17 |
18 | if args.meta_info_name is None:
19 | args.meta_info_name = args.dataset_name
20 |
21 | pose_dir = args.root_path + "_dwpose"
22 |
23 | # collect all video_folder paths
24 | video_mp4_paths = set()
25 | for root, dirs, files in os.walk(args.root_path):
26 | for name in files:
27 | if name.endswith(".mp4"):
28 | video_mp4_paths.add(os.path.join(root, name))
29 | video_mp4_paths = list(video_mp4_paths)
30 |
31 | meta_infos = []
32 | for video_mp4_path in video_mp4_paths:
33 | relative_video_name = os.path.relpath(video_mp4_path, args.root_path)
34 | kps_path = os.path.join(pose_dir, relative_video_name)
35 | meta_infos.append({"video_path": video_mp4_path, "kps_path": kps_path})
36 |
37 | json.dump(meta_infos, open(f"./data/{args.meta_info_name}_meta.json", "w"))
38 |
--------------------------------------------------------------------------------
/tools/facetracker_api.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os, sys
3 | import math
4 | import numpy as np
5 | import cv2
6 | sys.path.append("OpenSeeFace/")
7 | from tracker import Tracker, get_model_base_path
8 |
9 | features = ["eye_l", "eye_r", "eyebrow_steepness_l", "eyebrow_updown_l", "eyebrow_quirk_l", "eyebrow_steepness_r", "eyebrow_updown_r", "eyebrow_quirk_r", "mouth_corner_updown_l", "mouth_corner_inout_l", "mouth_corner_updown_r", "mouth_corner_inout_r", "mouth_open", "mouth_wide"]
10 |
11 |
12 | def face_image(frame, save_path=None):
13 | height, width, c = frame.shape
14 | tracker = Tracker(width, height, threshold=None, max_threads=1, max_faces=1, discard_after=10, scan_every=3, silent=False, model_type=3, model_dir=None,
15 | no_gaze=False, detection_threshold=0.4, use_retinaface=0, max_feature_updates=900, static_model=True, try_hard=False)
16 | faces = tracker.predict(frame)
17 | frame = np.zeros_like(frame)
18 | detected = False
19 | face_lms = None
20 | for face_num, f in enumerate(faces):
21 | f = copy.copy(f)
22 | if f.eye_blink is None:
23 | f.eye_blink = [1, 1]
24 | right_state = "O" if f.eye_blink[0] > 0.30 else "-"
25 | left_state = "O" if f.eye_blink[1] > 0.30 else "-"
26 | detected = True
27 | if not f.success:
28 | pts_3d = np.zeros((70, 3), np.float32)
29 | if face_num == 0:
30 | face_lms = f.lms
31 | for pt_num, (x,y,c) in enumerate(f.lms):
32 | if pt_num == 66 and (f.eye_blink[0] < 0.30 or c < 0.20):
33 | continue
34 | if pt_num == 67 and (f.eye_blink[1] < 0.30 or c < 0.20):
35 | continue
36 | x = int(x + 0.5)
37 | y = int(y + 0.5)
38 |
39 | color = (0, 255, 0)
40 | if pt_num >= 66:
41 | color = (255, 255, 0)
42 | if not (x < 0 or y < 0 or x >= height or y >= width):
43 | cv2.circle(frame, (y, x), 1, color, -1)
44 | if f.rotation is not None:
45 | projected = cv2.projectPoints(f.contour, f.rotation, f.translation, tracker.camera, tracker.dist_coeffs)
46 | for [(x,y)] in projected[0]:
47 | x = int(x + 0.5)
48 | y = int(y + 0.5)
49 | if not (x < 0 or y < 0 or x >= height or y >= width):
50 | frame[int(x), int(y)] = (0, 255, 255)
51 | x += 1
52 | if not (x < 0 or y < 0 or x >= height or y >= width):
53 | frame[int(x), int(y)] = (0, 255, 255)
54 | y += 1
55 | if not (x < 0 or y < 0 or x >= height or y >= width):
56 | frame[int(x), int(y)] = (0, 255, 255)
57 | x -= 1
58 | if not (x < 0 or y < 0 or x >= height or y >= width):
59 | frame[int(x), int(y)] = (0, 255, 255)
60 | if save_path is not None:
61 | cv2.imwrite(save_path, frame)
62 | return frame, face_lms
63 |
--------------------------------------------------------------------------------
/tools/vid2pose.py:
--------------------------------------------------------------------------------
1 | from src.dwpose import DWposeDetector
2 | import os
3 | from pathlib import Path
4 |
5 | from src.utils.util import get_fps, read_frames, save_videos_from_pil
6 | import numpy as np
7 |
8 |
9 | if __name__ == "__main__":
10 | import argparse
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--video_path", type=str)
14 | args = parser.parse_args()
15 |
16 | if not os.path.exists(args.video_path):
17 | raise ValueError(f"Path: {args.video_path} not exists")
18 |
19 | dir_path, video_name = (
20 | os.path.dirname(args.video_path),
21 | os.path.splitext(os.path.basename(args.video_path))[0],
22 | )
23 | out_path = os.path.join(dir_path, video_name + "_kps.mp4")
24 |
25 | detector = DWposeDetector()
26 | detector = detector.to(f"cuda")
27 |
28 | fps = get_fps(args.video_path)
29 | frames = read_frames(args.video_path)
30 | kps_results = []
31 | for i, frame_pil in enumerate(frames):
32 | result, score = detector(frame_pil)
33 | score = np.mean(score, axis=-1)
34 |
35 | kps_results.append(result)
36 |
37 | print(out_path)
38 | save_videos_from_pil(kps_results, out_path, fps=fps)
39 |
--------------------------------------------------------------------------------